Spaces:
Runtime error
Runtime error
feat: :sparkles: Configuration (#22)
Browse files* feat: :sparkles: Configuration
Now we can config prompt, gpt4, knowledge base in gradio. Hooray!
- edu_assistant/learning_tasks/coding_problem.py +101 -22
- edu_assistant/learning_tasks/qa.py +53 -37
- edu_assistant/utils/langchain_utils.py +17 -0
- tests/learning_tasks/test_coding_problem.py +1 -0
- tests/learning_tasks/test_qa.py +6 -4
- webui/coding_problem.py +236 -172
- webui/qa.py +80 -62
- webui/ui.py +97 -9
edu_assistant/learning_tasks/coding_problem.py
CHANGED
@@ -7,7 +7,11 @@ from pydantic import BaseModel, Field
|
|
7 |
from pydantic_redis import Model, Store
|
8 |
|
9 |
from edu_assistant.learning_tasks.base import BaseTask
|
10 |
-
from edu_assistant.utils.langchain_utils import
|
|
|
|
|
|
|
|
|
11 |
from edu_assistant.utils.redis_utils import get_redis_config
|
12 |
|
13 |
TEMPLATE = """The following is a friendly conversation between a human and an ai.
|
@@ -29,10 +33,40 @@ Student's code:
|
|
29 |
---
|
30 |
|
31 |
Current conversation:
|
32 |
-
{{
|
33 |
Human: {{input}}
|
34 |
AI:"""
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
DEFAULT_INSTRUCTION = """Act as a c++ professional to check student's code.
|
37 |
The code is written by a student aged 5-10 and mostly like to buggy or bad performanced.
|
38 |
"""
|
@@ -59,14 +93,15 @@ class CodingProblem(Model):
|
|
59 |
store.register_model(CodingProblem)
|
60 |
|
61 |
def expr(self, lang=""):
|
62 |
-
expr = f"## Question\n\n
|
63 |
expr += (
|
64 |
-
f"""## Standard Answer (There might be others)\n\n```{lang
|
|
|
65 |
"""
|
66 |
if self.standard_answer
|
67 |
else ""
|
68 |
)
|
69 |
-
expr += f"## Analysis\n\n
|
70 |
expr += "## Extra\n\n" + escape_for_prompt("".join(self.extra)) + "\n"
|
71 |
return expr
|
72 |
|
@@ -88,14 +123,35 @@ class CodingAnswer(BaseModel):
|
|
88 |
|
89 |
|
90 |
class CodingProblemAnalysis(BaseTask):
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
self
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
self.instruction = instruction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
self._session_store = {}
|
97 |
self._knowledge = knowledge
|
98 |
|
|
|
|
|
99 |
@staticmethod
|
100 |
def build_coding_problem(question: str, standard_answer: str = "", analysis: str = "", extra: list[str] = None):
|
101 |
extra = [] if extra is None else extra
|
@@ -120,7 +176,13 @@ class CodingProblemAnalysis(BaseTask):
|
|
120 |
session_id = self._create_session_id()
|
121 |
self._session_store[session_id] = chain
|
122 |
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
result["session_id"] = session_id
|
126 |
|
@@ -143,35 +205,52 @@ class CodingProblemAnalysis(BaseTask):
|
|
143 |
|
144 |
chain = self._session_store[session_id]
|
145 |
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
result["session_id"] = session_id
|
149 |
|
150 |
return result
|
151 |
|
|
|
|
|
|
|
|
|
152 |
def _build_chain(self, problem: CodingProblem, answer: CodingAnswer) -> Chain:
|
153 |
-
|
154 |
-
|
155 |
-
prompt = PromptTemplate.from_template(
|
156 |
-
TEMPLATE.format(
|
157 |
-
instruction=self.instruction,
|
158 |
-
problem=problem.expr(lang=problem.language or self.lang or ""),
|
159 |
-
answer=answer.expr(lang=problem.language or self.lang or ""),
|
160 |
-
)
|
161 |
)
|
162 |
|
163 |
if not self._knowledge:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
return ConversationChain(
|
165 |
-
llm=
|
166 |
memory=memory,
|
167 |
prompt=prompt,
|
168 |
)
|
169 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
return ConversationalRetrievalChain.from_llm(
|
171 |
-
llm=
|
172 |
memory=memory,
|
173 |
retriever=self._knowledge,
|
174 |
-
condense_question_llm=
|
175 |
return_source_documents=True,
|
176 |
combine_docs_chain_kwargs={"prompt": prompt},
|
177 |
)
|
|
|
7 |
from pydantic_redis import Model, Store
|
8 |
|
9 |
from edu_assistant.learning_tasks.base import BaseTask
|
10 |
+
from edu_assistant.utils.langchain_utils import (
|
11 |
+
escape_for_prompt,
|
12 |
+
load_gpt4_llm,
|
13 |
+
load_llm,
|
14 |
+
)
|
15 |
from edu_assistant.utils.redis_utils import get_redis_config
|
16 |
|
17 |
TEMPLATE = """The following is a friendly conversation between a human and an ai.
|
|
|
33 |
---
|
34 |
|
35 |
Current conversation:
|
36 |
+
{{chat_history}}
|
37 |
Human: {{input}}
|
38 |
AI:"""
|
39 |
|
40 |
+
KNOWLEDGE_TEMPLATE = """The following is a friendly conversation between a human and an ai.
|
41 |
+
The ai is talkative and provides lots of specific details from its context.
|
42 |
+
If the ai does not know the answer to a question, it truthfully says it does not know.
|
43 |
+
The ai act following below instructions:
|
44 |
+
---
|
45 |
+
{instruction}
|
46 |
+
---
|
47 |
+
|
48 |
+
The coding problem:
|
49 |
+
---
|
50 |
+
{problem}
|
51 |
+
---
|
52 |
+
|
53 |
+
Student's code:
|
54 |
+
```
|
55 |
+
{answer}
|
56 |
+
```
|
57 |
+
|
58 |
+
Extra Information might be helpful for you:
|
59 |
+
---
|
60 |
+
{{context}}
|
61 |
+
---
|
62 |
+
|
63 |
+
Current conversation:
|
64 |
+
{{chat_history}}
|
65 |
+
Human: {{question}}
|
66 |
+
AI:
|
67 |
+
"""
|
68 |
+
|
69 |
+
|
70 |
DEFAULT_INSTRUCTION = """Act as a c++ professional to check student's code.
|
71 |
The code is written by a student aged 5-10 and mostly like to buggy or bad performanced.
|
72 |
"""
|
|
|
93 |
store.register_model(CodingProblem)
|
94 |
|
95 |
def expr(self, lang=""):
|
96 |
+
expr = f"## Question\n\n---\n{escape_for_prompt(self.question)}\n---\n\n"
|
97 |
expr += (
|
98 |
+
f"""## Standard Answer (There might be others)\n\n```{lang if lang else self.language}
|
99 |
+
{escape_for_prompt(self.standard_answer)}\n```
|
100 |
"""
|
101 |
if self.standard_answer
|
102 |
else ""
|
103 |
)
|
104 |
+
expr += f"## Analysis\n\n---\n{escape_for_prompt(self.analysis)}\n---\n\n" if self.analysis else ""
|
105 |
expr += "## Extra\n\n" + escape_for_prompt("".join(self.extra)) + "\n"
|
106 |
return expr
|
107 |
|
|
|
123 |
|
124 |
|
125 |
class CodingProblemAnalysis(BaseTask):
|
126 |
+
HISTORY_KEY = "chat_history"
|
127 |
+
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
instruction: str = DEFAULT_INSTRUCTION,
|
131 |
+
first_question: str = DEFAULT_FIRST_QUESTION,
|
132 |
+
lang: str = "",
|
133 |
+
knowledge: BaseRetriever = None,
|
134 |
+
enable_gpt4: bool = False,
|
135 |
+
):
|
136 |
self.instruction = instruction
|
137 |
+
self.first_question = first_question
|
138 |
+
self.lang = lang
|
139 |
+
self.enable_gpt4 = enable_gpt4
|
140 |
+
# TODO: load threshold key from implement. value from config
|
141 |
+
self.vectordbkwargs = {"score_threshold": 0.9} # Qdrant cosine. higher is better.
|
142 |
+
|
143 |
+
if knowledge:
|
144 |
+
self._input_key = "question"
|
145 |
+
self._output_key = "answer"
|
146 |
+
else:
|
147 |
+
self._input_key = "input"
|
148 |
+
self._output_key = "response"
|
149 |
+
|
150 |
self._session_store = {}
|
151 |
self._knowledge = knowledge
|
152 |
|
153 |
+
self._init_llm()
|
154 |
+
|
155 |
@staticmethod
|
156 |
def build_coding_problem(question: str, standard_answer: str = "", analysis: str = "", extra: list[str] = None):
|
157 |
extra = [] if extra is None else extra
|
|
|
176 |
session_id = self._create_session_id()
|
177 |
self._session_store[session_id] = chain
|
178 |
|
179 |
+
args = {self._input_key: first_question if first_question else self.first_question, self.HISTORY_KEY: ""}
|
180 |
+
|
181 |
+
# TODO: ConversationalRetrievalChain should support vectordbkwargs
|
182 |
+
# if self._knowledge:
|
183 |
+
# args["vectordbkwargs"] = self.vectordbkwargs
|
184 |
+
|
185 |
+
result = chain(args)
|
186 |
|
187 |
result["session_id"] = session_id
|
188 |
|
|
|
205 |
|
206 |
chain = self._session_store[session_id]
|
207 |
|
208 |
+
args = {self._input_key: question}
|
209 |
+
|
210 |
+
# if self._knowledge:
|
211 |
+
# args["vectordbkwargs"] = self.vectordbkwargs
|
212 |
+
|
213 |
+
result = chain(args)
|
214 |
|
215 |
result["session_id"] = session_id
|
216 |
|
217 |
return result
|
218 |
|
219 |
+
def _init_llm(self):
|
220 |
+
self._main_llm = load_gpt4_llm() if self.enable_gpt4 else load_llm()
|
221 |
+
self._secondary_llm = load_llm()
|
222 |
+
|
223 |
def _build_chain(self, problem: CodingProblem, answer: CodingAnswer) -> Chain:
|
224 |
+
memory = ConversationBufferMemory(
|
225 |
+
memory_key=self.HISTORY_KEY, output_key=self._output_key, return_messages=True
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
)
|
227 |
|
228 |
if not self._knowledge:
|
229 |
+
prompt = PromptTemplate.from_template(
|
230 |
+
TEMPLATE.format(
|
231 |
+
instruction=self.instruction,
|
232 |
+
problem=problem.expr(lang=problem.language or self.lang),
|
233 |
+
answer=answer.expr(lang=problem.language or self.lang),
|
234 |
+
)
|
235 |
+
)
|
236 |
return ConversationChain(
|
237 |
+
llm=self._main_llm,
|
238 |
memory=memory,
|
239 |
prompt=prompt,
|
240 |
)
|
241 |
else:
|
242 |
+
prompt = PromptTemplate.from_template(
|
243 |
+
KNOWLEDGE_TEMPLATE.format(
|
244 |
+
instruction=self.instruction,
|
245 |
+
problem=problem.expr(lang=problem.language or self.lang),
|
246 |
+
answer=answer.expr(lang=problem.language or self.lang),
|
247 |
+
)
|
248 |
+
)
|
249 |
return ConversationalRetrievalChain.from_llm(
|
250 |
+
llm=self._main_llm,
|
251 |
memory=memory,
|
252 |
retriever=self._knowledge,
|
253 |
+
condense_question_llm=self._secondary_llm,
|
254 |
return_source_documents=True,
|
255 |
combine_docs_chain_kwargs={"prompt": prompt},
|
256 |
)
|
edu_assistant/learning_tasks/qa.py
CHANGED
@@ -9,7 +9,7 @@ from langchain.memory import ConversationBufferMemory
|
|
9 |
from langchain.schema import BaseRetriever
|
10 |
|
11 |
from edu_assistant.learning_tasks.base import BaseTask
|
12 |
-
from edu_assistant.utils.langchain_utils import load_llm
|
13 |
|
14 |
TEMPLATE_CHAT = """The following is a friendly conversation between a human and an ai.
|
15 |
The ai is talkative and provides lots of specific details from its context.
|
@@ -68,6 +68,9 @@ Useful context for you to answer the question:
|
|
68 |
{{input}}
|
69 |
"""
|
70 |
|
|
|
|
|
|
|
71 |
|
72 |
class QaTask(BaseTask):
|
73 |
_session_store: dict
|
@@ -76,7 +79,7 @@ class QaTask(BaseTask):
|
|
76 |
|
77 |
HISTORY_KEY = "chat_history"
|
78 |
|
79 |
-
def __init__(self, instruction: str =
|
80 |
"""Create a new QaTask service.
|
81 |
|
82 |
Args:
|
@@ -88,6 +91,10 @@ class QaTask(BaseTask):
|
|
88 |
If not set, will use internal memory to store chat history. Which will be lost after restart and might
|
89 |
cost huge memory.
|
90 |
"""
|
|
|
|
|
|
|
|
|
91 |
if knowledge:
|
92 |
self._chat_prompt = PromptTemplate.from_template(TEMPLATE_CHAT_CONTEXT.format(instruction=instruction))
|
93 |
self._once_prompt = PromptTemplate.from_template(TEMPLATE_ONCE_CONTEXT.format(instruction=instruction))
|
@@ -102,42 +109,9 @@ class QaTask(BaseTask):
|
|
102 |
self._session_store = {}
|
103 |
self._knowledge = knowledge
|
104 |
|
105 |
-
self.
|
106 |
-
|
107 |
-
def _build_once_chain(self):
|
108 |
-
if not self._knowledge:
|
109 |
-
return LLMChain(
|
110 |
-
llm=load_llm(),
|
111 |
-
prompt=self._once_prompt,
|
112 |
-
)
|
113 |
-
else:
|
114 |
-
return RetrievalQA.from_llm(
|
115 |
-
llm=load_llm(),
|
116 |
-
retriever=self._knowledge,
|
117 |
-
return_source_documents=True,
|
118 |
-
prompt=self._once_prompt,
|
119 |
-
)
|
120 |
|
121 |
-
|
122 |
-
if not self._knowledge:
|
123 |
-
return ConversationChain(
|
124 |
-
llm=load_llm(),
|
125 |
-
memory=ConversationBufferMemory(
|
126 |
-
memory_key=QaTask.HISTORY_KEY, output_key=self._output_key, return_messages=True
|
127 |
-
),
|
128 |
-
prompt=self._chat_prompt,
|
129 |
-
)
|
130 |
-
else:
|
131 |
-
return ConversationalRetrievalChain.from_llm(
|
132 |
-
llm=load_llm(),
|
133 |
-
retriever=self._knowledge,
|
134 |
-
condense_question_llm=load_llm(),
|
135 |
-
return_source_documents=True,
|
136 |
-
combine_docs_chain_kwargs={"prompt": self._chat_prompt},
|
137 |
-
memory=ConversationBufferMemory(
|
138 |
-
memory_key=QaTask.HISTORY_KEY, output_key=self._output_key, return_messages=True
|
139 |
-
),
|
140 |
-
)
|
141 |
|
142 |
def ask(
|
143 |
self,
|
@@ -176,6 +150,9 @@ class QaTask(BaseTask):
|
|
176 |
if session_mem:
|
177 |
chain.memory = session_mem
|
178 |
|
|
|
|
|
|
|
179 |
result = chain(args)
|
180 |
|
181 |
if session_id:
|
@@ -183,6 +160,45 @@ class QaTask(BaseTask):
|
|
183 |
|
184 |
return result
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
def _create_session_chain(self, session_id) -> ConversationChain:
|
187 |
chain = self._build_chat_chain()
|
188 |
self._session_store[session_id] = chain
|
|
|
9 |
from langchain.schema import BaseRetriever
|
10 |
|
11 |
from edu_assistant.learning_tasks.base import BaseTask
|
12 |
+
from edu_assistant.utils.langchain_utils import load_gpt4_llm, load_llm
|
13 |
|
14 |
TEMPLATE_CHAT = """The following is a friendly conversation between a human and an ai.
|
15 |
The ai is talkative and provides lots of specific details from its context.
|
|
|
68 |
{{input}}
|
69 |
"""
|
70 |
|
71 |
+
DEFAULT_INSTRUCTION = """Act as a c++ professional to answer student aged 5-10 questions. Answer properly and politely.
|
72 |
+
Don't extend conversation multiple times. Only add one time saying."""
|
73 |
+
|
74 |
|
75 |
class QaTask(BaseTask):
|
76 |
_session_store: dict
|
|
|
79 |
|
80 |
HISTORY_KEY = "chat_history"
|
81 |
|
82 |
+
def __init__(self, instruction: str = DEFAULT_INSTRUCTION, knowledge: BaseRetriever = None, enable_gpt4=False):
|
83 |
"""Create a new QaTask service.
|
84 |
|
85 |
Args:
|
|
|
91 |
If not set, will use internal memory to store chat history. Which will be lost after restart and might
|
92 |
cost huge memory.
|
93 |
"""
|
94 |
+
self.enable_gpt4 = enable_gpt4
|
95 |
+
# TODO: load threshold key from implement. value from config
|
96 |
+
self.vectordbkwargs = {"score_threshold": 0.9} # Qdrant cosine. higher is better.
|
97 |
+
|
98 |
if knowledge:
|
99 |
self._chat_prompt = PromptTemplate.from_template(TEMPLATE_CHAT_CONTEXT.format(instruction=instruction))
|
100 |
self._once_prompt = PromptTemplate.from_template(TEMPLATE_ONCE_CONTEXT.format(instruction=instruction))
|
|
|
109 |
self._session_store = {}
|
110 |
self._knowledge = knowledge
|
111 |
|
112 |
+
self._init_llm()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
+
self._qa_once = self._build_once_chain()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
def ask(
|
117 |
self,
|
|
|
150 |
if session_mem:
|
151 |
chain.memory = session_mem
|
152 |
|
153 |
+
# TODO: ConversationalRetrievalChain should support vectordbkwargs
|
154 |
+
# if self._knowledge:
|
155 |
+
# args["vectordbkwargs"] = self.vectordbkwargs
|
156 |
result = chain(args)
|
157 |
|
158 |
if session_id:
|
|
|
160 |
|
161 |
return result
|
162 |
|
163 |
+
def _init_llm(self):
|
164 |
+
self._main_llm = load_gpt4_llm() if self.enable_gpt4 else load_llm()
|
165 |
+
self._secondary_llm = load_llm()
|
166 |
+
|
167 |
+
def _build_once_chain(self):
|
168 |
+
if not self._knowledge:
|
169 |
+
return LLMChain(
|
170 |
+
llm=self._main_llm,
|
171 |
+
prompt=self._once_prompt,
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
return RetrievalQA.from_llm(
|
175 |
+
llm=self._main_llm,
|
176 |
+
retriever=self._knowledge,
|
177 |
+
return_source_documents=True,
|
178 |
+
prompt=self._once_prompt,
|
179 |
+
)
|
180 |
+
|
181 |
+
def _build_chat_chain(self):
|
182 |
+
if not self._knowledge:
|
183 |
+
return ConversationChain(
|
184 |
+
llm=self._main_llm,
|
185 |
+
memory=ConversationBufferMemory(
|
186 |
+
memory_key=QaTask.HISTORY_KEY, output_key=self._output_key, return_messages=True
|
187 |
+
),
|
188 |
+
prompt=self._chat_prompt,
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
return ConversationalRetrievalChain.from_llm(
|
192 |
+
llm=self._main_llm,
|
193 |
+
retriever=self._knowledge,
|
194 |
+
condense_question_llm=self._secondary_llm,
|
195 |
+
return_source_documents=True,
|
196 |
+
combine_docs_chain_kwargs={"prompt": self._chat_prompt},
|
197 |
+
memory=ConversationBufferMemory(
|
198 |
+
memory_key=QaTask.HISTORY_KEY, output_key=self._output_key, return_messages=True
|
199 |
+
),
|
200 |
+
)
|
201 |
+
|
202 |
def _create_session_chain(self, session_id) -> ConversationChain:
|
203 |
chain = self._build_chat_chain()
|
204 |
self._session_store[session_id] = chain
|
edu_assistant/utils/langchain_utils.py
CHANGED
@@ -4,6 +4,7 @@ from functools import lru_cache
|
|
4 |
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
5 |
from langchain.chat_models.base import BaseChatModel
|
6 |
from langchain.embeddings import OpenAIEmbeddings
|
|
|
7 |
from langchain.vectorstores import Qdrant, VectorStore
|
8 |
|
9 |
from edu_assistant.utils.qdrant_utils import load_qdrant_client
|
@@ -42,6 +43,11 @@ def load_gpt4_llm() -> BaseChatModel:
|
|
42 |
return llm
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
45 |
@lru_cache(maxsize=1)
|
46 |
def load_embeddings():
|
47 |
if os.environ.get("AZURE_OPENAI"):
|
@@ -84,3 +90,14 @@ def escape_for_prompt(text: str) -> str:
|
|
84 |
str: escaped string.
|
85 |
"""
|
86 |
return text.replace("{", "{{").replace("}", "}}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
5 |
from langchain.chat_models.base import BaseChatModel
|
6 |
from langchain.embeddings import OpenAIEmbeddings
|
7 |
+
from langchain.schema import Document
|
8 |
from langchain.vectorstores import Qdrant, VectorStore
|
9 |
|
10 |
from edu_assistant.utils.qdrant_utils import load_qdrant_client
|
|
|
43 |
return llm
|
44 |
|
45 |
|
46 |
+
@lru_cache(maxsize=1)
|
47 |
+
def load_gpt4_flag() -> bool:
|
48 |
+
return os.environ.get("CODEDOG_ENABLE_GPT4") is not None
|
49 |
+
|
50 |
+
|
51 |
@lru_cache(maxsize=1)
|
52 |
def load_embeddings():
|
53 |
if os.environ.get("AZURE_OPENAI"):
|
|
|
90 |
str: escaped string.
|
91 |
"""
|
92 |
return text.replace("{", "{{").replace("}", "}}")
|
93 |
+
|
94 |
+
|
95 |
+
def shrink_docs(docs: list[Document], max_size=50):
|
96 |
+
"""shrink source docs content size for display.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
docs (dict): Retrieval Chain returned docs.
|
100 |
+
"""
|
101 |
+
for doc in docs:
|
102 |
+
doc.page_content = doc.page_content[:max_size] + ".."
|
103 |
+
return docs
|
tests/learning_tasks/test_coding_problem.py
CHANGED
@@ -5,6 +5,7 @@ from edu_assistant.learning_tasks import CodingProblemAnalysis
|
|
5 |
|
6 |
|
7 |
class TestCodingProblemAnalysis(TestCase):
|
|
|
8 |
def setUp(self):
|
9 |
self.analysis = CodingProblemAnalysis()
|
10 |
|
|
|
5 |
|
6 |
|
7 |
class TestCodingProblemAnalysis(TestCase):
|
8 |
+
@patch.object(CodingProblemAnalysis, "_init_llm", MagicMock())
|
9 |
def setUp(self):
|
10 |
self.analysis = CodingProblemAnalysis()
|
11 |
|
tests/learning_tasks/test_qa.py
CHANGED
@@ -6,8 +6,9 @@ from edu_assistant.learning_tasks import QaTask
|
|
6 |
from edu_assistant.learning_tasks.qa import TEMPLATE_CHAT, TEMPLATE_ONCE
|
7 |
|
8 |
|
|
|
9 |
@patch.object(QaTask, "_build_once_chain")
|
10 |
-
def test_init_without_knowledge(mocked_build_once_chain):
|
11 |
task = QaTask(instruction="test")
|
12 |
|
13 |
assert task._chat_prompt == PromptTemplate.from_template(TEMPLATE_CHAT.format(instruction="test"))
|
@@ -16,9 +17,10 @@ def test_init_without_knowledge(mocked_build_once_chain):
|
|
16 |
mocked_build_once_chain.assert_called_once()
|
17 |
|
18 |
|
|
|
19 |
@patch.object(QaTask, "_build_once_chain")
|
20 |
@patch.object(QaTask, "_create_session_chain")
|
21 |
-
def test_ask_with_session(mocked_create_session_chain, mocked_build_once_chain):
|
22 |
mocked_chain = MagicMock(return_value={"response": "ok"})
|
23 |
mocked_build_once_chain.return_value = mocked_chain
|
24 |
mocked_create_session_chain.return_value = mocked_chain
|
@@ -30,15 +32,15 @@ def test_ask_with_session(mocked_create_session_chain, mocked_build_once_chain):
|
|
30 |
result = task.ask("how are you?", session=True)
|
31 |
|
32 |
mock_create_id.assert_called_once()
|
33 |
-
mocked_create_session_chain.assert_called_once_with(123)
|
34 |
assert "session_id" in result
|
35 |
assert result["session_id"] == 123
|
36 |
assert "response" in result
|
37 |
assert result["response"] == "ok"
|
38 |
|
39 |
|
|
|
40 |
@patch.object(QaTask, "_build_once_chain")
|
41 |
-
def test_ask_without_session(mocked_build_once_chain):
|
42 |
mocked_llm = MagicMock()
|
43 |
mocked_llm.run.return_value = {"result": "ok"}
|
44 |
mocked_build_once_chain.return_value = mocked_llm
|
|
|
6 |
from edu_assistant.learning_tasks.qa import TEMPLATE_CHAT, TEMPLATE_ONCE
|
7 |
|
8 |
|
9 |
+
@patch.object(QaTask, "_init_llm")
|
10 |
@patch.object(QaTask, "_build_once_chain")
|
11 |
+
def test_init_without_knowledge(mocked_build_once_chain, mocked_init_llm):
|
12 |
task = QaTask(instruction="test")
|
13 |
|
14 |
assert task._chat_prompt == PromptTemplate.from_template(TEMPLATE_CHAT.format(instruction="test"))
|
|
|
17 |
mocked_build_once_chain.assert_called_once()
|
18 |
|
19 |
|
20 |
+
@patch.object(QaTask, "_init_llm")
|
21 |
@patch.object(QaTask, "_build_once_chain")
|
22 |
@patch.object(QaTask, "_create_session_chain")
|
23 |
+
def test_ask_with_session(mocked_create_session_chain, mocked_build_once_chain, mocked_init_llm):
|
24 |
mocked_chain = MagicMock(return_value={"response": "ok"})
|
25 |
mocked_build_once_chain.return_value = mocked_chain
|
26 |
mocked_create_session_chain.return_value = mocked_chain
|
|
|
32 |
result = task.ask("how are you?", session=True)
|
33 |
|
34 |
mock_create_id.assert_called_once()
|
|
|
35 |
assert "session_id" in result
|
36 |
assert result["session_id"] == 123
|
37 |
assert "response" in result
|
38 |
assert result["response"] == "ok"
|
39 |
|
40 |
|
41 |
+
@patch.object(QaTask, "_init_llm")
|
42 |
@patch.object(QaTask, "_build_once_chain")
|
43 |
+
def test_ask_without_session(mocked_build_once_chain, mocked_init_llm):
|
44 |
mocked_llm = MagicMock()
|
45 |
mocked_llm.run.return_value = {"result": "ok"}
|
46 |
mocked_build_once_chain.return_value = mocked_llm
|
webui/coding_problem.py
CHANGED
@@ -6,192 +6,256 @@ from langchain.callbacks import get_openai_callback
|
|
6 |
|
7 |
from edu_assistant.learning_tasks.coding_problem import (
|
8 |
DEFAULT_FIRST_QUESTION,
|
|
|
9 |
CodingProblem,
|
10 |
CodingProblemAnalysis,
|
11 |
)
|
|
|
12 |
|
13 |
CodingProblem.enable_redis_orm()
|
14 |
-
task = CodingProblemAnalysis()
|
15 |
-
|
16 |
-
|
17 |
-
def get_problems() -> list[str]:
|
18 |
-
data = CodingProblem.select(columns=["title"])
|
19 |
-
if not data:
|
20 |
-
return []
|
21 |
-
titles = [problem_data["title"] for problem_data in data]
|
22 |
-
return titles
|
23 |
-
|
24 |
-
|
25 |
-
def update_problems():
|
26 |
-
titles = get_problems()
|
27 |
-
gr.Info("更新题目列表成功")
|
28 |
-
return gr.Dropdown.update(choices=titles)
|
29 |
-
|
30 |
-
|
31 |
-
def select_problem(title: str):
|
32 |
-
problem: CodingProblem = CodingProblem.select(ids=[title])[0]
|
33 |
-
return (
|
34 |
-
problem.expr(),
|
35 |
-
problem.title,
|
36 |
-
problem.language,
|
37 |
-
problem.question,
|
38 |
-
problem.analysis,
|
39 |
-
problem.standard_answer,
|
40 |
-
json.dumps(problem.extra, ensure_ascii=False, indent=4),
|
41 |
-
)
|
42 |
-
|
43 |
-
|
44 |
-
def update_problem(title, language, problem, analysis, answer, extra):
|
45 |
-
# TODO: add language
|
46 |
-
try:
|
47 |
-
extra_data = json.loads(extra)
|
48 |
-
except json.JSONDecodeError:
|
49 |
-
extra_data = [extra]
|
50 |
-
|
51 |
-
CodingProblem.update(
|
52 |
-
title,
|
53 |
-
data={
|
54 |
-
"title": title,
|
55 |
-
"language": language,
|
56 |
-
"question": problem,
|
57 |
-
"analysis": analysis,
|
58 |
-
"standard_answer": answer,
|
59 |
-
"extra": extra_data,
|
60 |
-
},
|
61 |
-
)
|
62 |
-
gr.Info("更新题目成功")
|
63 |
-
|
64 |
-
|
65 |
-
def delete_problem(title):
|
66 |
-
CodingProblem.delete(ids=[title])
|
67 |
-
gr.Info("删除题目成功")
|
68 |
-
|
69 |
-
return "", "", "", "", "", "", "", ""
|
70 |
-
|
71 |
-
|
72 |
-
def analysis_problem(title, code, extra: str = ""):
|
73 |
-
problem = CodingProblem.select(ids=[title])[0]
|
74 |
-
answer = CodingProblemAnalysis.build_coding_answer(answer=code)
|
75 |
-
|
76 |
-
with get_openai_callback() as cb:
|
77 |
-
result = task.start_analysis(problem, answer)
|
78 |
-
status = {"tokens": cb.total_tokens, "cost": f"${cb.total_cost:.4f}"}
|
79 |
-
|
80 |
-
answer = result["response"]
|
81 |
-
session_id = result["session_id"]
|
82 |
-
docs = jsonable_encoder(result.get("source_documents", []))
|
83 |
-
return [(DEFAULT_FIRST_QUESTION, answer)], session_id, status, docs
|
84 |
-
|
85 |
-
|
86 |
-
def chat(message, chat_history, session_id):
|
87 |
-
if not session_id:
|
88 |
-
return "", "", {"tokens": 0}, []
|
89 |
-
with get_openai_callback() as cb:
|
90 |
-
result = task.ask(message, session_id=session_id)
|
91 |
-
if not result:
|
92 |
-
raise gr.Error("Session expired. Please recreate a new problem analysis session.")
|
93 |
|
94 |
-
session_id = result["session_id"]
|
95 |
-
docs = jsonable_encoder(result.get("source_documents", []))
|
96 |
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
|
|
|
101 |
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
|
|
|
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
with gr.Row():
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
with gr.Row():
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
with gr.Row():
|
131 |
-
|
|
|
|
|
|
|
132 |
with gr.Row():
|
133 |
-
|
134 |
with gr.Row():
|
135 |
-
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
from edu_assistant.learning_tasks.coding_problem import (
|
8 |
DEFAULT_FIRST_QUESTION,
|
9 |
+
DEFAULT_INSTRUCTION,
|
10 |
CodingProblem,
|
11 |
CodingProblemAnalysis,
|
12 |
)
|
13 |
+
from edu_assistant.utils.langchain_utils import load_vectorstore, shrink_docs
|
14 |
|
15 |
CodingProblem.enable_redis_orm()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
|
|
|
|
17 |
|
18 |
+
class CodingProblemUI:
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
*,
|
22 |
+
instruction: str = DEFAULT_INSTRUCTION,
|
23 |
+
first_question: str = DEFAULT_FIRST_QUESTION,
|
24 |
+
knowledge_name: str = "example",
|
25 |
+
enable_gpt4: bool = False,
|
26 |
+
):
|
27 |
+
self._init_task(instruction, first_question, knowledge_name, enable_gpt4)
|
28 |
+
self._init_ui()
|
29 |
|
30 |
+
def ui_render(self):
|
31 |
+
self.ui.render()
|
32 |
|
33 |
+
def ui_reload(
|
34 |
+
self,
|
35 |
+
*,
|
36 |
+
instruction: str = DEFAULT_INSTRUCTION,
|
37 |
+
first_question: str = DEFAULT_FIRST_QUESTION,
|
38 |
+
knowledge_name: str = "example",
|
39 |
+
enable_gpt4: bool = False,
|
40 |
+
refresh: bool = True,
|
41 |
+
):
|
42 |
+
self._init_task(instruction, first_question, knowledge_name, enable_gpt4)
|
43 |
|
44 |
+
if refresh:
|
45 |
+
self.ui_render()
|
46 |
|
47 |
+
def get_instruction(self):
|
48 |
+
return self.instruction
|
49 |
+
|
50 |
+
def get_first_question(self):
|
51 |
+
return self.first_question
|
52 |
+
|
53 |
+
def _init_task(self, instruction: str, first_question: str, knowledge_name: str, enable_gpt4: bool):
|
54 |
+
self.instruction = instruction
|
55 |
+
self.first_question = first_question
|
56 |
+
self.knowledge = knowledge_name
|
57 |
+
self.enable_gpt4 = enable_gpt4
|
58 |
+
self.task = CodingProblemAnalysis(
|
59 |
+
instruction=instruction,
|
60 |
+
first_question=first_question,
|
61 |
+
knowledge=load_vectorstore(knowledge_name).as_retriever(),
|
62 |
+
enable_gpt4=enable_gpt4,
|
63 |
+
)
|
64 |
+
|
65 |
+
def _init_ui(self):
|
66 |
+
self.ui = gr.Blocks()
|
67 |
+
with self.ui:
|
68 |
+
with gr.Row():
|
69 |
+
with gr.Column(scale=6):
|
70 |
+
problem_selector = gr.Dropdown(choices=self._get_problems(), show_label=False, interactive=True)
|
71 |
+
with gr.Column(scale=1):
|
72 |
+
refresh_btn = gr.Button(value="刷新")
|
73 |
+
with gr.Row():
|
74 |
+
with gr.Column(scale=6):
|
75 |
+
with gr.Tab(label="错误代码分析"):
|
76 |
with gr.Row():
|
77 |
+
with gr.Column(scale=3):
|
78 |
+
with gr.Row():
|
79 |
+
problem_view = gr.Markdown(label="题目")
|
80 |
+
with gr.Row():
|
81 |
+
code_view = gr.Textbox(label="代码", lines=10, interactive=True)
|
82 |
+
with gr.Column(scale=3):
|
83 |
+
with gr.Row():
|
84 |
+
chat_box = gr.Chatbot(height=500, label="聊天记录")
|
85 |
+
with gr.Row():
|
86 |
+
chat_input = gr.Textbox(show_label=False)
|
87 |
+
with gr.Column():
|
88 |
+
with gr.Row():
|
89 |
+
analysis_btn = gr.Button(value="分析")
|
90 |
+
with gr.Row():
|
91 |
+
clear = gr.ClearButton()
|
92 |
+
with gr.Row():
|
93 |
+
session_id = gr.Textbox(label="Session", interactive=False, value="")
|
94 |
+
with gr.Row():
|
95 |
+
status = gr.JSON(value={"tokens": 0}, label="Status")
|
96 |
+
with gr.Row():
|
97 |
+
docs = gr.JSON(value=["docs"], label="Docs")
|
98 |
+
|
99 |
+
with gr.Tab(label="题库管理"):
|
100 |
with gr.Row():
|
101 |
+
with gr.Column(scale=6):
|
102 |
+
with gr.Row():
|
103 |
+
title_edit = gr.Textbox(label="标题", interactive=True)
|
104 |
+
with gr.Row():
|
105 |
+
language_edit = gr.Dropdown(
|
106 |
+
choices=["python", "cpp", "java"],
|
107 |
+
label="语言",
|
108 |
+
interactive=True,
|
109 |
+
allow_custom_value=True,
|
110 |
+
)
|
111 |
+
with gr.Column(scale=1):
|
112 |
+
manage_update = gr.Button(value="更新")
|
113 |
+
manage_delete = gr.Button(value="删除", variant="stop")
|
114 |
with gr.Row():
|
115 |
+
with gr.Column():
|
116 |
+
problem_edit = gr.Textbox(label="题目", lines=10, max_lines=100, interactive=True)
|
117 |
+
with gr.Column():
|
118 |
+
analysis_edit = gr.Textbox(label="解析", lines=10, max_lines=100, interactive=True)
|
119 |
with gr.Row():
|
120 |
+
answer_edit = gr.Textbox(label="标准答案", lines=10, max_lines=100, interactive=True)
|
121 |
with gr.Row():
|
122 |
+
extra_edit = gr.Textbox(label="额外信息", lines=10, max_lines=100, interactive=True)
|
123 |
|
124 |
+
refresh_btn.click(self._update_problems, [], [problem_selector])
|
125 |
+
problem_selector.select(
|
126 |
+
self._select_problem,
|
127 |
+
[
|
128 |
+
problem_selector,
|
129 |
+
],
|
130 |
+
[problem_view, title_edit, language_edit, problem_edit, analysis_edit, answer_edit, extra_edit],
|
131 |
+
)
|
132 |
+
analysis_btn.click(
|
133 |
+
self._analysis_problem,
|
134 |
+
[problem_selector, code_view],
|
135 |
+
[chat_box, session_id, status, docs],
|
136 |
+
)
|
137 |
+
chat_input.submit(self._chat, [chat_input, chat_box, session_id], [chat_input, chat_box, status, docs])
|
138 |
+
|
139 |
+
manage_update.click(
|
140 |
+
self._update_problem,
|
141 |
+
[title_edit, language_edit, problem_edit, analysis_edit, answer_edit, extra_edit],
|
142 |
+
[],
|
143 |
+
)
|
144 |
+
manage_delete.click(
|
145 |
+
self._delete_problem,
|
146 |
+
[problem_selector],
|
147 |
+
[
|
148 |
+
problem_selector,
|
149 |
+
problem_view,
|
150 |
+
title_edit,
|
151 |
+
language_edit,
|
152 |
+
problem_edit,
|
153 |
+
analysis_edit,
|
154 |
+
answer_edit,
|
155 |
+
extra_edit,
|
156 |
+
],
|
157 |
+
)
|
158 |
+
clear.click(
|
159 |
+
self._clear,
|
160 |
+
[],
|
161 |
+
[
|
162 |
+
problem_selector,
|
163 |
+
problem_view,
|
164 |
+
code_view,
|
165 |
+
chat_box,
|
166 |
+
session_id,
|
167 |
+
status,
|
168 |
+
docs,
|
169 |
+
problem_view,
|
170 |
+
title_edit,
|
171 |
+
language_edit,
|
172 |
+
problem_edit,
|
173 |
+
analysis_edit,
|
174 |
+
answer_edit,
|
175 |
+
extra_edit,
|
176 |
+
],
|
177 |
+
)
|
178 |
+
|
179 |
+
def _get_problems(self) -> list[str]:
|
180 |
+
data = CodingProblem.select(columns=["title"])
|
181 |
+
if not data:
|
182 |
+
return []
|
183 |
+
titles = [problem_data["title"] for problem_data in data]
|
184 |
+
return titles
|
185 |
+
|
186 |
+
def _update_problems(self):
|
187 |
+
titles = self._get_problems()
|
188 |
+
gr.Info("更新题目列表成功")
|
189 |
+
return gr.Dropdown.update(choices=titles)
|
190 |
+
|
191 |
+
def _select_problem(self, title: str):
|
192 |
+
problem: CodingProblem = CodingProblem.select(ids=[title])[0]
|
193 |
+
return (
|
194 |
+
problem.expr(),
|
195 |
+
problem.title,
|
196 |
+
problem.language,
|
197 |
+
problem.question,
|
198 |
+
problem.analysis,
|
199 |
+
problem.standard_answer,
|
200 |
+
json.dumps(problem.extra, ensure_ascii=False, indent=4),
|
201 |
+
)
|
202 |
+
|
203 |
+
def _update_problem(self, title, language, problem, analysis, answer, extra):
|
204 |
+
# TODO: add language
|
205 |
+
try:
|
206 |
+
extra_data = json.loads(extra)
|
207 |
+
except json.JSONDecodeError:
|
208 |
+
extra_data = [extra]
|
209 |
+
|
210 |
+
CodingProblem.update(
|
211 |
+
title,
|
212 |
+
data={
|
213 |
+
"title": title,
|
214 |
+
"language": language,
|
215 |
+
"question": problem,
|
216 |
+
"analysis": analysis,
|
217 |
+
"standard_answer": answer,
|
218 |
+
"extra": extra_data,
|
219 |
+
},
|
220 |
+
)
|
221 |
+
gr.Info("更新题目成功")
|
222 |
+
|
223 |
+
def _delete_problem(self, title):
|
224 |
+
CodingProblem.delete(ids=[title])
|
225 |
+
gr.Info("删除题目成功")
|
226 |
+
|
227 |
+
return "", "", "", "", "", "", "", ""
|
228 |
+
|
229 |
+
def _analysis_problem(self, title, code, extra: str = ""):
|
230 |
+
problem = CodingProblem.select(ids=[title])[0]
|
231 |
+
answer = CodingProblemAnalysis.build_coding_answer(answer=code, extra=[extra])
|
232 |
+
|
233 |
+
with get_openai_callback() as cb:
|
234 |
+
result = self.task.start_analysis(problem, answer)
|
235 |
+
status = {"tokens": cb.total_tokens, "cost": f"${cb.total_cost:.4f}"}
|
236 |
+
|
237 |
+
answer = result["answer"]
|
238 |
+
session_id = result["session_id"]
|
239 |
+
docs = jsonable_encoder(shrink_docs(result.get("source_documents", [])))
|
240 |
+
return [(self.first_question, answer)], session_id, status, docs
|
241 |
+
|
242 |
+
def _chat(self, message, chat_history, session_id):
|
243 |
+
if not session_id:
|
244 |
+
return "", "", {"tokens": 0}, []
|
245 |
+
with get_openai_callback() as cb:
|
246 |
+
result = self.task.ask(message, session_id=session_id)
|
247 |
+
if not result:
|
248 |
+
raise gr.Error("Session expired. Please recreate a new problem analysis session.")
|
249 |
+
|
250 |
+
session_id = result["session_id"]
|
251 |
+
docs = jsonable_encoder(result.get("source_documents", []))
|
252 |
+
|
253 |
+
bot_message = result["answer"]
|
254 |
+
chat_history.append((message, bot_message))
|
255 |
+
|
256 |
+
status = {"tokens": cb.total_tokens, "cost": f"${cb.total_cost:.4f}"}
|
257 |
+
|
258 |
+
return "", chat_history, status, docs
|
259 |
+
|
260 |
+
def _clear(self):
|
261 |
+
return "", "", "", [], "", {"tokens": 0}, ["docs"], "", "", "", "", "", "", ""
|
webui/qa.py
CHANGED
@@ -2,66 +2,84 @@ import gradio as gr
|
|
2 |
from fastapi.encoders import jsonable_encoder
|
3 |
from langchain.callbacks import get_openai_callback
|
4 |
|
5 |
-
from edu_assistant.learning_tasks.qa import QaTask
|
6 |
-
from edu_assistant.utils.langchain_utils import load_vectorstore
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
def
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
with gr.
|
47 |
-
instruction = gr.Textbox(label="Instruction", value=DEFAULT_INSTRUCTION, interactive=False)
|
48 |
-
with gr.Column(scale=1):
|
49 |
-
apply = gr.Button(value="更换Prompt")
|
50 |
-
clear_btn = gr.Button(value="清空")
|
51 |
-
|
52 |
-
with gr.Row():
|
53 |
-
with gr.Column(scale=6):
|
54 |
with gr.Row():
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from fastapi.encoders import jsonable_encoder
|
3 |
from langchain.callbacks import get_openai_callback
|
4 |
|
5 |
+
from edu_assistant.learning_tasks.qa import DEFAULT_INSTRUCTION, QaTask
|
6 |
+
from edu_assistant.utils.langchain_utils import load_vectorstore, shrink_docs
|
7 |
+
|
8 |
+
|
9 |
+
class QaUI:
|
10 |
+
def __init__(
|
11 |
+
self, *, instruction: str = DEFAULT_INSTRUCTION, enable_gpt4: bool = False, knowledge_name: str = "example"
|
12 |
+
):
|
13 |
+
self._init_task(instruction, knowledge_name, enable_gpt4)
|
14 |
+
self._init_ui()
|
15 |
+
|
16 |
+
def ui_render(self):
|
17 |
+
self.ui.render()
|
18 |
+
|
19 |
+
def ui_reload(
|
20 |
+
self,
|
21 |
+
*,
|
22 |
+
instruction: str = DEFAULT_INSTRUCTION,
|
23 |
+
knowledge_name: str = "example",
|
24 |
+
enable_gpt4: bool = False,
|
25 |
+
refresh: bool = True,
|
26 |
+
):
|
27 |
+
self._init_task(instruction, knowledge_name, enable_gpt4)
|
28 |
+
|
29 |
+
if refresh:
|
30 |
+
self.ui_render()
|
31 |
+
|
32 |
+
def get_instruction(self):
|
33 |
+
return self.instruction
|
34 |
+
|
35 |
+
def _init_task(self, instruction, knowledge_name, enable_gpt4):
|
36 |
+
self.instruction = instruction
|
37 |
+
self.knowledge = knowledge_name
|
38 |
+
self.enable_gpt4 = enable_gpt4
|
39 |
+
self.task = QaTask(
|
40 |
+
instruction=instruction,
|
41 |
+
knowledge=load_vectorstore(knowledge_name).as_retriever(),
|
42 |
+
enable_gpt4=enable_gpt4,
|
43 |
+
)
|
44 |
+
|
45 |
+
def _init_ui(self):
|
46 |
+
with gr.Blocks() as ui:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
with gr.Row():
|
48 |
+
with gr.Column(scale=6):
|
49 |
+
with gr.Row():
|
50 |
+
chatbot = gr.Chatbot(height=500, label="聊天记录")
|
51 |
+
with gr.Row():
|
52 |
+
msg = gr.Textbox(show_label=False)
|
53 |
+
with gr.Column(scale=1):
|
54 |
+
with gr.Row():
|
55 |
+
clear_button = gr.Button(value="清空")
|
56 |
+
with gr.Row():
|
57 |
+
session_id = gr.Textbox(label="Session", interactive=False, value="")
|
58 |
+
with gr.Row():
|
59 |
+
status = gr.JSON(value="""{"tokens":0}""", label="Status")
|
60 |
+
with gr.Row():
|
61 |
+
docs = gr.JSON(value="""["docs"]""", label="Docs")
|
62 |
+
|
63 |
+
clear_button.click(self._clear, [], [msg, chatbot, session_id, status, docs])
|
64 |
+
msg.submit(self._respond, [msg, chatbot, session_id], [msg, chatbot, session_id, status, docs])
|
65 |
+
|
66 |
+
self.ui = ui
|
67 |
+
|
68 |
+
def _respond(self, message, chat_history, session_id):
|
69 |
+
with get_openai_callback() as cb:
|
70 |
+
if session_id:
|
71 |
+
result = self.task.ask(message, session_id=session_id)
|
72 |
+
else:
|
73 |
+
result = self.task.ask(message)
|
74 |
+
|
75 |
+
session_id = result["session_id"]
|
76 |
+
docs = jsonable_encoder(shrink_docs(result.get("source_documents", [])))
|
77 |
+
|
78 |
+
bot_message = result["answer"]
|
79 |
+
chat_history.append((message, bot_message))
|
80 |
+
|
81 |
+
status = {"tokens": cb.total_tokens, "cost": f"${cb.total_cost:.4f}"}
|
82 |
+
return "", chat_history, session_id, status, docs
|
83 |
+
|
84 |
+
def _clear(self):
|
85 |
+
return "", [], "", {"tokens": 0}, ["docs"]
|
webui/ui.py
CHANGED
@@ -1,19 +1,107 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
|
3 |
from edu_assistant import version
|
4 |
-
from webui.coding_problem import
|
5 |
-
from webui.qa import
|
6 |
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
with gr.Row():
|
9 |
-
gr.Markdown(f" v{version.VERSION}")
|
10 |
|
11 |
with gr.Tab(label="答疑"):
|
12 |
-
qa_ui.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
|
17 |
if __name__ == "__main__":
|
18 |
-
|
19 |
-
ui.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import uvicorn
|
3 |
+
from fastapi import FastAPI
|
4 |
|
5 |
from edu_assistant import version
|
6 |
+
from webui.coding_problem import CodingProblemUI
|
7 |
+
from webui.qa import QaUI
|
8 |
|
9 |
+
app = FastAPI()
|
10 |
+
demo = gr.Blocks(title="Codedog Edu Assistant", theme="gradio/soft")
|
11 |
+
qa_ui = QaUI()
|
12 |
+
cp_ui = CodingProblemUI()
|
13 |
+
|
14 |
+
|
15 |
+
def apply_cfg(
|
16 |
+
gpt4_flags: list[int],
|
17 |
+
qa_instruction: str,
|
18 |
+
cp_instruction: str,
|
19 |
+
cp_first_question: str,
|
20 |
+
qa_knowledge: str,
|
21 |
+
cp_knowledge: str,
|
22 |
+
):
|
23 |
+
qa_ui.ui_reload(
|
24 |
+
instruction=qa_instruction,
|
25 |
+
knowledge_name=qa_knowledge,
|
26 |
+
enable_gpt4=0 in gpt4_flags,
|
27 |
+
)
|
28 |
+
cp_ui.ui_reload(
|
29 |
+
instruction=cp_instruction,
|
30 |
+
first_question=cp_first_question,
|
31 |
+
knowledge_name=cp_knowledge,
|
32 |
+
enable_gpt4=1 in gpt4_flags,
|
33 |
+
)
|
34 |
+
demo.render()
|
35 |
+
gr.update()
|
36 |
+
gr.Info("更新配置成功")
|
37 |
+
|
38 |
+
|
39 |
+
def default_cfg():
|
40 |
+
qa_ui.ui_reload()
|
41 |
+
cp_ui.ui_reload()
|
42 |
+
demo.render()
|
43 |
+
gr.update()
|
44 |
+
gr.Info("恢复默认配置成功")
|
45 |
+
|
46 |
+
|
47 |
+
def get_gpt4_flags():
|
48 |
+
result = []
|
49 |
+
if qa_ui.enable_gpt4:
|
50 |
+
result.append("答疑")
|
51 |
+
if cp_ui.enable_gpt4:
|
52 |
+
result.append("做题")
|
53 |
+
return result
|
54 |
+
|
55 |
+
|
56 |
+
with demo:
|
57 |
with gr.Row():
|
58 |
+
gr.Markdown(f"# Codedog Edu Assistant v{version.VERSION}")
|
59 |
|
60 |
with gr.Tab(label="答疑"):
|
61 |
+
qa_ui.ui_render()
|
62 |
+
|
63 |
+
with gr.Tab(label="做题"):
|
64 |
+
cp_ui.ui_render()
|
65 |
+
|
66 |
+
with gr.Tab(label="设置"):
|
67 |
+
with gr.Row():
|
68 |
+
gr.Markdown("## Prompt 设置")
|
69 |
+
with gr.Row():
|
70 |
+
qa_instruction = gr.Textbox(
|
71 |
+
label="答疑指示Prompt", lines=5, max_lines=20, value=qa_ui.get_instruction, interactive=True
|
72 |
+
)
|
73 |
+
with gr.Row():
|
74 |
+
cp_instruction = gr.Textbox(
|
75 |
+
label="做题指示Prompt", lines=5, max_lines=20, value=cp_ui.get_instruction, interactive=True
|
76 |
+
)
|
77 |
+
with gr.Row():
|
78 |
+
cp_first_question = gr.Textbox(
|
79 |
+
label="判题Prompt", lines=5, max_lines=20, value=cp_ui.get_first_question, interactive=True
|
80 |
+
)
|
81 |
+
with gr.Row():
|
82 |
+
with gr.Column(scale=1):
|
83 |
+
gr.Markdown("## Open AI 设置")
|
84 |
+
with gr.Column(scale=2):
|
85 |
+
gpt4_flags = gr.CheckboxGroup(
|
86 |
+
value=get_gpt4_flags, choices=["答疑", "做题"], label="启用GPT4", type="index", interactive=True
|
87 |
+
)
|
88 |
+
|
89 |
+
with gr.Row():
|
90 |
+
gr.Markdown("## 知识库设置")
|
91 |
+
qa_knowledge = gr.Textbox(value=qa_ui.knowledge, label="答疑知识库", interactive=True)
|
92 |
+
cp_knowledge = gr.Textbox(value=cp_ui.knowledge, label="做题知识库", interactive=True)
|
93 |
+
|
94 |
+
with gr.Row():
|
95 |
+
default_btn = gr.Button(value="恢复默认配置", interactive=True, scale=1)
|
96 |
+
apply_btn = gr.Button(value="更新配置", interactive=True, variant="primary", scale=1)
|
97 |
+
|
98 |
+
default_btn.click(default_cfg, [], [])
|
99 |
+
apply_btn.click(
|
100 |
+
apply_cfg, [gpt4_flags, qa_instruction, cp_instruction, cp_first_question, qa_knowledge, cp_knowledge], []
|
101 |
+
)
|
102 |
|
103 |
+
demo.queue()
|
104 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
105 |
|
106 |
if __name__ == "__main__":
|
107 |
+
uvicorn.run(app, port=7860)
|
|