pagezyhf's picture
pagezyhf HF Staff
test
213e464
from datasets import load_dataset
import pandas as pd
import duckdb
import matplotlib.pyplot as plt
import seaborn as sns # Import Seaborn
import plotly.express as px # Added for Plotly
import plotly.graph_objects as go # Added for Plotly error figure
import gradio as gr
import os
from huggingface_hub import login
from datetime import datetime, timedelta
import sys # Added for error logging
# Get token from environment variable
HF_TOKEN = os.getenv('HF_TOKEN')
if not HF_TOKEN:
raise ValueError("Please set the HF_TOKEN environment variable")
# Login to Hugging Face
login(token=HF_TOKEN)
# Apply Seaborn theme and context globally
sns.set_theme(style="whitegrid")
sns.set_context("notebook")
# Load dataset once at startup
try:
dataset = load_dataset("reach-vb/trending-repos", split="models")
df = dataset.to_pandas()
# Register the pandas DataFrame as a DuckDB table named 'models'
# This allows the SQL query to use 'FROM models'
duckdb.register('models', df)
except Exception as e:
print(f"Error loading dataset: {e}")
raise
def get_retention_data(start_date: str, end_date: str) -> pd.DataFrame:
try:
# The input start_date and end_date are already strings in YYYY-MM-DD format.
# We can pass them directly to DuckDB if the SQL column is DATE.
query = """
WITH model_presence AS (
SELECT
id AS model_id,
collected_at::DATE AS collection_day
FROM models
),
daily_model_counts AS (
SELECT
collection_day,
COUNT(*) AS total_models_today
FROM model_presence
GROUP BY collection_day
),
retained_models AS (
SELECT
a.collection_day,
COUNT(*) AS previously_existed_count
FROM model_presence a
JOIN model_presence b
ON a.model_id = b.model_id
AND a.collection_day = b.collection_day + INTERVAL '1 day'
GROUP BY a.collection_day
)
SELECT
d.collection_day,
d.total_models_today,
COALESCE(r.previously_existed_count, 0) AS carried_over_models,
CASE
WHEN d.total_models_today = 0 THEN NULL
ELSE ROUND(COALESCE(r.previously_existed_count, 0) * 100.0 / d.total_models_today, 2)
END AS percent_retained
FROM daily_model_counts d
LEFT JOIN retained_models r ON d.collection_day = r.collection_day
WHERE d.collection_day BETWEEN ? AND ?
ORDER BY d.collection_day
"""
# Pass the string dates directly to the query, using the 'params' keyword argument.
result = duckdb.query(query, params=[start_date, end_date]).to_df()
print("SQL Query Result:") # Log the result
print(result) # Log the result
return result
except Exception as e:
# Log the error to standard error
print(f"Error in get_retention_data: {e}", file=sys.stderr)
# Return empty DataFrame with error message
return pd.DataFrame({"Error": [str(e)]})
def plot_retention_data(dataframe: pd.DataFrame):
print("DataFrame received by plot_retention_data (first 5 rows):")
print(dataframe.head())
print("\nData types in plot_retention_data before any conversion:")
print(dataframe.dtypes)
# Check if the DataFrame itself is an error signal from the previous function
if "Error" in dataframe.columns and not dataframe.empty:
error_message = dataframe['Error'].iloc[0]
print(f"Error DataFrame received: {error_message}", file=sys.stderr)
fig = go.Figure()
fig.add_annotation(
text=f"Error from data generation: {error_message}",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
)
return fig
try:
# Ensure 'percent_retained' column exists
if 'percent_retained' not in dataframe.columns:
raise ValueError("'percent_retained' column is missing from the DataFrame.")
if 'collection_day' not in dataframe.columns:
raise ValueError("'collection_day' column is missing from the DataFrame.")
# Explicitly convert 'percent_retained' to numeric.
# Ensure 'percent_retained' is numeric and 'collection_day' is datetime for Plotly
dataframe['percent_retained'] = pd.to_numeric(dataframe['percent_retained'], errors='coerce')
dataframe['collection_day'] = pd.to_datetime(dataframe['collection_day'])
# Drop rows where 'percent_retained' could not be converted (became NaT)
dataframe.dropna(subset=['percent_retained', 'collection_day'], inplace=True)
print("\n'percent_retained' column after pd.to_numeric (first 5 values):")
print(dataframe['percent_retained'].head())
print("'percent_retained' dtype after pd.to_numeric:", dataframe['percent_retained'].dtype)
print("\n'collection_day' column after pd.to_datetime (first 5 values):")
print(dataframe['collection_day'].head())
print("'collection_day' dtype after pd.to_datetime:", dataframe['collection_day'].dtype)
if dataframe.empty:
fig = go.Figure()
fig.add_annotation(
text="No data available to plot after processing.",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
)
return fig
# Create Plotly bar chart
fig = px.bar(
dataframe,
x='collection_day',
y='percent_retained',
title='Previous Day Top 200 Trending Model Retention %',
labels={'collection_day': 'Date', 'percent_retained': 'Retention Rate (%)'},
text='percent_retained' # Use the column directly for hover/text
)
# Format the text on bars
fig.update_traces(
texttemplate='%{text:.2f}%',
textposition='inside',
insidetextanchor='middle', # Anchor text to the middle of the bar
textfont_color='white',
textfont_size=10, # Adjusted size for better fit
hovertemplate='<b>Date</b>: %{x|%Y-%m-%d}<br>' +
'<b>Retention</b>: %{y:.2f}%<extra></extra>' # Custom hover
)
# Calculate and plot the average retention line
if not dataframe['percent_retained'].empty:
average_retention = dataframe['percent_retained'].mean()
fig.add_hline(
y=average_retention,
line_dash="dash",
line_color="red",
annotation_text=f"Average: {average_retention:.2f}%",
annotation_position="bottom right"
)
fig.update_xaxes(tickangle=45)
fig.update_layout(
title_x=0.5, # Center title
xaxis_title="Date",
yaxis_title="Retention Rate (%)",
plot_bgcolor='white', # Set plot background to white like seaborn whitegrid
bargap=0.2 # Gap between bars of different categories
)
return fig
except Exception as e:
print(f"Error during plot_retention_data: {e}", file=sys.stderr)
fig = go.Figure()
fig.add_annotation(
text=f"Plotting Error: {str(e)}",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False,
font=dict(size=16)
)
return fig
def interface_fn(start_date, end_date):
result = get_retention_data(start_date, end_date)
return plot_retention_data(result)
# Get min and max dates from the dataset
min_date = datetime.fromisoformat(df['collected_at'].min()).date()
max_date = datetime.fromisoformat(df['collected_at'].max()).date()
iface = gr.Interface(
fn=interface_fn,
inputs=[
gr.Textbox(label="Start Date (YYYY-MM-DD)", value=min_date.strftime("%Y-%m-%d")),
gr.Textbox(label="End Date (YYYY-MM-DD)", value=max_date.strftime("%Y-%m-%d"))
],
outputs=gr.Plot(label="Model Retention Visualization"),
title="Model Retention Analysis",
description="Visualize model retention rates over time. Enter dates in YYYY-MM-DD format."
)
iface.launch()