ChartGPT / tools.py
xiaofeifei's picture
init commit
26bc264
import os
import uuid
import plotly.graph_objects as go
from langchain.tools import Tool
os.makedirs('image', exist_ok=True)
def plot_bar(inputs) -> str:
inputs = eval(inputs)
image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
fig = go.Figure(data=[go.Bar(x=list(range(len(inputs))), y=inputs)])
fig.write_image(image_filename)
print(f"\nProcessed PlotBarChart, Input arr: {inputs}, Output Text: {image_filename}")
return image_filename
def plot_line(inputs) -> str:
inputs = eval(inputs)
image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
fig = go.Figure(data=[go.Line(x=list(range(len(inputs))), y=inputs)])
fig.write_image(image_filename)
print(f"\nProcessed PlotBarChart, Input arr: {inputs}, Output Text: {image_filename}")
return image_filename
def plot_pie(inputs: tuple) -> str:
inputs = eval(inputs)
labels = inputs[0]
values = inputs[1]
# 创建饼图对象
fig = go.Figure(data=go.Pie(labels=labels, values=values))
# 设置图表布局
fig.update_layout(title='Pie Chart')
image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
fig.write_image(image_filename)
print(f"\nProcessed PlotPieChart, Input labels :{labels}, Input arr: {values}, Output Text: {image_filename}")
return image_filename
def plot_scatter(inputs: tuple) -> str:
inputs = eval(inputs)
# 创建散点图数据
x = inputs[0]
y = inputs[1]
# 创建散点图对象
fig = go.Figure(data=go.Scatter(
x=x, y=y, mode='markers', marker=dict(size=10, color='blue')
))
# 设置图表布局
fig.update_layout(title='Scatter Plot', xaxis_title='X-axis', yaxis_title='Y-axis')
image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
fig.write_image(image_filename)
print(f"\nProcessed PlotScatterChart, Input x :{x}, Input y: {y}, Output Text: {image_filename}")
return image_filename
bar_tool = Tool(name="PlotBarChart", func=plot_bar,
description="useful when you want to draw bar chart. receives array as input. "
"The input to this tool should be a number list, representing the array. ",
return_direct=True)
line_tool = Tool(name="PlotLineChart", func=plot_line,
description="useful when you want to draw line chart. receives array as input. "
"The input to this tool should be a number list, representing the array. ",
return_direct=True)
pie_tool = Tool(name="PlotPieChart", func=plot_pie,
description="useful when you want to draw pie chart. the input is a tuple, which schema is such as(lables,values)."
"'labels', the first element of the input tuple, is a list of string, containing the names of each sector in the pie chart. "
"'values', the second element of the input tuple, is a list of number, containing values corresponding to each sector ",
return_direct=True)
scatter_tool = Tool(name="PlotScatterChart", func=plot_scatter,
description="useful when you want to draw scatter chart. the input is a tuple, which schema is such as(x,y)."
"'x', the first element of the input tuple, is a list containing the abscissa values of each point in the scatterplot."
"'y', the second element of the input tuple, 'is a list containing the ordinate values of each point in the scatterplot.",
return_direct=True)