| | import re |
| | import sys |
| | from collections import defaultdict |
| | from contextlib import nullcontext |
| | from io import StringIO |
| | from multiprocessing import Process, Queue |
| | from typing import List, Optional, Type, Union |
| |
|
| | from filelock import FileLock |
| | from timeout_decorator import timeout as tm |
| |
|
| | from ..schema import ActionReturn, ActionStatusCode |
| | from .base_action import BaseAction |
| | from .parser import BaseParser, JsonParser |
| |
|
| |
|
| | class IPythonProcess(Process): |
| |
|
| | def __init__(self, |
| | in_q: Queue, |
| | out_q: Queue, |
| | timeout: int = 20, |
| | ci_lock: str = None, |
| | daemon: bool = True): |
| | super().__init__(daemon=daemon) |
| | self.in_q = in_q |
| | self.out_q = out_q |
| | self.timeout = timeout |
| | self.session_id2shell = defaultdict(self.create_shell) |
| | self.ci_lock = FileLock( |
| | ci_lock) if ci_lock else nullcontext() |
| | self._highlighting = re.compile(r'\x1b\[\d{,3}(;\d{,3}){,3}m') |
| |
|
| | def run(self): |
| | while True: |
| | msg = self.in_q.get() |
| | if msg == 'reset': |
| | for session_id, shell in self.session_id2shell.items(): |
| | with self.ci_lock: |
| | try: |
| | shell.reset(new_session=False) |
| | |
| | except Exception: |
| | self.session_id2shell[ |
| | session_id] = self.create_shell() |
| | self.out_q.put('ok') |
| | elif isinstance(msg, tuple) and len(msg) == 3: |
| | i, session_id, code = msg |
| | res = self.exec(session_id, code) |
| | self.out_q.put((i, session_id, res)) |
| |
|
| | def exec(self, session_id, code): |
| | try: |
| | shell = self.session_id2shell[session_id] |
| | with StringIO() as io: |
| | old_stdout = sys.stdout |
| | sys.stdout = io |
| | if self.timeout is False or self.timeout < 0: |
| | shell.run_cell(self.extract_code(code)) |
| | else: |
| | tm(self.timeout)(shell.run_cell)(self.extract_code(code)) |
| | sys.stdout = old_stdout |
| | output = self._highlighting.sub('', io.getvalue().strip()) |
| | output = re.sub(r'^Out\[\d+\]: ', '', output) |
| | if 'Error' in output or 'Traceback' in output: |
| | output = output.lstrip('-').strip() |
| | if output.startswith('TimeoutError'): |
| | output = 'The code interpreter encountered a timeout error.' |
| | return {'status': 'FAILURE', 'msg': output, 'code': code} |
| | return {'status': 'SUCCESS', 'value': output, 'code': code} |
| | except Exception as e: |
| | return {'status': 'FAILURE', 'msg': str(e), 'code': code} |
| |
|
| | @staticmethod |
| | def create_shell(enable_history: bool = False, in_memory: bool = True): |
| | from IPython import InteractiveShell |
| | from traitlets.config import Config |
| |
|
| | c = Config() |
| | c.HistoryManager.enabled = enable_history |
| | if in_memory: |
| | c.HistoryManager.hist_file = ':memory:' |
| | shell = InteractiveShell(config=c) |
| | return shell |
| |
|
| | @staticmethod |
| | def extract_code(text: str) -> str: |
| | """Extract Python code from markup languages. |
| | |
| | Args: |
| | text (:class:`str`): Markdown-formatted text |
| | |
| | Returns: |
| | :class:`str`: Python code |
| | """ |
| | import json5 |
| |
|
| | |
| | triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) |
| | |
| | single_match = re.search(r'`([^`]*)`', text, re.DOTALL) |
| | if triple_match: |
| | text = triple_match.group(1) |
| | elif single_match: |
| | text = single_match.group(1) |
| | else: |
| | try: |
| | text = json5.loads(text)['code'] |
| | except Exception: |
| | pass |
| | |
| | return text |
| |
|
| |
|
| | class IPythonInteractiveManager(BaseAction): |
| | """An interactive IPython shell manager for code execution""" |
| |
|
| | def __init__( |
| | self, |
| | max_workers: int = 50, |
| | timeout: int = 20, |
| | ci_lock: str = None, |
| | description: Optional[dict] = None, |
| | parser: Type[BaseParser] = JsonParser, |
| | ): |
| | super().__init__(description, parser) |
| | self.max_workers = max_workers |
| | self.timeout = timeout |
| | self.ci_lock = ci_lock |
| | self.id2queue = defaultdict(Queue) |
| | self.id2process = {} |
| | self.out_queue = Queue() |
| |
|
| | def __call__(self, |
| | commands: Union[str, List[str]], |
| | session_ids: Union[int, List[int]] = None): |
| | if isinstance(commands, list): |
| | batch_size = len(commands) |
| | is_batch = True |
| | else: |
| | batch_size = 1 |
| | commands = [commands] |
| | is_batch = False |
| | if session_ids is None: |
| | session_ids = range(batch_size) |
| | elif isinstance(session_ids, int): |
| | session_ids = [session_ids] |
| | if len(session_ids) != batch_size or len(session_ids) != len( |
| | set(session_ids)): |
| | raise ValueError( |
| | 'the size of `session_ids` must equal that of `commands`') |
| | try: |
| | exec_results = self.run_code_blocks([ |
| | (session_id, command) |
| | for session_id, command in zip(session_ids, commands) |
| | ]) |
| | except KeyboardInterrupt: |
| | self.clear() |
| | exit(1) |
| | action_returns = [] |
| | for result, code in zip(exec_results, commands): |
| | action_return = ActionReturn({'command': code}, type=self.name) |
| | if result['status'] == 'SUCCESS': |
| | action_return.result = [ |
| | dict(type='text', content=result['value']) |
| | ] |
| | action_return.state = ActionStatusCode.SUCCESS |
| | else: |
| | action_return.errmsg = result['msg'] |
| | action_return.state = ActionStatusCode.API_ERROR |
| | action_returns.append(action_return) |
| | if not is_batch: |
| | return action_returns[0] |
| | return action_returns |
| |
|
| | def process_code(self, index, session_id, code): |
| | ipy_id = session_id % self.max_workers |
| | input_queue = self.id2queue[ipy_id] |
| | proc = self.id2process.setdefault( |
| | ipy_id, |
| | IPythonProcess( |
| | input_queue, |
| | self.out_queue, |
| | self.timeout, |
| | self.ci_lock, |
| | daemon=True)) |
| | if not proc.is_alive(): |
| | proc.start() |
| | input_queue.put((index, session_id, code)) |
| |
|
| | def run_code_blocks(self, session_code_pairs): |
| | size = len(session_code_pairs) |
| | for index, (session_id, code) in enumerate(session_code_pairs): |
| | self.process_code(index, session_id, code) |
| | results = [] |
| | while len(results) < size: |
| | msg = self.out_queue.get() |
| | if isinstance(msg, tuple) and len(msg) == 3: |
| | index, _, result = msg |
| | results.append((index, result)) |
| | results.sort() |
| | return [item[1] for item in results] |
| |
|
| | def clear(self): |
| | self.id2queue.clear() |
| | for proc in self.id2process.values(): |
| | proc.terminate() |
| | self.id2process.clear() |
| | while not self.out_queue.empty(): |
| | self.out_queue.get() |
| |
|
| | def reset(self): |
| | cnt = 0 |
| | for q in self.id2queue.values(): |
| | q.put('reset') |
| | cnt += 1 |
| | while cnt > 0: |
| | msg = self.out_queue.get() |
| | if msg == 'ok': |
| | cnt -= 1 |
| |
|