I-am-agent / my_modelscope_agent /tools /code_interpreter_jupyter.py
jianuo's picture
first
09321b6
raw
history blame
No virus
11.7 kB
import asyncio
import atexit
import base64
import glob
import io
import os
import queue
import re
import shutil
import signal
import subprocess
import sys
import time
import traceback
import uuid
from pathlib import Path
from typing import Dict, Optional
import json
import matplotlib
import PIL.Image
from jupyter_client import BlockingKernelClient
from .tool import Tool
WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/ci_workspace')
STATIC_URL = os.getenv('CODE_INTERPRETER_STATIC_URL',
'http://127.0.0.1:7866/static')
LAUNCH_KERNEL_PY = """
from ipykernel import kernelapp as app
app.launch_new_instance()
"""
INIT_CODE_FILE = str(
Path(__file__).absolute().parent / 'code_interpreter_utils'
/ 'code_interpreter_init_kernel.py')
ALIB_FONT_FILE = str(
Path(__file__).absolute().parent / 'code_interpreter_utils'
/ 'AlibabaPuHuiTi-3-45-Light.ttf')
_KERNEL_CLIENTS: Dict[int, BlockingKernelClient] = {}
class CodeInterpreterJupyter(Tool):
"""
using jupyter kernel client to interpret python code,
should not be used the other code interpreter tool at the same time
"""
description = '代码解释器,可用于执行Python代码。'
name = 'code_interpreter'
parameters: list = [{
'name': 'code',
'description': '待执行的代码',
'required': True
}]
def __init__(self, cfg={}):
super().__init__(cfg)
self.timeout = self.cfg.get('timeout', 30)
self.image_server = self.cfg.get('image_server', False)
self.kernel_clients: Dict[int, BlockingKernelClient] = {}
atexit.register(self._kill_kernels)
pid: int = os.getpid()
if pid in self.kernel_clients:
kc = self.kernel_clients[pid]
else:
self._fix_matplotlib_cjk_font_issue()
kc = self._start_kernel(pid)
with open(INIT_CODE_FILE) as fin:
start_code = fin.read()
start_code = start_code.replace('{{M6_FONT_PATH}}',
repr(ALIB_FONT_FILE)[1:-1])
print(self._execute_code(kc, start_code))
self.kernel_clients[pid] = kc
self.kc = kc
def __del__(self):
# make sure all the kernels are killed during __del__
signal.signal(signal.SIGTERM, self._kill_kernels)
signal.signal(signal.SIGINT, self._kill_kernels)
def _start_kernel(self, pid) -> BlockingKernelClient:
connection_file = os.path.join(WORK_DIR,
f'kernel_connection_file_{pid}.json')
launch_kernel_script = os.path.join(WORK_DIR,
f'launch_kernel_{pid}.py')
for f in [connection_file, launch_kernel_script]:
if os.path.exists(f):
print(f'WARNING: {f} already exists')
os.remove(f)
os.makedirs(WORK_DIR, exist_ok=True)
with open(launch_kernel_script, 'w') as fout:
fout.write(LAUNCH_KERNEL_PY)
available_envs = ['PATH', 'PYTHONPATH', 'LD_LIBRARY_PATH']
envs = {}
for k in available_envs:
if os.getenv(k) is not None:
envs[k] = os.getenv(k)
args = (
sys.executable,
launch_kernel_script,
'--IPKernelApp.connection_file',
connection_file,
'--matplotlib=inline',
'--quiet',
)
kernel_process = subprocess.Popen([*args], env=envs,
cwd=WORK_DIR) # noqa E126
print(f"INFO: kernel process's PID = {kernel_process.pid}")
# Wait for kernel connection file to be written
while True:
if not os.path.isfile(connection_file):
time.sleep(0.1)
else:
# Keep looping if JSON parsing fails, file may be partially written
try:
with open(connection_file, 'r') as fp:
json.load(fp)
break
except json.JSONDecodeError:
pass
# Client
kc = BlockingKernelClient(connection_file=connection_file)
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
kc.load_connection_file()
kc.start_channels()
kc.wait_for_ready()
return kc
def _kill_kernels(self):
for v in self.kernel_clients.values():
v.shutdown()
for k in list(self.kernel_clients.keys()):
del self.kernel_clients[k]
def _serve_image(self, image_base64: str, image_type: str) -> str:
image_file = f'{uuid.uuid4()}.{image_type}'
local_image_file = os.path.join(WORK_DIR, image_file)
png_bytes = base64.b64decode(image_base64)
assert isinstance(png_bytes, bytes)
if image_type == 'gif':
with open(local_image_file, 'wb') as file:
file.write(png_bytes)
else:
bytes_io = io.BytesIO(png_bytes)
PIL.Image.open(bytes_io).save(local_image_file, image_type)
if self.image_server:
image_url = f'{STATIC_URL}/{image_file}'
return image_url
else:
return local_image_file
def _escape_ansi(self, line: str) -> str:
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
return ansi_escape.sub('', line)
def _fix_matplotlib_cjk_font_issue(self):
ttf_name = os.path.basename(ALIB_FONT_FILE)
local_ttf = os.path.join(
os.path.abspath(
os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)),
'fonts', 'ttf', ttf_name)
if not os.path.exists(local_ttf):
try:
shutil.copy(ALIB_FONT_FILE, local_ttf)
font_list_cache = os.path.join(matplotlib.get_cachedir(),
'fontlist-*.json')
for cache_file in glob.glob(font_list_cache):
with open(cache_file) as fin:
cache_content = fin.read()
if ttf_name not in cache_content:
os.remove(cache_file)
except Exception:
traceback.format_exc()
def _execute_code(self, kc: BlockingKernelClient, code: str) -> str:
kc.wait_for_ready()
kc.execute(code)
result = ''
image_idx = 0
while True:
text = ''
image = ''
finished = False
msg_type = 'error'
try:
msg = kc.get_iopub_msg()
msg_type = msg['msg_type']
if msg_type == 'status':
if msg['content'].get('execution_state') == 'idle':
finished = True
elif msg_type == 'execute_result':
text = msg['content']['data'].get('text/plain', '')
if 'image/png' in msg['content']['data']:
image_b64 = msg['content']['data']['image/png']
image_url = self._serve_image(image_b64, 'png')
image_idx += 1
image = '![IMAGEGEN](%s)' % (image_url)
elif 'text/html' in msg['content']['data']:
text += '\n' + msg['content']['data']['text/html']
elif 'image/gif' in msg['content']['data']:
image_b64 = msg['content']['data']['image/gif']
image_url = self._serve_image(image_b64, 'gif')
image_idx += 1
image = '![IMAGEGEN](%s)' % (image_url)
elif msg_type == 'display_data':
if 'image/png' in msg['content']['data']:
image_b64 = msg['content']['data']['image/png']
image_url = self._serve_image(image_b64, 'png')
image_idx += 1
image = '![IMAGEGEN](%s)' % (image_url)
else:
text = msg['content']['data'].get('text/plain', '')
elif msg_type == 'stream':
msg_type = msg['content']['name'] # stdout, stderr
text = msg['content']['text']
elif msg_type == 'error':
text = self._escape_ansi('\n'.join(
msg['content']['traceback']))
if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
text = 'Timeout: Code execution exceeded the time limit.'
except queue.Empty:
text = 'Timeout: Code execution exceeded the time limit.'
finished = True
except Exception:
text = 'The code interpreter encountered an unexpected error.'
traceback.format_exc()
finished = True
if text:
result += f'\n{text}'
if image:
result += f'\n\n{image}'
if finished:
break
result = result.lstrip('\n')
if not result:
result += 'The code executed successfully.'
return result
def _local_call(self, *args, **kwargs):
code = self._handle_input_fallback(**kwargs)
if not code.strip():
return ''
if self.timeout:
code = f'_M6CountdownTimer.start({self.timeout})\n{code}'
fixed_code = []
for line in code.split('\n'):
fixed_code.append(line)
if line.startswith('sns.set_theme('):
fixed_code.append(
'plt.rcParams["font.family"] = _m6_font_prop.get_name()')
fixed_code = '\n'.join(fixed_code)
result = self._execute_code(self.kc, fixed_code)
if self.timeout:
self._execute_code(self.kc, '_M6CountdownTimer.cancel()')
return {'result': result}
def _handle_input_fallback(self, **kwargs):
"""
an alternative method is to parse code in content not from function call
such as:
text = response['content']
code_block = re.search(r'```([\s\S]+)```', text) # noqa W^05
if code_block:
result = code_block.group(1)
language = result.split('\n')[0]
code = '\n'.join(result.split('\n')[1:])
:param fallback_text:
:return: language, cocde
"""
code = kwargs.get('code', None)
fallback = kwargs.get('fallback', None)
if code:
return code
elif fallback:
try:
text = fallback
code_block = re.search(r'```([\s\S]+)```', text) # noqa W^05
if code_block:
result = code_block.group(1)
language = result.split('\n')[0]
if language == 'py' or language == 'python':
# handle py case
# ```py code ```
language = 'python'
code = '\n'.join(result.split('\n')[1:])
return code
if language == 'json':
# handle json case
# ```json {language,code}```
parameters = json.loads('\n'.join(
result.split('\n')[1:]).replace('\n', ''))
return parameters['code']
except ValueError:
return code
else:
return code