Zekun Wu
add
0c50b28
raw
history blame contribute delete
No virus
4.83 kB
import streamlit as st
import pandas as pd
from model_util import get_mode_from_hf, calculate_flops_in_hugging_space
def calculate_flops_architecture(num_params, num_tokens):
"""
Calculate FLOPs based on model architecture.
"""
total_flops = 6 * num_params * num_tokens
threshold_flops = 10 ** 25
meets_criteria = total_flops > threshold_flops
return {
"total_flops": total_flops,
"meets_criteria": meets_criteria
}
def calculate_flops_gpu(gpu_hours, power_consumption_w, flops_per_gpu_s):
"""
Calculate FLOPs based on GPU hours and type.
"""
threshold_flops = 10 ** 25
total_energy_wh = gpu_hours * power_consumption_w
total_flops = gpu_hours * flops_per_gpu_s * 3600
meets_criteria = total_flops > threshold_flops
return {
"total_energy_wh": total_energy_wh,
"total_flops": total_flops,
"meets_criteria": meets_criteria
}
def calculate_flops_hf(model_name, input_shape, access_token, bp_factor):
"""
Calculate FLOPs using Hugging Face model information.
"""
model = get_mode_from_hf(model_name=model_name, library="auto", access_token=access_token)
data, return_print = calculate_flops_in_hugging_space(model_name=model_name, empty_model=model,
access_token=access_token, input_shape=input_shape,
bp_factor=bp_factor)#, output_unit=output_unit)
print(f"Data: {data}")
total_flops = data[0]['Forward+Backward FLOPs']
threshold_flops = 10 ** 25
meets_criteria = total_flops > threshold_flops
return {
"total_flops": total_flops,
"meets_criteria": meets_criteria,
"dataframe": data,
"return_print": return_print
}
st.title("FLOPs Calculator for EU AI Act High Impact Capabilities")
# Choose calculation method
calculation_method = st.selectbox("Choose Calculation Method:",
["Model Architecture", "GPU Hours and Type", "Hugging Face Model"])
if calculation_method == "Model Architecture":
num_params = st.number_input("Number of Parameters (N):", min_value=0.0, value=float(7.0 * 10 ** 9), step=1.0)
num_tokens = st.number_input("Number of Training Tokens (D):", min_value=0.0, value=float(1500 * 10 ** 9), step=1.0)
if st.button("Calculate FLOPs (Model Architecture)"):
result = calculate_flops_architecture(num_params, num_tokens)
st.write(f"Total FLOPs: {result['total_flops']:.2e} FLOPs")
st.write(f"Meets high impact capabilities criteria: {result['meets_criteria']}")
elif calculation_method == "GPU Hours and Type":
# Define GPU types and their corresponding FLOPs per second
gpu_types = {
"Nvidia A100": {"flops_per_s": 312 * 10 ** 12, "power": 400},
"Nvidia V100": {"flops_per_s": 125 * 10 ** 12, "power": 300},
"Nvidia H100": {"flops_per_s": 1.25 * 10 ** 15, "power": 700},
"Nvidia T4": {"flops_per_s": 65 * 10 ** 12, "power": 70}
}
gpu_type = st.selectbox("Select GPU type:", list(gpu_types.keys()))
flops_per_gpu_s = gpu_types[gpu_type]["flops_per_s"]
power_consumption_w = gpu_types[gpu_type]["power"]
gpu_hours = st.number_input("Total GPU hours used for training:", min_value=0.0, value=float(7.7 * 10 ** 6),
step=1.0)
if st.button("Calculate FLOPs (GPU Hours and Type)"):
result = calculate_flops_gpu(gpu_hours, power_consumption_w, flops_per_gpu_s)
st.write(f"Total energy consumption: {result['total_energy_wh']:.2e} Wh")
st.write(f"Total FLOPs: {result['total_flops']:.2e} FLOPs")
st.write(f"Meets high impact capabilities criteria: {result['meets_criteria']}")
elif calculation_method == "Hugging Face Model":
model_name = st.text_input("Model Name:", "tiiuae/falcon-40b")
batch_size = st.number_input("Batch Size:", min_value=1, value=1)
max_seq_length = st.number_input("Max Sequence Length:", min_value=1, value=128)
input_shape = (batch_size, max_seq_length)
access_token = st.text_input("Hugging Face Access Token:", "")
bp_factor = st.number_input("Backward Pass Factor (BP Factor):", min_value=0.0, value=1.0, step=0.1)
#output_unit = st.selectbox("Output Unit:", ["auto", "MFLOPs", "GFLOPs", "TFLOPs", "PFLOPs"])
if st.button("Calculate FLOPs (Hugging Face Model)"):
result = calculate_flops_hf(model_name, input_shape, access_token, bp_factor) #output_unit)
st.write(f"Total FLOPs: {result['total_flops']:.2e} FLOPs")
st.write(f"Meets high impact capabilities criteria: {result['meets_criteria']}")
st.write("Detailed FLOPs Data:")
st.dataframe(result["dataframe"])
st.text(result["return_print"])