| import re |
| import sys |
| from collections import defaultdict |
| from contextlib import nullcontext |
| from io import StringIO |
| from multiprocessing import Process, Queue |
| from typing import List, Optional, Type, Union |
|
|
| from filelock import FileLock |
| from timeout_decorator import timeout as tm |
|
|
| from ..schema import ActionReturn, ActionStatusCode |
| from .base_action import BaseAction |
| from .parser import BaseParser, JsonParser |
|
|
|
|
| class IPythonProcess(Process): |
|
|
| def __init__(self, |
| in_q: Queue, |
| out_q: Queue, |
| timeout: int = 20, |
| ci_lock: str = None, |
| daemon: bool = True): |
| super().__init__(daemon=daemon) |
| self.in_q = in_q |
| self.out_q = out_q |
| self.timeout = timeout |
| self.session_id2shell = defaultdict(self.create_shell) |
| self.ci_lock = FileLock( |
| ci_lock) if ci_lock else nullcontext() |
| self._highlighting = re.compile(r'\x1b\[\d{,3}(;\d{,3}){,3}m') |
|
|
| def run(self): |
| while True: |
| msg = self.in_q.get() |
| if msg == 'reset': |
| for session_id, shell in self.session_id2shell.items(): |
| with self.ci_lock: |
| try: |
| shell.reset(new_session=False) |
| |
| except Exception: |
| self.session_id2shell[ |
| session_id] = self.create_shell() |
| self.out_q.put('ok') |
| elif isinstance(msg, tuple) and len(msg) == 3: |
| i, session_id, code = msg |
| res = self.exec(session_id, code) |
| self.out_q.put((i, session_id, res)) |
|
|
| def exec(self, session_id, code): |
| try: |
| shell = self.session_id2shell[session_id] |
| with StringIO() as io: |
| old_stdout = sys.stdout |
| sys.stdout = io |
| if self.timeout is False or self.timeout < 0: |
| shell.run_cell(self.extract_code(code)) |
| else: |
| tm(self.timeout)(shell.run_cell)(self.extract_code(code)) |
| sys.stdout = old_stdout |
| output = self._highlighting.sub('', io.getvalue().strip()) |
| output = re.sub(r'^Out\[\d+\]: ', '', output) |
| if 'Error' in output or 'Traceback' in output: |
| output = output.lstrip('-').strip() |
| if output.startswith('TimeoutError'): |
| output = 'The code interpreter encountered a timeout error.' |
| return {'status': 'FAILURE', 'msg': output, 'code': code} |
| return {'status': 'SUCCESS', 'value': output, 'code': code} |
| except Exception as e: |
| return {'status': 'FAILURE', 'msg': str(e), 'code': code} |
|
|
| @staticmethod |
| def create_shell(enable_history: bool = False, in_memory: bool = True): |
| from IPython import InteractiveShell |
| from traitlets.config import Config |
|
|
| c = Config() |
| c.HistoryManager.enabled = enable_history |
| if in_memory: |
| c.HistoryManager.hist_file = ':memory:' |
| shell = InteractiveShell(config=c) |
| return shell |
|
|
| @staticmethod |
| def extract_code(text: str) -> str: |
| """Extract Python code from markup languages. |
| |
| Args: |
| text (:class:`str`): Markdown-formatted text |
| |
| Returns: |
| :class:`str`: Python code |
| """ |
| import json5 |
|
|
| |
| triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) |
| |
| single_match = re.search(r'`([^`]*)`', text, re.DOTALL) |
| if triple_match: |
| text = triple_match.group(1) |
| elif single_match: |
| text = single_match.group(1) |
| else: |
| try: |
| text = json5.loads(text)['code'] |
| except Exception: |
| pass |
| |
| return text |
|
|
|
|
| class IPythonInteractiveManager(BaseAction): |
| """An interactive IPython shell manager for code execution""" |
|
|
| def __init__( |
| self, |
| max_workers: int = 50, |
| timeout: int = 20, |
| ci_lock: str = None, |
| description: Optional[dict] = None, |
| parser: Type[BaseParser] = JsonParser, |
| ): |
| super().__init__(description, parser) |
| self.max_workers = max_workers |
| self.timeout = timeout |
| self.ci_lock = ci_lock |
| self.id2queue = defaultdict(Queue) |
| self.id2process = {} |
| self.out_queue = Queue() |
|
|
| def __call__(self, |
| commands: Union[str, List[str]], |
| session_ids: Union[int, List[int]] = None): |
| if isinstance(commands, list): |
| batch_size = len(commands) |
| is_batch = True |
| else: |
| batch_size = 1 |
| commands = [commands] |
| is_batch = False |
| if session_ids is None: |
| session_ids = range(batch_size) |
| elif isinstance(session_ids, int): |
| session_ids = [session_ids] |
| if len(session_ids) != batch_size or len(session_ids) != len( |
| set(session_ids)): |
| raise ValueError( |
| 'the size of `session_ids` must equal that of `commands`') |
| try: |
| exec_results = self.run_code_blocks([ |
| (session_id, command) |
| for session_id, command in zip(session_ids, commands) |
| ]) |
| except KeyboardInterrupt: |
| self.clear() |
| exit(1) |
| action_returns = [] |
| for result, code in zip(exec_results, commands): |
| action_return = ActionReturn({'command': code}, type=self.name) |
| if result['status'] == 'SUCCESS': |
| action_return.result = [ |
| dict(type='text', content=result['value']) |
| ] |
| action_return.state = ActionStatusCode.SUCCESS |
| else: |
| action_return.errmsg = result['msg'] |
| action_return.state = ActionStatusCode.API_ERROR |
| action_returns.append(action_return) |
| if not is_batch: |
| return action_returns[0] |
| return action_returns |
|
|
| def process_code(self, index, session_id, code): |
| ipy_id = session_id % self.max_workers |
| input_queue = self.id2queue[ipy_id] |
| proc = self.id2process.setdefault( |
| ipy_id, |
| IPythonProcess( |
| input_queue, |
| self.out_queue, |
| self.timeout, |
| self.ci_lock, |
| daemon=True)) |
| if not proc.is_alive(): |
| proc.start() |
| input_queue.put((index, session_id, code)) |
|
|
| def run_code_blocks(self, session_code_pairs): |
| size = len(session_code_pairs) |
| for index, (session_id, code) in enumerate(session_code_pairs): |
| self.process_code(index, session_id, code) |
| results = [] |
| while len(results) < size: |
| msg = self.out_queue.get() |
| if isinstance(msg, tuple) and len(msg) == 3: |
| index, _, result = msg |
| results.append((index, result)) |
| results.sort() |
| return [item[1] for item in results] |
|
|
| def clear(self): |
| self.id2queue.clear() |
| for proc in self.id2process.values(): |
| proc.terminate() |
| self.id2process.clear() |
| while not self.out_queue.empty(): |
| self.out_queue.get() |
|
|
| def reset(self): |
| cnt = 0 |
| for q in self.id2queue.values(): |
| q.put('reset') |
| cnt += 1 |
| while cnt > 0: |
| msg = self.out_queue.get() |
| if msg == 'ok': |
| cnt -= 1 |
|
|