ThomasSimonini's picture
Adding str accuracy
57059b9
raw history blame
No virus
1.86 kB
import pandas as pd
import requests
from tqdm.auto import tqdm
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.repocard import metadata_load
# Based on Omar Sanseviero work
# Make model clickable link
def make_clickable_model(model_name):
# remove user from model name
model_name_show = ' '.join(model_name.split('/')[1:])
link = "https://huggingface.co/" + model_name
return f'<a target="_blank" href="{link}">{model_name_show}</a>'
# Make user clickable link
def make_clickable_user(user_id):
link = "https://huggingface.co/" + user_id
return f'<a target="_blank" href="{link}">{user_id}</a>'
def get_model_ids(rl_env):
api = HfApi()
models = api.list_models(filter=rl_env)
model_ids = [x.modelId for x in models]
return model_ids
def get_metadata(model_id):
try:
readme_path = hf_hub_download(model_id, filename="README.md")
return metadata_load(readme_path)
except requests.exceptions.HTTPError:
# 404 README.md not found
return None
def parse_metrics_accuracy(meta):
if "model-index" not in meta:
return None
result = meta["model-index"][0]["results"]
metrics = result[0]["metrics"]
accuracy = metrics[0]["value"]
return accuracy
# We keep the worst case episode
def parse_rewards(accuracy):
default_std = -1000
default_reward=-1000
if accuracy != None:
accuracy = str(accuracy)
parsed = accuracy.split(' +/- ')
if len(parsed)>1:
mean_reward = float(parsed[0])
std_reward = float(parsed[1])
else:
mean_reward = float(default_std)
std_reward = float(default_reward)
else:
mean_reward = float(default_std)
std_reward = float(default_reward)
return mean_reward, std_reward