engrrifatullah's picture
Update app.py
4c0e4e6 verified
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
from groq import Groq
# Set up the API key (replace with your actual Groq API key)
os.environ["GROQ_API_KEY"] = "My_API_key"
# Groq client setup
client = Groq(api_key=os.environ.get("My_API_key"))
# Function to generate traffic optimization strategies from Groq
def generate_traffic_optimization(data: str):
response = client.chat.completions.create(
messages=[
{
"role": "user",
"content": f"Generate a detailed traffic flow optimization strategy for the following data. Include peak hours, vehicle type distributions, and actionable suggestions to improve flow: {data}",
}
],
model="llama3-8b-8192", # You can choose a different model if necessary
)
return response.choices[0].message.content
# Visualization function to generate the traffic flow chart
def generate_traffic_chart(df_filtered):
# Summarize traffic data by aggregating vehicle counts for each record
df_filtered['Total'] = df_filtered['CarCount'] + df_filtered['BikeCount'] + df_filtered['BusCount'] + df_filtered['TruckCount']
# Select relevant columns for the optimization request (you can adjust this as needed)
traffic_data_summary = df_filtered[['Time', 'Total', 'Traffic Situation', 'CarCount', 'BikeCount', 'BusCount', 'TruckCount']]
# Calculate some basic statistics
avg_car_count = df_filtered['CarCount'].mean()
avg_bike_count = df_filtered['BikeCount'].mean()
avg_bus_count = df_filtered['BusCount'].mean()
avg_truck_count = df_filtered['TruckCount'].mean()
peak_traffic_time = df_filtered.loc[df_filtered['Total'].idxmax()]['Time']
# Generate summary for Groq API input
summary_str = traffic_data_summary.head(10).to_string(index=False)
# Get the optimization strategy from Groq
optimization_strategy = generate_traffic_optimization(summary_str)
# Visualization of traffic flow data
time_labels = df_filtered['Time'].head(10)
car_counts = df_filtered['CarCount'].head(10)
bike_counts = df_filtered['BikeCount'].head(10)
bus_counts = df_filtered['BusCount'].head(10)
truck_counts = df_filtered['TruckCount'].head(10)
# Create the stacked bar chart for vehicle counts
fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(time_labels, car_counts, label='Cars', color='blue')
ax.bar(time_labels, bike_counts, bottom=car_counts, label='Bikes', color='green')
ax.bar(time_labels, bus_counts, bottom=np.array(car_counts) + np.array(bike_counts), label='Buses', color='red')
ax.bar(time_labels, truck_counts, bottom=np.array(car_counts) + np.array(bike_counts) + np.array(bus_counts), label='Trucks', color='yellow')
# Customize chart
ax.set_xlabel('Time')
ax.set_ylabel('Vehicle Count')
ax.set_title('Traffic Flow by Vehicle Type')
ax.legend()
# Save the plot as a file
plt.xticks(rotation=45)
plt.tight_layout()
chart_path = "/tmp/traffic_chart.png"
plt.savefig(chart_path)
plt.close()
# Return additional insights
insights = f"""
Average Car Count: {avg_car_count:.2f}
Average Bike Count: {avg_bike_count:.2f}
Average Bus Count: {avg_bus_count:.2f}
Average Truck Count: {avg_truck_count:.2f}
Peak Traffic Time: {peak_traffic_time}
"""
return optimization_strategy + "\n\n" + insights, chart_path
# Function to process the uploaded file and run traffic optimization
def process_traffic_file(file):
# Load the dataset
df = pd.read_csv(file.name)
# Optionally, you may filter data for specific days or time intervals
# For example, let's filter the data for a specific day:
df_filtered = df[df['Day of the week'] == 'Monday']
# Generate traffic chart and optimization strategy
optimization_strategy, chart_path = generate_traffic_chart(df_filtered)
return optimization_strategy, chart_path
# Gradio interface
iface = gr.Interface(
fn=process_traffic_file,
inputs=gr.File(label="Upload CSV with Traffic Data"),
outputs=[gr.Textbox(label="Optimization Strategy and Insights"), gr.Image(label="Traffic Flow Chart")],
title="Traffic Flow Optimization",
description="Upload a CSV file with traffic data, and the app will generate traffic optimization strategies, provide insights, and visualize the traffic flow."
)
# Launch the app
iface.launch()