Mikasa06's picture
Update app.py
5fa87d3 verified
from smolagents import CodeAgent,DuckDuckGoSearchTool, HfApiModel,load_tool,tool
import datetime
import requests
import pytz
import yaml
from tools.final_answer import FinalAnswerTool
import matplotlib.pyplot as plt
import io
from PIL import Image as PILImage
from IPython.display import display, Image
import numpy as np
from Gradio_UI import GradioUI
# Below is an example of a tool that does nothing. Amaze us with your creativity !
@tool
def my_custom_tool(arg1:str, arg2:int)-> str: #it's import to specify the return type
#Keep this format for the description / args / args description but feel free to modify the tool
"""A tool that does nothing yet
Args:
arg1: the first argument
arg2: the second argument
"""
return "What magic will you build ?"
@tool
def get_current_time_in_timezone(timezone: str) -> str:
"""A tool that fetches the current local time in a specified timezone.
Args:
timezone: A string representing a valid timezone (e.g., 'America/New_York').
"""
try:
# Create timezone object
tz = pytz.timezone(timezone)
# Get current time in that timezone
local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
return f"The current local time in {timezone} is: {local_time}"
except Exception as e:
return f"Error fetching time for timezone '{timezone}': {str(e)}"
# Bar chart tool
@tool
def generate_bar_chart(
x_values: list[str], y_values: list[int], title: str = 'Bar Chart',
x_label: str = 'X-Axis', y_label: str = 'Y-Axis', show_labels: bool = False
) -> PILImage:
"""Generates a bar chart from the provided x and y values and returns a PILImage object.
Args:
x_values: A list of string values for the x-axis.
y_values: A list of numerical values for the y-axis.
title: Title for the bar plot.
x_label: Label for the x-axis.
y_label: Label for the y-axis.
show_labels: Whether to display value labels on top of the bars.
Returns:
A PIL Image object containing the generated bar chart.
"""
try:
if len(x_values) != len(y_values):
raise ValueError("x_values and y_values must have the same length.")
plt.figure(figsize=(8, 6))
bars = plt.bar(x_values, y_values, color=plt.cm.Paired.colors, edgecolor='black')
if show_labels:
for bar in bars:
plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
f'{bar.get_height()}', ha='center', va='bottom',
fontsize=10, fontweight='bold')
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.title(title)
# Save the plot to a BytesIO buffer instead of a file
img_buffer = io.BytesIO()
plt.savefig(img_buffer, format='png', dpi=300, bbox_inches='tight')
plt.close() # Close the plot to free memory
img_buffer.seek(0) # Move to the beginning of the buffer
return PILImage.open(img_buffer) # Return the image object
except Exception as e:
print(f"Error generating bar chart: {str(e)}")
return None
# Scatter plot tool
@tool
def generate_scatter_plot_with_labels(
x_values: list[float], y_values: list[float], labels: list[str],
title: str = 'Scatter Plot', x_label: str = 'X-Axis', y_label: str = 'Y-Axis'
) -> PILImage:
"""Generates a scatter plot from the provided x and y values and labels each point. Returns a PILImage object.
Args:
x_values: A list of numerical values for the x-axis.
y_values: A list of numerical values for the y-axis.
labels: A list of labels corresponding to each point.
title: Title for the scatter plot.
x_label: Label for the x-axis.
y_label: Label for the y-axis.
Returns:
A PIL Image object containing the scatter plot.
"""
try:
if len(x_values) != len(y_values) or len(x_values) != len(labels):
raise ValueError("x_values, y_values, and labels must all have the same length.")
plt.figure(figsize=(6, 4))
plt.scatter(x_values, y_values, color='blue', edgecolors='black', marker='o')
for i, label in enumerate(labels):
plt.text(x_values[i], y_values[i], label, fontsize=9, ha='right', color='red')
plt.title(title)
plt.xlabel(x_label)
plt.ylabel(y_label)
img_buffer = io.BytesIO()
plt.savefig(img_buffer, format='png', dpi=300, bbox_inches='tight')
plt.close()
img_buffer.seek(0)
return PILImage.open(img_buffer)
except Exception as e:
print(f"Error generating scatter plot with labels: {str(e)}")
return None
# Line plot tool
@tool
def generate_line_plot(
x_values: list[float], y_values: list[float],
title: str = 'Line Plot', x_label: str = 'X-Axis', y_label: str = 'Y-Axis'
) -> PILImage:
"""Generates a line plot from the provided x and y values. Returns a PILImage object.
Args:
x_values: A list of numerical values for the x-axis.
y_values: A list of numerical values for the y-axis.
title: Title for the line plot.
x_label: Label for the x-axis.
y_label: Label for the y-axis.
Returns:
A PIL Image object containing the line plot.
"""
try:
if len(x_values) != len(y_values):
raise ValueError("x_values and y_values must have the same length.")
plt.figure(figsize=(6, 4))
plt.plot(x_values, y_values, marker='o', linestyle='-', color='b')
plt.title(title)
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.grid(True)
img_buffer = io.BytesIO()
plt.savefig(img_buffer, format='png', dpi=300, bbox_inches='tight')
plt.close()
img_buffer.seek(0)
return PILImage.open(img_buffer)
except Exception as e:
print(f"Error generating line plot: {str(e)}")
return None
# Histogram tool
@tool
def generate_histogram(
data: list[float], bins: int = 10,
title: str = 'Histogram', x_label: str = 'Values', y_label: str = 'Frequency'
) -> PILImage:
"""Generates a histogram from the provided data and returns a PILImage object.
Args:
data: A list of numerical values.
bins: Number of bins for the histogram.
title: Title for the histogram.
x_label: Label for the x-axis.
y_label: Label for the y-axis.
Returns:
A PIL Image object containing the histogram plot.
"""
try:
if not data:
raise ValueError("Data list is empty.")
plt.figure(figsize=(6, 4))
plt.hist(data, bins=bins, color='purple', edgecolor='black', alpha=0.7)
plt.title(title)
plt.xlabel(x_label)
plt.ylabel(y_label)
img_buffer = io.BytesIO()
plt.savefig(img_buffer, format='png', dpi=300, bbox_inches='tight')
plt.close()
img_buffer.seek(0)
return PILImage.open(img_buffer)
except Exception as e:
print(f"Error generating histogram: {str(e)}")
return None
# Pie chart tool
@tool
def generate_pie_chart(
labels: list[str], values: list[float], title: str = 'Pie Chart'
) -> PILImage:
"""Generates a pie chart from the provided labels and values with dynamic colors. Returns a PILImage object.
Args:
labels: A list of category labels.
values: A list of numerical values for each category.
title: Title for the pie chart.
Returns:
A PIL Image object containing the pie chart.
"""
try:
if len(labels) != len(values):
raise ValueError("Labels and values must have the same length.")
if any(v < 0 for v in values):
raise ValueError("Values must be non-negative.")
if sum(values) == 0:
raise ValueError("Sum of values must be greater than zero.")
cmap = plt.get_cmap('tab10')
colors = [cmap(i / len(values)) for i in range(len(values))]
fig, ax = plt.subplots(figsize=(6, 6))
ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=140,
colors=colors, textprops={'fontsize': 10})
ax.set_title(title, fontsize=12)
img_buffer = io.BytesIO()
fig.savefig(img_buffer, format='png', dpi=300, bbox_inches='tight')
plt.close(fig)
img_buffer.seek(0)
return PILImage.open(img_buffer)
except Exception as e:
print(f"Error generating pie chart: {str(e)}")
return None
# Box plot tool
@tool
def generate_box_plot(
data: list[list[float]], labels: list[str] = None,
title: str = 'Box Plot', y_label: str = 'Values'
) -> PILImage:
"""Generates a box plot from the provided data and returns a PILImage object.
Args:
data: A list of numerical lists representing different categories.
labels: A list of labels corresponding to each dataset (optional).
title: Title for the box plot.
y_label: Label for the y-axis.
Returns:
A PIL Image object containing the box plot.
"""
try:
if not all(isinstance(category, list) and all(isinstance(x, (int, float)) for x in category) for category in data):
raise ValueError("Data must be a list of numerical lists.")
if labels and len(labels) != len(data):
raise ValueError("Labels length must match the number of data categories.")
plt.figure(figsize=(8, 6))
plt.boxplot(data)
if labels:
plt.xticks(range(1, len(labels) + 1), labels, rotation=20)
plt.title(title)
plt.ylabel(y_label)
img_buffer = io.BytesIO()
plt.savefig(img_buffer, format='png', dpi=300, bbox_inches='tight')
plt.close()
img_buffer.seek(0)
return PILImage.open(img_buffer)
except Exception as e:
print(f"Error generating box plot: {str(e)}")
return None
final_answer = FinalAnswerTool()
# If the agent does not answer, the model is overloaded, please use another model or the following Hugging Face Endpoint that also contains qwen2.5 coder:
# model_id='https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud'
model = HfApiModel(
max_tokens=2096,
temperature=0.5,
# model_id='https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud',
model_id='Qwen/Qwen2.5-Coder-32B-Instruct',# it is possible that this model may be overloaded
custom_role_conversions=None,
)
# Import tool from Hub
image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
agent = CodeAgent(
model=model,
tools=[final_answer,generate_bar_chart,generate_scatter_plot_with_labels,generate_line_plot,generate_histogram,generate_pie_chart,generate_box_plot], ## add your tools here (don't remove final answer)
max_steps=6,
verbosity_level=2,
grammar=None,
planning_interval=None,
name=None,
description=None,
prompt_templates=prompt_templates
)
GradioUI(agent).launch()