Arcadia822 commited on
Commit
1543ec3
·
unverified ·
1 Parent(s): 1010b83

feat: :sparkles: Configuration (#22)

Browse files

* feat: :sparkles: Configuration

Now we can config prompt, gpt4, knowledge base in gradio. Hooray!

edu_assistant/learning_tasks/coding_problem.py CHANGED
@@ -7,7 +7,11 @@ from pydantic import BaseModel, Field
7
  from pydantic_redis import Model, Store
8
 
9
  from edu_assistant.learning_tasks.base import BaseTask
10
- from edu_assistant.utils.langchain_utils import escape_for_prompt, load_llm
 
 
 
 
11
  from edu_assistant.utils.redis_utils import get_redis_config
12
 
13
  TEMPLATE = """The following is a friendly conversation between a human and an ai.
@@ -29,10 +33,40 @@ Student's code:
29
  ---
30
 
31
  Current conversation:
32
- {{history}}
33
  Human: {{input}}
34
  AI:"""
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  DEFAULT_INSTRUCTION = """Act as a c++ professional to check student's code.
37
  The code is written by a student aged 5-10 and mostly like to buggy or bad performanced.
38
  """
@@ -59,14 +93,15 @@ class CodingProblem(Model):
59
  store.register_model(CodingProblem)
60
 
61
  def expr(self, lang=""):
62
- expr = f"## Question\n\n```\n{escape_for_prompt(self.question)}\n```\n\n"
63
  expr += (
64
- f"""## Standard Answer (There might be others)\n\n```{lang}\n{escape_for_prompt(self.standard_answer)}\n```
 
65
  """
66
  if self.standard_answer
67
  else ""
68
  )
69
- expr += f"## Analysis\n\n```\n{escape_for_prompt(self.analysis)}\n```\n\n" if self.analysis else ""
70
  expr += "## Extra\n\n" + escape_for_prompt("".join(self.extra)) + "\n"
71
  return expr
72
 
@@ -88,14 +123,35 @@ class CodingAnswer(BaseModel):
88
 
89
 
90
  class CodingProblemAnalysis(BaseTask):
91
- def __init__(self, instruction: str = DEFAULT_INSTRUCTION, lang: str = "", knowledge: BaseRetriever = None):
92
- assert lang in ["python", "cpp", "java", "javascript", "go", "c#", ""]
93
-
94
- self.lang = lang
 
 
 
 
 
 
95
  self.instruction = instruction
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  self._session_store = {}
97
  self._knowledge = knowledge
98
 
 
 
99
  @staticmethod
100
  def build_coding_problem(question: str, standard_answer: str = "", analysis: str = "", extra: list[str] = None):
101
  extra = [] if extra is None else extra
@@ -120,7 +176,13 @@ class CodingProblemAnalysis(BaseTask):
120
  session_id = self._create_session_id()
121
  self._session_store[session_id] = chain
122
 
123
- result = chain({"input": first_question if first_question else DEFAULT_FIRST_QUESTION, "history": ""})
 
 
 
 
 
 
124
 
125
  result["session_id"] = session_id
126
 
@@ -143,35 +205,52 @@ class CodingProblemAnalysis(BaseTask):
143
 
144
  chain = self._session_store[session_id]
145
 
146
- result = chain({"input": question})
 
 
 
 
 
147
 
148
  result["session_id"] = session_id
149
 
150
  return result
151
 
 
 
 
 
152
  def _build_chain(self, problem: CodingProblem, answer: CodingAnswer) -> Chain:
153
- llm = load_llm()
154
- memory = ConversationBufferMemory()
155
- prompt = PromptTemplate.from_template(
156
- TEMPLATE.format(
157
- instruction=self.instruction,
158
- problem=problem.expr(lang=problem.language or self.lang or ""),
159
- answer=answer.expr(lang=problem.language or self.lang or ""),
160
- )
161
  )
162
 
163
  if not self._knowledge:
 
 
 
 
 
 
 
164
  return ConversationChain(
165
- llm=llm,
166
  memory=memory,
167
  prompt=prompt,
168
  )
169
  else:
 
 
 
 
 
 
 
170
  return ConversationalRetrievalChain.from_llm(
171
- llm=llm,
172
  memory=memory,
173
  retriever=self._knowledge,
174
- condense_question_llm=llm,
175
  return_source_documents=True,
176
  combine_docs_chain_kwargs={"prompt": prompt},
177
  )
 
7
  from pydantic_redis import Model, Store
8
 
9
  from edu_assistant.learning_tasks.base import BaseTask
10
+ from edu_assistant.utils.langchain_utils import (
11
+ escape_for_prompt,
12
+ load_gpt4_llm,
13
+ load_llm,
14
+ )
15
  from edu_assistant.utils.redis_utils import get_redis_config
16
 
17
  TEMPLATE = """The following is a friendly conversation between a human and an ai.
 
33
  ---
34
 
35
  Current conversation:
36
+ {{chat_history}}
37
  Human: {{input}}
38
  AI:"""
39
 
40
+ KNOWLEDGE_TEMPLATE = """The following is a friendly conversation between a human and an ai.
41
+ The ai is talkative and provides lots of specific details from its context.
42
+ If the ai does not know the answer to a question, it truthfully says it does not know.
43
+ The ai act following below instructions:
44
+ ---
45
+ {instruction}
46
+ ---
47
+
48
+ The coding problem:
49
+ ---
50
+ {problem}
51
+ ---
52
+
53
+ Student's code:
54
+ ```
55
+ {answer}
56
+ ```
57
+
58
+ Extra Information might be helpful for you:
59
+ ---
60
+ {{context}}
61
+ ---
62
+
63
+ Current conversation:
64
+ {{chat_history}}
65
+ Human: {{question}}
66
+ AI:
67
+ """
68
+
69
+
70
  DEFAULT_INSTRUCTION = """Act as a c++ professional to check student's code.
71
  The code is written by a student aged 5-10 and mostly like to buggy or bad performanced.
72
  """
 
93
  store.register_model(CodingProblem)
94
 
95
  def expr(self, lang=""):
96
+ expr = f"## Question\n\n---\n{escape_for_prompt(self.question)}\n---\n\n"
97
  expr += (
98
+ f"""## Standard Answer (There might be others)\n\n```{lang if lang else self.language}
99
+ {escape_for_prompt(self.standard_answer)}\n```
100
  """
101
  if self.standard_answer
102
  else ""
103
  )
104
+ expr += f"## Analysis\n\n---\n{escape_for_prompt(self.analysis)}\n---\n\n" if self.analysis else ""
105
  expr += "## Extra\n\n" + escape_for_prompt("".join(self.extra)) + "\n"
106
  return expr
107
 
 
123
 
124
 
125
  class CodingProblemAnalysis(BaseTask):
126
+ HISTORY_KEY = "chat_history"
127
+
128
+ def __init__(
129
+ self,
130
+ instruction: str = DEFAULT_INSTRUCTION,
131
+ first_question: str = DEFAULT_FIRST_QUESTION,
132
+ lang: str = "",
133
+ knowledge: BaseRetriever = None,
134
+ enable_gpt4: bool = False,
135
+ ):
136
  self.instruction = instruction
137
+ self.first_question = first_question
138
+ self.lang = lang
139
+ self.enable_gpt4 = enable_gpt4
140
+ # TODO: load threshold key from implement. value from config
141
+ self.vectordbkwargs = {"score_threshold": 0.9} # Qdrant cosine. higher is better.
142
+
143
+ if knowledge:
144
+ self._input_key = "question"
145
+ self._output_key = "answer"
146
+ else:
147
+ self._input_key = "input"
148
+ self._output_key = "response"
149
+
150
  self._session_store = {}
151
  self._knowledge = knowledge
152
 
153
+ self._init_llm()
154
+
155
  @staticmethod
156
  def build_coding_problem(question: str, standard_answer: str = "", analysis: str = "", extra: list[str] = None):
157
  extra = [] if extra is None else extra
 
176
  session_id = self._create_session_id()
177
  self._session_store[session_id] = chain
178
 
179
+ args = {self._input_key: first_question if first_question else self.first_question, self.HISTORY_KEY: ""}
180
+
181
+ # TODO: ConversationalRetrievalChain should support vectordbkwargs
182
+ # if self._knowledge:
183
+ # args["vectordbkwargs"] = self.vectordbkwargs
184
+
185
+ result = chain(args)
186
 
187
  result["session_id"] = session_id
188
 
 
205
 
206
  chain = self._session_store[session_id]
207
 
208
+ args = {self._input_key: question}
209
+
210
+ # if self._knowledge:
211
+ # args["vectordbkwargs"] = self.vectordbkwargs
212
+
213
+ result = chain(args)
214
 
215
  result["session_id"] = session_id
216
 
217
  return result
218
 
219
+ def _init_llm(self):
220
+ self._main_llm = load_gpt4_llm() if self.enable_gpt4 else load_llm()
221
+ self._secondary_llm = load_llm()
222
+
223
  def _build_chain(self, problem: CodingProblem, answer: CodingAnswer) -> Chain:
224
+ memory = ConversationBufferMemory(
225
+ memory_key=self.HISTORY_KEY, output_key=self._output_key, return_messages=True
 
 
 
 
 
 
226
  )
227
 
228
  if not self._knowledge:
229
+ prompt = PromptTemplate.from_template(
230
+ TEMPLATE.format(
231
+ instruction=self.instruction,
232
+ problem=problem.expr(lang=problem.language or self.lang),
233
+ answer=answer.expr(lang=problem.language or self.lang),
234
+ )
235
+ )
236
  return ConversationChain(
237
+ llm=self._main_llm,
238
  memory=memory,
239
  prompt=prompt,
240
  )
241
  else:
242
+ prompt = PromptTemplate.from_template(
243
+ KNOWLEDGE_TEMPLATE.format(
244
+ instruction=self.instruction,
245
+ problem=problem.expr(lang=problem.language or self.lang),
246
+ answer=answer.expr(lang=problem.language or self.lang),
247
+ )
248
+ )
249
  return ConversationalRetrievalChain.from_llm(
250
+ llm=self._main_llm,
251
  memory=memory,
252
  retriever=self._knowledge,
253
+ condense_question_llm=self._secondary_llm,
254
  return_source_documents=True,
255
  combine_docs_chain_kwargs={"prompt": prompt},
256
  )
edu_assistant/learning_tasks/qa.py CHANGED
@@ -9,7 +9,7 @@ from langchain.memory import ConversationBufferMemory
9
  from langchain.schema import BaseRetriever
10
 
11
  from edu_assistant.learning_tasks.base import BaseTask
12
- from edu_assistant.utils.langchain_utils import load_llm
13
 
14
  TEMPLATE_CHAT = """The following is a friendly conversation between a human and an ai.
15
  The ai is talkative and provides lots of specific details from its context.
@@ -68,6 +68,9 @@ Useful context for you to answer the question:
68
  {{input}}
69
  """
70
 
 
 
 
71
 
72
  class QaTask(BaseTask):
73
  _session_store: dict
@@ -76,7 +79,7 @@ class QaTask(BaseTask):
76
 
77
  HISTORY_KEY = "chat_history"
78
 
79
- def __init__(self, instruction: str = "", knowledge: BaseRetriever = None):
80
  """Create a new QaTask service.
81
 
82
  Args:
@@ -88,6 +91,10 @@ class QaTask(BaseTask):
88
  If not set, will use internal memory to store chat history. Which will be lost after restart and might
89
  cost huge memory.
90
  """
 
 
 
 
91
  if knowledge:
92
  self._chat_prompt = PromptTemplate.from_template(TEMPLATE_CHAT_CONTEXT.format(instruction=instruction))
93
  self._once_prompt = PromptTemplate.from_template(TEMPLATE_ONCE_CONTEXT.format(instruction=instruction))
@@ -102,42 +109,9 @@ class QaTask(BaseTask):
102
  self._session_store = {}
103
  self._knowledge = knowledge
104
 
105
- self._qa_once = self._build_once_chain()
106
-
107
- def _build_once_chain(self):
108
- if not self._knowledge:
109
- return LLMChain(
110
- llm=load_llm(),
111
- prompt=self._once_prompt,
112
- )
113
- else:
114
- return RetrievalQA.from_llm(
115
- llm=load_llm(),
116
- retriever=self._knowledge,
117
- return_source_documents=True,
118
- prompt=self._once_prompt,
119
- )
120
 
121
- def _build_chat_chain(self):
122
- if not self._knowledge:
123
- return ConversationChain(
124
- llm=load_llm(),
125
- memory=ConversationBufferMemory(
126
- memory_key=QaTask.HISTORY_KEY, output_key=self._output_key, return_messages=True
127
- ),
128
- prompt=self._chat_prompt,
129
- )
130
- else:
131
- return ConversationalRetrievalChain.from_llm(
132
- llm=load_llm(),
133
- retriever=self._knowledge,
134
- condense_question_llm=load_llm(),
135
- return_source_documents=True,
136
- combine_docs_chain_kwargs={"prompt": self._chat_prompt},
137
- memory=ConversationBufferMemory(
138
- memory_key=QaTask.HISTORY_KEY, output_key=self._output_key, return_messages=True
139
- ),
140
- )
141
 
142
  def ask(
143
  self,
@@ -176,6 +150,9 @@ class QaTask(BaseTask):
176
  if session_mem:
177
  chain.memory = session_mem
178
 
 
 
 
179
  result = chain(args)
180
 
181
  if session_id:
@@ -183,6 +160,45 @@ class QaTask(BaseTask):
183
 
184
  return result
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  def _create_session_chain(self, session_id) -> ConversationChain:
187
  chain = self._build_chat_chain()
188
  self._session_store[session_id] = chain
 
9
  from langchain.schema import BaseRetriever
10
 
11
  from edu_assistant.learning_tasks.base import BaseTask
12
+ from edu_assistant.utils.langchain_utils import load_gpt4_llm, load_llm
13
 
14
  TEMPLATE_CHAT = """The following is a friendly conversation between a human and an ai.
15
  The ai is talkative and provides lots of specific details from its context.
 
68
  {{input}}
69
  """
70
 
71
+ DEFAULT_INSTRUCTION = """Act as a c++ professional to answer student aged 5-10 questions. Answer properly and politely.
72
+ Don't extend conversation multiple times. Only add one time saying."""
73
+
74
 
75
  class QaTask(BaseTask):
76
  _session_store: dict
 
79
 
80
  HISTORY_KEY = "chat_history"
81
 
82
+ def __init__(self, instruction: str = DEFAULT_INSTRUCTION, knowledge: BaseRetriever = None, enable_gpt4=False):
83
  """Create a new QaTask service.
84
 
85
  Args:
 
91
  If not set, will use internal memory to store chat history. Which will be lost after restart and might
92
  cost huge memory.
93
  """
94
+ self.enable_gpt4 = enable_gpt4
95
+ # TODO: load threshold key from implement. value from config
96
+ self.vectordbkwargs = {"score_threshold": 0.9} # Qdrant cosine. higher is better.
97
+
98
  if knowledge:
99
  self._chat_prompt = PromptTemplate.from_template(TEMPLATE_CHAT_CONTEXT.format(instruction=instruction))
100
  self._once_prompt = PromptTemplate.from_template(TEMPLATE_ONCE_CONTEXT.format(instruction=instruction))
 
109
  self._session_store = {}
110
  self._knowledge = knowledge
111
 
112
+ self._init_llm()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ self._qa_once = self._build_once_chain()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  def ask(
117
  self,
 
150
  if session_mem:
151
  chain.memory = session_mem
152
 
153
+ # TODO: ConversationalRetrievalChain should support vectordbkwargs
154
+ # if self._knowledge:
155
+ # args["vectordbkwargs"] = self.vectordbkwargs
156
  result = chain(args)
157
 
158
  if session_id:
 
160
 
161
  return result
162
 
163
+ def _init_llm(self):
164
+ self._main_llm = load_gpt4_llm() if self.enable_gpt4 else load_llm()
165
+ self._secondary_llm = load_llm()
166
+
167
+ def _build_once_chain(self):
168
+ if not self._knowledge:
169
+ return LLMChain(
170
+ llm=self._main_llm,
171
+ prompt=self._once_prompt,
172
+ )
173
+ else:
174
+ return RetrievalQA.from_llm(
175
+ llm=self._main_llm,
176
+ retriever=self._knowledge,
177
+ return_source_documents=True,
178
+ prompt=self._once_prompt,
179
+ )
180
+
181
+ def _build_chat_chain(self):
182
+ if not self._knowledge:
183
+ return ConversationChain(
184
+ llm=self._main_llm,
185
+ memory=ConversationBufferMemory(
186
+ memory_key=QaTask.HISTORY_KEY, output_key=self._output_key, return_messages=True
187
+ ),
188
+ prompt=self._chat_prompt,
189
+ )
190
+ else:
191
+ return ConversationalRetrievalChain.from_llm(
192
+ llm=self._main_llm,
193
+ retriever=self._knowledge,
194
+ condense_question_llm=self._secondary_llm,
195
+ return_source_documents=True,
196
+ combine_docs_chain_kwargs={"prompt": self._chat_prompt},
197
+ memory=ConversationBufferMemory(
198
+ memory_key=QaTask.HISTORY_KEY, output_key=self._output_key, return_messages=True
199
+ ),
200
+ )
201
+
202
  def _create_session_chain(self, session_id) -> ConversationChain:
203
  chain = self._build_chat_chain()
204
  self._session_store[session_id] = chain
edu_assistant/utils/langchain_utils.py CHANGED
@@ -4,6 +4,7 @@ from functools import lru_cache
4
  from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
5
  from langchain.chat_models.base import BaseChatModel
6
  from langchain.embeddings import OpenAIEmbeddings
 
7
  from langchain.vectorstores import Qdrant, VectorStore
8
 
9
  from edu_assistant.utils.qdrant_utils import load_qdrant_client
@@ -42,6 +43,11 @@ def load_gpt4_llm() -> BaseChatModel:
42
  return llm
43
 
44
 
 
 
 
 
 
45
  @lru_cache(maxsize=1)
46
  def load_embeddings():
47
  if os.environ.get("AZURE_OPENAI"):
@@ -84,3 +90,14 @@ def escape_for_prompt(text: str) -> str:
84
  str: escaped string.
85
  """
86
  return text.replace("{", "{{").replace("}", "}}")
 
 
 
 
 
 
 
 
 
 
 
 
4
  from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
5
  from langchain.chat_models.base import BaseChatModel
6
  from langchain.embeddings import OpenAIEmbeddings
7
+ from langchain.schema import Document
8
  from langchain.vectorstores import Qdrant, VectorStore
9
 
10
  from edu_assistant.utils.qdrant_utils import load_qdrant_client
 
43
  return llm
44
 
45
 
46
+ @lru_cache(maxsize=1)
47
+ def load_gpt4_flag() -> bool:
48
+ return os.environ.get("CODEDOG_ENABLE_GPT4") is not None
49
+
50
+
51
  @lru_cache(maxsize=1)
52
  def load_embeddings():
53
  if os.environ.get("AZURE_OPENAI"):
 
90
  str: escaped string.
91
  """
92
  return text.replace("{", "{{").replace("}", "}}")
93
+
94
+
95
+ def shrink_docs(docs: list[Document], max_size=50):
96
+ """shrink source docs content size for display.
97
+
98
+ Args:
99
+ docs (dict): Retrieval Chain returned docs.
100
+ """
101
+ for doc in docs:
102
+ doc.page_content = doc.page_content[:max_size] + ".."
103
+ return docs
tests/learning_tasks/test_coding_problem.py CHANGED
@@ -5,6 +5,7 @@ from edu_assistant.learning_tasks import CodingProblemAnalysis
5
 
6
 
7
  class TestCodingProblemAnalysis(TestCase):
 
8
  def setUp(self):
9
  self.analysis = CodingProblemAnalysis()
10
 
 
5
 
6
 
7
  class TestCodingProblemAnalysis(TestCase):
8
+ @patch.object(CodingProblemAnalysis, "_init_llm", MagicMock())
9
  def setUp(self):
10
  self.analysis = CodingProblemAnalysis()
11
 
tests/learning_tasks/test_qa.py CHANGED
@@ -6,8 +6,9 @@ from edu_assistant.learning_tasks import QaTask
6
  from edu_assistant.learning_tasks.qa import TEMPLATE_CHAT, TEMPLATE_ONCE
7
 
8
 
 
9
  @patch.object(QaTask, "_build_once_chain")
10
- def test_init_without_knowledge(mocked_build_once_chain):
11
  task = QaTask(instruction="test")
12
 
13
  assert task._chat_prompt == PromptTemplate.from_template(TEMPLATE_CHAT.format(instruction="test"))
@@ -16,9 +17,10 @@ def test_init_without_knowledge(mocked_build_once_chain):
16
  mocked_build_once_chain.assert_called_once()
17
 
18
 
 
19
  @patch.object(QaTask, "_build_once_chain")
20
  @patch.object(QaTask, "_create_session_chain")
21
- def test_ask_with_session(mocked_create_session_chain, mocked_build_once_chain):
22
  mocked_chain = MagicMock(return_value={"response": "ok"})
23
  mocked_build_once_chain.return_value = mocked_chain
24
  mocked_create_session_chain.return_value = mocked_chain
@@ -30,15 +32,15 @@ def test_ask_with_session(mocked_create_session_chain, mocked_build_once_chain):
30
  result = task.ask("how are you?", session=True)
31
 
32
  mock_create_id.assert_called_once()
33
- mocked_create_session_chain.assert_called_once_with(123)
34
  assert "session_id" in result
35
  assert result["session_id"] == 123
36
  assert "response" in result
37
  assert result["response"] == "ok"
38
 
39
 
 
40
  @patch.object(QaTask, "_build_once_chain")
41
- def test_ask_without_session(mocked_build_once_chain):
42
  mocked_llm = MagicMock()
43
  mocked_llm.run.return_value = {"result": "ok"}
44
  mocked_build_once_chain.return_value = mocked_llm
 
6
  from edu_assistant.learning_tasks.qa import TEMPLATE_CHAT, TEMPLATE_ONCE
7
 
8
 
9
+ @patch.object(QaTask, "_init_llm")
10
  @patch.object(QaTask, "_build_once_chain")
11
+ def test_init_without_knowledge(mocked_build_once_chain, mocked_init_llm):
12
  task = QaTask(instruction="test")
13
 
14
  assert task._chat_prompt == PromptTemplate.from_template(TEMPLATE_CHAT.format(instruction="test"))
 
17
  mocked_build_once_chain.assert_called_once()
18
 
19
 
20
+ @patch.object(QaTask, "_init_llm")
21
  @patch.object(QaTask, "_build_once_chain")
22
  @patch.object(QaTask, "_create_session_chain")
23
+ def test_ask_with_session(mocked_create_session_chain, mocked_build_once_chain, mocked_init_llm):
24
  mocked_chain = MagicMock(return_value={"response": "ok"})
25
  mocked_build_once_chain.return_value = mocked_chain
26
  mocked_create_session_chain.return_value = mocked_chain
 
32
  result = task.ask("how are you?", session=True)
33
 
34
  mock_create_id.assert_called_once()
 
35
  assert "session_id" in result
36
  assert result["session_id"] == 123
37
  assert "response" in result
38
  assert result["response"] == "ok"
39
 
40
 
41
+ @patch.object(QaTask, "_init_llm")
42
  @patch.object(QaTask, "_build_once_chain")
43
+ def test_ask_without_session(mocked_build_once_chain, mocked_init_llm):
44
  mocked_llm = MagicMock()
45
  mocked_llm.run.return_value = {"result": "ok"}
46
  mocked_build_once_chain.return_value = mocked_llm
webui/coding_problem.py CHANGED
@@ -6,192 +6,256 @@ from langchain.callbacks import get_openai_callback
6
 
7
  from edu_assistant.learning_tasks.coding_problem import (
8
  DEFAULT_FIRST_QUESTION,
 
9
  CodingProblem,
10
  CodingProblemAnalysis,
11
  )
 
12
 
13
  CodingProblem.enable_redis_orm()
14
- task = CodingProblemAnalysis()
15
-
16
-
17
- def get_problems() -> list[str]:
18
- data = CodingProblem.select(columns=["title"])
19
- if not data:
20
- return []
21
- titles = [problem_data["title"] for problem_data in data]
22
- return titles
23
-
24
-
25
- def update_problems():
26
- titles = get_problems()
27
- gr.Info("更新题目列表成功")
28
- return gr.Dropdown.update(choices=titles)
29
-
30
-
31
- def select_problem(title: str):
32
- problem: CodingProblem = CodingProblem.select(ids=[title])[0]
33
- return (
34
- problem.expr(),
35
- problem.title,
36
- problem.language,
37
- problem.question,
38
- problem.analysis,
39
- problem.standard_answer,
40
- json.dumps(problem.extra, ensure_ascii=False, indent=4),
41
- )
42
-
43
-
44
- def update_problem(title, language, problem, analysis, answer, extra):
45
- # TODO: add language
46
- try:
47
- extra_data = json.loads(extra)
48
- except json.JSONDecodeError:
49
- extra_data = [extra]
50
-
51
- CodingProblem.update(
52
- title,
53
- data={
54
- "title": title,
55
- "language": language,
56
- "question": problem,
57
- "analysis": analysis,
58
- "standard_answer": answer,
59
- "extra": extra_data,
60
- },
61
- )
62
- gr.Info("更新题目成功")
63
-
64
-
65
- def delete_problem(title):
66
- CodingProblem.delete(ids=[title])
67
- gr.Info("删除题目成功")
68
-
69
- return "", "", "", "", "", "", "", ""
70
-
71
-
72
- def analysis_problem(title, code, extra: str = ""):
73
- problem = CodingProblem.select(ids=[title])[0]
74
- answer = CodingProblemAnalysis.build_coding_answer(answer=code)
75
-
76
- with get_openai_callback() as cb:
77
- result = task.start_analysis(problem, answer)
78
- status = {"tokens": cb.total_tokens, "cost": f"${cb.total_cost:.4f}"}
79
-
80
- answer = result["response"]
81
- session_id = result["session_id"]
82
- docs = jsonable_encoder(result.get("source_documents", []))
83
- return [(DEFAULT_FIRST_QUESTION, answer)], session_id, status, docs
84
-
85
-
86
- def chat(message, chat_history, session_id):
87
- if not session_id:
88
- return "", "", {"tokens": 0}, []
89
- with get_openai_callback() as cb:
90
- result = task.ask(message, session_id=session_id)
91
- if not result:
92
- raise gr.Error("Session expired. Please recreate a new problem analysis session.")
93
 
94
- session_id = result["session_id"]
95
- docs = jsonable_encoder(result.get("source_documents", []))
96
 
97
- bot_message = result["response"]
98
- chat_history.append((message, bot_message))
 
 
 
 
 
 
 
 
 
99
 
100
- status = {"tokens": cb.total_tokens, "cost": f"${cb.total_cost:.4f}"}
 
101
 
102
- return "", chat_history, status, docs
 
 
 
 
 
 
 
 
 
103
 
 
 
104
 
105
- with gr.Blocks() as coding_problem_ui:
106
- with gr.Row():
107
- with gr.Column(scale=6):
108
- problem_selector = gr.Dropdown(choices=get_problems(), show_label=False, interactive=True)
109
- with gr.Column():
110
- refresh_btn = gr.Button(value="刷新")
111
- with gr.Row():
112
- with gr.Column(scale=6):
113
- with gr.Tab(label="错误代码分析"):
114
- with gr.Row():
115
- with gr.Column(scale=3):
116
- with gr.Row():
117
- problem_view = gr.Markdown(label="题目")
118
- with gr.Row():
119
- code_view = gr.Textbox(label="代码", lines=10, interactive=True)
120
- with gr.Column(scale=3):
121
- with gr.Row():
122
- chat_box = gr.Chatbot(height=500)
123
- with gr.Row():
124
- chat_input = gr.Textbox(interactive=True)
125
- with gr.Column():
 
 
 
 
 
 
 
 
126
  with gr.Row():
127
- analysis_btn = gr.Button(value="分析")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  with gr.Row():
129
- status = gr.JSON(value="""{"tokens":0}""")
 
 
 
 
 
 
 
 
 
 
 
 
130
  with gr.Row():
131
- session_id = gr.Textbox(label="Session", interactive=False, value="")
 
 
 
132
  with gr.Row():
133
- docs = gr.JSON(value="""["docs"]""", label="Docs")
134
  with gr.Row():
135
- clear = gr.ClearButton([problem_view, code_view, problem_selector, session_id, docs])
136
 
137
- with gr.Tab(label="题库管理"):
138
- with gr.Row():
139
- with gr.Column(scale=6):
140
- with gr.Row():
141
- title_edit = gr.Textbox(label="标题", interactive=True)
142
- with gr.Row():
143
- language_edit = gr.Dropdown(
144
- choices=["python", "cpp", "java"],
145
- label="语言",
146
- interactive=True,
147
- allow_custom_value=True,
148
- )
149
- with gr.Column():
150
- manage_update = gr.Button(value="更新")
151
- manage_delete = gr.Button(value="删除", variant="stop")
152
- with gr.Row():
153
- with gr.Column():
154
- problem_edit = gr.Textbox(label="题目", lines=10, max_lines=100, interactive=True)
155
- with gr.Column():
156
- analysis_edit = gr.Textbox(label="解析", lines=10, max_lines=100, interactive=True)
157
- with gr.Row():
158
- answer_edit = gr.Textbox(label="标准答案", lines=10, max_lines=100, interactive=True)
159
- with gr.Row():
160
- extra_edit = gr.Textbox(label="额外信息", lines=10, max_lines=100, interactive=True)
161
-
162
- refresh_btn.click(update_problems, [], [problem_selector])
163
- problem_selector.select(
164
- select_problem,
165
- [
166
- problem_selector,
167
- ],
168
- [problem_view, title_edit, language_edit, problem_edit, analysis_edit, answer_edit, extra_edit],
169
- )
170
- analysis_btn.click(
171
- analysis_problem,
172
- [problem_selector, code_view],
173
- [chat_box, session_id, status, docs],
174
- )
175
- chat_input.submit(chat, [chat_input, chat_box, session_id], [chat_input, chat_box, status, docs])
176
-
177
- manage_update.click(
178
- update_problem, [title_edit, language_edit, problem_edit, analysis_edit, answer_edit, extra_edit], []
179
- )
180
- manage_delete.click(
181
- delete_problem,
182
- [problem_selector],
183
- [
184
- problem_selector,
185
- problem_view,
186
- title_edit,
187
- language_edit,
188
- problem_edit,
189
- analysis_edit,
190
- answer_edit,
191
- extra_edit,
192
- ],
193
- )
194
-
195
- if __name__ == "__main__":
196
- coding_problem_ui.queue()
197
- coding_problem_ui.launch(max_threads=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  from edu_assistant.learning_tasks.coding_problem import (
8
  DEFAULT_FIRST_QUESTION,
9
+ DEFAULT_INSTRUCTION,
10
  CodingProblem,
11
  CodingProblemAnalysis,
12
  )
13
+ from edu_assistant.utils.langchain_utils import load_vectorstore, shrink_docs
14
 
15
  CodingProblem.enable_redis_orm()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
 
 
17
 
18
+ class CodingProblemUI:
19
+ def __init__(
20
+ self,
21
+ *,
22
+ instruction: str = DEFAULT_INSTRUCTION,
23
+ first_question: str = DEFAULT_FIRST_QUESTION,
24
+ knowledge_name: str = "example",
25
+ enable_gpt4: bool = False,
26
+ ):
27
+ self._init_task(instruction, first_question, knowledge_name, enable_gpt4)
28
+ self._init_ui()
29
 
30
+ def ui_render(self):
31
+ self.ui.render()
32
 
33
+ def ui_reload(
34
+ self,
35
+ *,
36
+ instruction: str = DEFAULT_INSTRUCTION,
37
+ first_question: str = DEFAULT_FIRST_QUESTION,
38
+ knowledge_name: str = "example",
39
+ enable_gpt4: bool = False,
40
+ refresh: bool = True,
41
+ ):
42
+ self._init_task(instruction, first_question, knowledge_name, enable_gpt4)
43
 
44
+ if refresh:
45
+ self.ui_render()
46
 
47
+ def get_instruction(self):
48
+ return self.instruction
49
+
50
+ def get_first_question(self):
51
+ return self.first_question
52
+
53
+ def _init_task(self, instruction: str, first_question: str, knowledge_name: str, enable_gpt4: bool):
54
+ self.instruction = instruction
55
+ self.first_question = first_question
56
+ self.knowledge = knowledge_name
57
+ self.enable_gpt4 = enable_gpt4
58
+ self.task = CodingProblemAnalysis(
59
+ instruction=instruction,
60
+ first_question=first_question,
61
+ knowledge=load_vectorstore(knowledge_name).as_retriever(),
62
+ enable_gpt4=enable_gpt4,
63
+ )
64
+
65
+ def _init_ui(self):
66
+ self.ui = gr.Blocks()
67
+ with self.ui:
68
+ with gr.Row():
69
+ with gr.Column(scale=6):
70
+ problem_selector = gr.Dropdown(choices=self._get_problems(), show_label=False, interactive=True)
71
+ with gr.Column(scale=1):
72
+ refresh_btn = gr.Button(value="刷新")
73
+ with gr.Row():
74
+ with gr.Column(scale=6):
75
+ with gr.Tab(label="错误代码分析"):
76
  with gr.Row():
77
+ with gr.Column(scale=3):
78
+ with gr.Row():
79
+ problem_view = gr.Markdown(label="题目")
80
+ with gr.Row():
81
+ code_view = gr.Textbox(label="代码", lines=10, interactive=True)
82
+ with gr.Column(scale=3):
83
+ with gr.Row():
84
+ chat_box = gr.Chatbot(height=500, label="聊天记录")
85
+ with gr.Row():
86
+ chat_input = gr.Textbox(show_label=False)
87
+ with gr.Column():
88
+ with gr.Row():
89
+ analysis_btn = gr.Button(value="分析")
90
+ with gr.Row():
91
+ clear = gr.ClearButton()
92
+ with gr.Row():
93
+ session_id = gr.Textbox(label="Session", interactive=False, value="")
94
+ with gr.Row():
95
+ status = gr.JSON(value={"tokens": 0}, label="Status")
96
+ with gr.Row():
97
+ docs = gr.JSON(value=["docs"], label="Docs")
98
+
99
+ with gr.Tab(label="题库管理"):
100
  with gr.Row():
101
+ with gr.Column(scale=6):
102
+ with gr.Row():
103
+ title_edit = gr.Textbox(label="标题", interactive=True)
104
+ with gr.Row():
105
+ language_edit = gr.Dropdown(
106
+ choices=["python", "cpp", "java"],
107
+ label="语言",
108
+ interactive=True,
109
+ allow_custom_value=True,
110
+ )
111
+ with gr.Column(scale=1):
112
+ manage_update = gr.Button(value="更新")
113
+ manage_delete = gr.Button(value="删除", variant="stop")
114
  with gr.Row():
115
+ with gr.Column():
116
+ problem_edit = gr.Textbox(label="题目", lines=10, max_lines=100, interactive=True)
117
+ with gr.Column():
118
+ analysis_edit = gr.Textbox(label="解析", lines=10, max_lines=100, interactive=True)
119
  with gr.Row():
120
+ answer_edit = gr.Textbox(label="标准答案", lines=10, max_lines=100, interactive=True)
121
  with gr.Row():
122
+ extra_edit = gr.Textbox(label="额外信息", lines=10, max_lines=100, interactive=True)
123
 
124
+ refresh_btn.click(self._update_problems, [], [problem_selector])
125
+ problem_selector.select(
126
+ self._select_problem,
127
+ [
128
+ problem_selector,
129
+ ],
130
+ [problem_view, title_edit, language_edit, problem_edit, analysis_edit, answer_edit, extra_edit],
131
+ )
132
+ analysis_btn.click(
133
+ self._analysis_problem,
134
+ [problem_selector, code_view],
135
+ [chat_box, session_id, status, docs],
136
+ )
137
+ chat_input.submit(self._chat, [chat_input, chat_box, session_id], [chat_input, chat_box, status, docs])
138
+
139
+ manage_update.click(
140
+ self._update_problem,
141
+ [title_edit, language_edit, problem_edit, analysis_edit, answer_edit, extra_edit],
142
+ [],
143
+ )
144
+ manage_delete.click(
145
+ self._delete_problem,
146
+ [problem_selector],
147
+ [
148
+ problem_selector,
149
+ problem_view,
150
+ title_edit,
151
+ language_edit,
152
+ problem_edit,
153
+ analysis_edit,
154
+ answer_edit,
155
+ extra_edit,
156
+ ],
157
+ )
158
+ clear.click(
159
+ self._clear,
160
+ [],
161
+ [
162
+ problem_selector,
163
+ problem_view,
164
+ code_view,
165
+ chat_box,
166
+ session_id,
167
+ status,
168
+ docs,
169
+ problem_view,
170
+ title_edit,
171
+ language_edit,
172
+ problem_edit,
173
+ analysis_edit,
174
+ answer_edit,
175
+ extra_edit,
176
+ ],
177
+ )
178
+
179
+ def _get_problems(self) -> list[str]:
180
+ data = CodingProblem.select(columns=["title"])
181
+ if not data:
182
+ return []
183
+ titles = [problem_data["title"] for problem_data in data]
184
+ return titles
185
+
186
+ def _update_problems(self):
187
+ titles = self._get_problems()
188
+ gr.Info("更新题目列表成功")
189
+ return gr.Dropdown.update(choices=titles)
190
+
191
+ def _select_problem(self, title: str):
192
+ problem: CodingProblem = CodingProblem.select(ids=[title])[0]
193
+ return (
194
+ problem.expr(),
195
+ problem.title,
196
+ problem.language,
197
+ problem.question,
198
+ problem.analysis,
199
+ problem.standard_answer,
200
+ json.dumps(problem.extra, ensure_ascii=False, indent=4),
201
+ )
202
+
203
+ def _update_problem(self, title, language, problem, analysis, answer, extra):
204
+ # TODO: add language
205
+ try:
206
+ extra_data = json.loads(extra)
207
+ except json.JSONDecodeError:
208
+ extra_data = [extra]
209
+
210
+ CodingProblem.update(
211
+ title,
212
+ data={
213
+ "title": title,
214
+ "language": language,
215
+ "question": problem,
216
+ "analysis": analysis,
217
+ "standard_answer": answer,
218
+ "extra": extra_data,
219
+ },
220
+ )
221
+ gr.Info("更新题目成功")
222
+
223
+ def _delete_problem(self, title):
224
+ CodingProblem.delete(ids=[title])
225
+ gr.Info("删除题目成功")
226
+
227
+ return "", "", "", "", "", "", "", ""
228
+
229
+ def _analysis_problem(self, title, code, extra: str = ""):
230
+ problem = CodingProblem.select(ids=[title])[0]
231
+ answer = CodingProblemAnalysis.build_coding_answer(answer=code, extra=[extra])
232
+
233
+ with get_openai_callback() as cb:
234
+ result = self.task.start_analysis(problem, answer)
235
+ status = {"tokens": cb.total_tokens, "cost": f"${cb.total_cost:.4f}"}
236
+
237
+ answer = result["answer"]
238
+ session_id = result["session_id"]
239
+ docs = jsonable_encoder(shrink_docs(result.get("source_documents", [])))
240
+ return [(self.first_question, answer)], session_id, status, docs
241
+
242
+ def _chat(self, message, chat_history, session_id):
243
+ if not session_id:
244
+ return "", "", {"tokens": 0}, []
245
+ with get_openai_callback() as cb:
246
+ result = self.task.ask(message, session_id=session_id)
247
+ if not result:
248
+ raise gr.Error("Session expired. Please recreate a new problem analysis session.")
249
+
250
+ session_id = result["session_id"]
251
+ docs = jsonable_encoder(result.get("source_documents", []))
252
+
253
+ bot_message = result["answer"]
254
+ chat_history.append((message, bot_message))
255
+
256
+ status = {"tokens": cb.total_tokens, "cost": f"${cb.total_cost:.4f}"}
257
+
258
+ return "", chat_history, status, docs
259
+
260
+ def _clear(self):
261
+ return "", "", "", [], "", {"tokens": 0}, ["docs"], "", "", "", "", "", "", ""
webui/qa.py CHANGED
@@ -2,66 +2,84 @@ import gradio as gr
2
  from fastapi.encoders import jsonable_encoder
3
  from langchain.callbacks import get_openai_callback
4
 
5
- from edu_assistant.learning_tasks.qa import QaTask
6
- from edu_assistant.utils.langchain_utils import load_vectorstore
7
-
8
- DEFAULT_INSTRUCTION = """Act as a c++ professional to answer student aged 5-10 questions. Answer properly and politely.
9
- Don't extend conversation multiple times. Only add one time saying."""
10
-
11
-
12
- task = QaTask(
13
- instruction=DEFAULT_INSTRUCTION,
14
- knowledge=load_vectorstore("example").as_retriever(),
15
- )
16
-
17
-
18
- def respond(message, chat_history, session_id):
19
- with get_openai_callback() as cb:
20
- if session_id:
21
- result = task.ask(message, session_id=session_id)
22
- else:
23
- result = task.ask(message)
24
-
25
- session_id = result["session_id"]
26
- docs = jsonable_encoder(result.get("source_documents", []))
27
-
28
- bot_message = result["answer"]
29
- chat_history.append((message, bot_message))
30
-
31
- status = {"tokens": cb.total_tokens, "cost": f"${cb.total_cost:.4f}"}
32
- return "", chat_history, session_id, status, docs
33
-
34
-
35
- def recreate(instruction):
36
- global task
37
- task = QaTask(instruction=instruction, knowledge=load_vectorstore("example").as_retriever())
38
-
39
-
40
- def clear(msg, chatbot, session_id, telemetry, docs):
41
- return "", "", "", '{"tokens":0}', '["docs"]'
42
-
43
-
44
- with gr.Blocks() as qa_ui:
45
- with gr.Row():
46
- with gr.Column(scale=6):
47
- instruction = gr.Textbox(label="Instruction", value=DEFAULT_INSTRUCTION, interactive=False)
48
- with gr.Column(scale=1):
49
- apply = gr.Button(value="更换Prompt")
50
- clear_btn = gr.Button(value="清空")
51
-
52
- with gr.Row():
53
- with gr.Column(scale=6):
54
  with gr.Row():
55
- chatbot = gr.Chatbot(height=500)
56
- with gr.Row():
57
- msg = gr.Textbox()
58
- with gr.Column(scale=1):
59
- with gr.Row():
60
- session_id = gr.Textbox(label="Session", interactive=False, value="")
61
- with gr.Row():
62
- telemetry = gr.JSON(value="""{"tokens":0}""", label="Telemetry")
63
- with gr.Row():
64
- docs = gr.JSON(value="""["docs"]""", label="Docs")
65
-
66
- clear_btn.click(clear, [msg, chatbot, session_id, telemetry, docs], [msg, chatbot, session_id, telemetry, docs])
67
- msg.submit(respond, [msg, chatbot, session_id], [msg, chatbot, session_id, telemetry, docs])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from fastapi.encoders import jsonable_encoder
3
  from langchain.callbacks import get_openai_callback
4
 
5
+ from edu_assistant.learning_tasks.qa import DEFAULT_INSTRUCTION, QaTask
6
+ from edu_assistant.utils.langchain_utils import load_vectorstore, shrink_docs
7
+
8
+
9
+ class QaUI:
10
+ def __init__(
11
+ self, *, instruction: str = DEFAULT_INSTRUCTION, enable_gpt4: bool = False, knowledge_name: str = "example"
12
+ ):
13
+ self._init_task(instruction, knowledge_name, enable_gpt4)
14
+ self._init_ui()
15
+
16
+ def ui_render(self):
17
+ self.ui.render()
18
+
19
+ def ui_reload(
20
+ self,
21
+ *,
22
+ instruction: str = DEFAULT_INSTRUCTION,
23
+ knowledge_name: str = "example",
24
+ enable_gpt4: bool = False,
25
+ refresh: bool = True,
26
+ ):
27
+ self._init_task(instruction, knowledge_name, enable_gpt4)
28
+
29
+ if refresh:
30
+ self.ui_render()
31
+
32
+ def get_instruction(self):
33
+ return self.instruction
34
+
35
+ def _init_task(self, instruction, knowledge_name, enable_gpt4):
36
+ self.instruction = instruction
37
+ self.knowledge = knowledge_name
38
+ self.enable_gpt4 = enable_gpt4
39
+ self.task = QaTask(
40
+ instruction=instruction,
41
+ knowledge=load_vectorstore(knowledge_name).as_retriever(),
42
+ enable_gpt4=enable_gpt4,
43
+ )
44
+
45
+ def _init_ui(self):
46
+ with gr.Blocks() as ui:
 
 
 
 
 
 
 
47
  with gr.Row():
48
+ with gr.Column(scale=6):
49
+ with gr.Row():
50
+ chatbot = gr.Chatbot(height=500, label="聊天记录")
51
+ with gr.Row():
52
+ msg = gr.Textbox(show_label=False)
53
+ with gr.Column(scale=1):
54
+ with gr.Row():
55
+ clear_button = gr.Button(value="清空")
56
+ with gr.Row():
57
+ session_id = gr.Textbox(label="Session", interactive=False, value="")
58
+ with gr.Row():
59
+ status = gr.JSON(value="""{"tokens":0}""", label="Status")
60
+ with gr.Row():
61
+ docs = gr.JSON(value="""["docs"]""", label="Docs")
62
+
63
+ clear_button.click(self._clear, [], [msg, chatbot, session_id, status, docs])
64
+ msg.submit(self._respond, [msg, chatbot, session_id], [msg, chatbot, session_id, status, docs])
65
+
66
+ self.ui = ui
67
+
68
+ def _respond(self, message, chat_history, session_id):
69
+ with get_openai_callback() as cb:
70
+ if session_id:
71
+ result = self.task.ask(message, session_id=session_id)
72
+ else:
73
+ result = self.task.ask(message)
74
+
75
+ session_id = result["session_id"]
76
+ docs = jsonable_encoder(shrink_docs(result.get("source_documents", [])))
77
+
78
+ bot_message = result["answer"]
79
+ chat_history.append((message, bot_message))
80
+
81
+ status = {"tokens": cb.total_tokens, "cost": f"${cb.total_cost:.4f}"}
82
+ return "", chat_history, session_id, status, docs
83
+
84
+ def _clear(self):
85
+ return "", [], "", {"tokens": 0}, ["docs"]
webui/ui.py CHANGED
@@ -1,19 +1,107 @@
1
  import gradio as gr
 
 
2
 
3
  from edu_assistant import version
4
- from webui.coding_problem import coding_problem_ui
5
- from webui.qa import qa_ui
6
 
7
- with gr.Blocks() as ui:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  with gr.Row():
9
- gr.Markdown(f" v{version.VERSION}")
10
 
11
  with gr.Tab(label="答疑"):
12
- qa_ui.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- with gr.Tab(label="编程题"):
15
- coding_problem_ui.render()
16
 
17
  if __name__ == "__main__":
18
- ui.queue()
19
- ui.launch()
 
1
  import gradio as gr
2
+ import uvicorn
3
+ from fastapi import FastAPI
4
 
5
  from edu_assistant import version
6
+ from webui.coding_problem import CodingProblemUI
7
+ from webui.qa import QaUI
8
 
9
+ app = FastAPI()
10
+ demo = gr.Blocks(title="Codedog Edu Assistant", theme="gradio/soft")
11
+ qa_ui = QaUI()
12
+ cp_ui = CodingProblemUI()
13
+
14
+
15
+ def apply_cfg(
16
+ gpt4_flags: list[int],
17
+ qa_instruction: str,
18
+ cp_instruction: str,
19
+ cp_first_question: str,
20
+ qa_knowledge: str,
21
+ cp_knowledge: str,
22
+ ):
23
+ qa_ui.ui_reload(
24
+ instruction=qa_instruction,
25
+ knowledge_name=qa_knowledge,
26
+ enable_gpt4=0 in gpt4_flags,
27
+ )
28
+ cp_ui.ui_reload(
29
+ instruction=cp_instruction,
30
+ first_question=cp_first_question,
31
+ knowledge_name=cp_knowledge,
32
+ enable_gpt4=1 in gpt4_flags,
33
+ )
34
+ demo.render()
35
+ gr.update()
36
+ gr.Info("更新配置成功")
37
+
38
+
39
+ def default_cfg():
40
+ qa_ui.ui_reload()
41
+ cp_ui.ui_reload()
42
+ demo.render()
43
+ gr.update()
44
+ gr.Info("恢复默认配置成功")
45
+
46
+
47
+ def get_gpt4_flags():
48
+ result = []
49
+ if qa_ui.enable_gpt4:
50
+ result.append("答疑")
51
+ if cp_ui.enable_gpt4:
52
+ result.append("做题")
53
+ return result
54
+
55
+
56
+ with demo:
57
  with gr.Row():
58
+ gr.Markdown(f"# Codedog Edu Assistant v{version.VERSION}")
59
 
60
  with gr.Tab(label="答疑"):
61
+ qa_ui.ui_render()
62
+
63
+ with gr.Tab(label="做题"):
64
+ cp_ui.ui_render()
65
+
66
+ with gr.Tab(label="设置"):
67
+ with gr.Row():
68
+ gr.Markdown("## Prompt 设置")
69
+ with gr.Row():
70
+ qa_instruction = gr.Textbox(
71
+ label="答疑指示Prompt", lines=5, max_lines=20, value=qa_ui.get_instruction, interactive=True
72
+ )
73
+ with gr.Row():
74
+ cp_instruction = gr.Textbox(
75
+ label="做题指示Prompt", lines=5, max_lines=20, value=cp_ui.get_instruction, interactive=True
76
+ )
77
+ with gr.Row():
78
+ cp_first_question = gr.Textbox(
79
+ label="判题Prompt", lines=5, max_lines=20, value=cp_ui.get_first_question, interactive=True
80
+ )
81
+ with gr.Row():
82
+ with gr.Column(scale=1):
83
+ gr.Markdown("## Open AI 设置")
84
+ with gr.Column(scale=2):
85
+ gpt4_flags = gr.CheckboxGroup(
86
+ value=get_gpt4_flags, choices=["答疑", "做题"], label="启用GPT4", type="index", interactive=True
87
+ )
88
+
89
+ with gr.Row():
90
+ gr.Markdown("## 知识库设置")
91
+ qa_knowledge = gr.Textbox(value=qa_ui.knowledge, label="答疑知识库", interactive=True)
92
+ cp_knowledge = gr.Textbox(value=cp_ui.knowledge, label="做题知识库", interactive=True)
93
+
94
+ with gr.Row():
95
+ default_btn = gr.Button(value="恢复默认配置", interactive=True, scale=1)
96
+ apply_btn = gr.Button(value="更新配置", interactive=True, variant="primary", scale=1)
97
+
98
+ default_btn.click(default_cfg, [], [])
99
+ apply_btn.click(
100
+ apply_cfg, [gpt4_flags, qa_instruction, cp_instruction, cp_first_question, qa_knowledge, cp_knowledge], []
101
+ )
102
 
103
+ demo.queue()
104
+ app = gr.mount_gradio_app(app, demo, path="/")
105
 
106
  if __name__ == "__main__":
107
+ uvicorn.run(app, port=7860)