rubra-v0.1-function-calling / postprocess.py
sanjay920's picture
update
29f7f08
raw
history blame
6.14 kB
import json
import uuid
import re
from typing import List
import subprocess
import sys
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
try:
import pythonmonkey
except ImportError:
install('pythonmonkey')
import pythonmonkey
# Your code using pythonmonkey
# Assuming jsonrepair is accessible
jsonrepair = pythonmonkey.require('jsonrepair').jsonrepair
def clean_command_string(command_str):
cleaned_command = re.sub(r'\\(?!["\\/bfnrt]|u[a-fA-F0-9]{4})', '', command_str)
cleaned_command = cleaned_command.replace('\\"', '"')
if cleaned_command.startswith('"') and cleaned_command.endswith('"'):
cleaned_command = cleaned_command[1:-1]
return cleaned_command
def parse_json_safely(json_str):
try:
return json.loads(json_str)
except json.JSONDecodeError:
try:
repaired = jsonrepair(json_str)
return json.loads(repaired)
except Exception:
return json_str
def clean_json_object(obj):
if isinstance(obj, dict):
return {k: clean_json_object(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [clean_json_object(item) for item in obj]
elif isinstance(obj, str):
cleaned = clean_command_string(obj)
return parse_json_safely(cleaned) if cleaned.startswith('{') or cleaned.startswith('[') else cleaned
else:
return obj
def extract_tool_calls(output_str):
# Pattern to capture everything after 'starttoolcall' until 'endtoolcall' or end of string if 'endtoolcall' isn't present
pattern = r'starttoolcall(.*?)(?:endtoolcall|$)'
matches = [match for match in re.findall(pattern, output_str, re.DOTALL)]
return matches
def extract_tool_calls_and_text(output_str):
# Initialize an empty list to collect all segments
segments = []
# Last index processed in the string
last_end = 0
# Pattern to capture everything after 'starttoolcall' until 'endtoolcall' or end of string if 'endtoolcall' isn't present
pattern = r'(starttoolcall(.*?)(?:endtoolcall|$))'
for match in re.finditer(pattern, output_str, re.DOTALL):
start, end = match.span(1)
# Capture any text between the end of the last tool call and the start of the current one
if start > last_end:
text_between = output_str[last_end:start].strip()
if text_between:
segments.append({"text": text_between, "type": "text"})
# Append the current tool call to the list
tool_call_content = match.group(2).strip()
segments.append({"tool_call": tool_call_content, "type": "function"})
# Update the last processed index
last_end = end
# Check if there is any remaining text after the last tool call
if last_end < len(output_str):
remaining_text = output_str[last_end:].strip()
if remaining_text:
segments.append({"text": remaining_text, "type": "text"})
return segments
def postprocess_output(output_str: str):
segments = extract_tool_calls_and_text(output_str)
results = []
for segment in segments:
print("processing segment")
print(segment)
if segment['type'] == 'function':
call = segment['tool_call']
try:
parsed_call = parse_json_safely(call)
cleaned_call = clean_json_object(parsed_call)
if isinstance(cleaned_call, dict) and 'name' in cleaned_call and 'arguments' in cleaned_call:
if isinstance(cleaned_call.get('arguments'), dict):
cleaned_call['arguments'] = json.dumps(cleaned_call['arguments'])
results.append({
"id": uuid.uuid4().hex[:8],
"function": cleaned_call,
"type": "function",
})
else:
results.append({
"id": uuid.uuid4().hex[:8],
"text": call,
"type": "text",
})
except Exception as e:
results.append({
"id": uuid.uuid4().hex[:8],
"text": call,
"type": "text",
})
else:
results.append({
"id": uuid.uuid4().hex[:8],
"text": segment['text'],
"type": "text",
})
return results
def json_to_markdown(json_obj):
"""Convert a JSON object to a formatted markdown string."""
markdown = ""
for item in json_obj:
if item.get("type") == "text":
# For text items, just add the text content
markdown += item.get("text", "") + "\n\n"
elif item.get("type") == "function":
# For function calls, format as JSON
markdown += "```json\n"
markdown += json.dumps(item.get("function", {}), indent=2)
markdown += "\n```\n\n"
return markdown.strip()
if __name__ == "__main__":
# Test the function with a sample input
# output_str = '''Some text before starttoolcall{"name": "funcA", "arguments": {"param1": 1}endtoolcall
# More text starttoolcall{"name": "funcB", "arguments": {"param2": "test"}}endtoolcall'''
# output_str = '''starttoolcall{"name": "get_current_weather", "arguments": {"location": "San Francisco", "unit": "celsius"}}endtoolcall starttoolcall{"name": "get_current_weather", "arguments": {"location": "Tokyo", "unit": "celsius"}}endtoolcall okay great '''
output_str = '''starttoolcall{"name": "get_current_weather", "arguments": {"location": "San Francisco", "unit": "celsius"}}endtoolcall starttoolcall{"name": "get_current_weather", "arguments": {"location": "Tokyo", "unit": "celsius"}}endtoolcall starttoolcall{"name": "get_current_weather", "arguments": {"location": "Paris", "unit": '''
parsed_json = postprocess_output(output_str)
print(json.dumps(parsed_json, indent=2))
print("-----")
print(json_to_markdown(parsed_json))