Spaces:
Running
on
A10G
Running
on
A10G
File size: 2,804 Bytes
7803dd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from jupyter_client import KernelManager
import threading
import re
from utils.const import *
class JupyterNotebook:
def __init__(self):
self.km = KernelManager()
self.km.start_kernel()
self.kc = self.km.client()
_ = self.add_and_run(TOOLS_CODE)
def clean_output(self, outputs):
outputs_only_str = list()
for i in outputs:
if type(i) == dict:
if "text/plain" in list(i.keys()):
outputs_only_str.append(i["text/plain"])
elif type(i) == str:
outputs_only_str.append(i)
elif type(i) == list:
error_msg = "\n".join(i)
error_msg = re.sub(r"\x1b\[.*?m", "", error_msg)
outputs_only_str.append(error_msg)
return "\n".join(outputs_only_str).strip()
def add_and_run(self, code_string):
# This inner function will be executed in a separate thread
def run_code_in_thread():
nonlocal outputs, error_flag
# Execute the code and get the execution count
msg_id = self.kc.execute(code_string)
while True:
try:
msg = self.kc.get_iopub_msg(timeout=20)
msg_type = msg["header"]["msg_type"]
content = msg["content"]
if msg_type == "execute_result":
outputs.append(content["data"])
elif msg_type == "stream":
outputs.append(content["text"])
elif msg_type == "error":
error_flag = True
outputs.append(content["traceback"])
# If the execution state of the kernel is idle, it means the cell finished executing
if msg_type == "status" and content["execution_state"] == "idle":
break
except:
break
outputs = []
error_flag = False
# Start the thread to run the code
thread = threading.Thread(target=run_code_in_thread)
thread.start()
# Wait for 20 seconds for the thread to finish
thread.join(timeout=20)
# If the thread is still alive after 20 seconds, it's a timeout
if thread.is_alive():
outputs = ["Execution timed out."]
# outputs = ["Error"]
error_flag = "Timeout"
return self.clean_output(outputs), error_flag
def close(self):
"""Shutdown the kernel."""
self.km.shutdown_kernel()
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
new_copy = type(self)()
memo[id(self)] = new_copy
return new_copy
|