|
|
import os |
|
|
import json |
|
|
import io |
|
|
import tempfile |
|
|
from datetime import datetime |
|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
from huggingface_hub import HfApi, create_repo, hf_hub_download, list_repo_files |
|
|
from storage.storage_interface import StorageProvider |
|
|
|
|
|
class HuggingFaceStorageProvider(StorageProvider): |
|
|
""" |
|
|
A storage provider that works directly with Hugging Face Hub API |
|
|
(without using the datasets library or local caching) |
|
|
""" |
|
|
|
|
|
def __init__(self, hf_token=None, repo_id=None): |
|
|
"""Initialize the provider with token and repo_id""" |
|
|
self.hf_token = hf_token or self._get_token() |
|
|
self.repo_id = repo_id or self._get_repo_id() |
|
|
self.initialized = False |
|
|
self.api = None |
|
|
self.parquet_file = "reports.parquet" |
|
|
|
|
|
def _get_token(self): |
|
|
"""Get HF token from environment variables, secrets, or a fallback method""" |
|
|
token = None |
|
|
|
|
|
token = os.getenv("HF_TOKEN") |
|
|
if token: |
|
|
|
|
|
return token |
|
|
|
|
|
for secret_name in ["hf_token", "HF_TOKEN", "Hf_Token"]: |
|
|
if secret_name in st.secrets: |
|
|
token = st.secrets[secret_name] |
|
|
|
|
|
return token |
|
|
|
|
|
if os.getenv("SPACE_ID"): |
|
|
token = os.getenv("HF_TOKEN_READ") |
|
|
if token: |
|
|
|
|
|
return token |
|
|
|
|
|
|
|
|
|
|
|
return token |
|
|
|
|
|
def _get_repo_id(self): |
|
|
"""Get repo ID from environment variables, secrets, or construct from space name""" |
|
|
repo_id = None |
|
|
|
|
|
repo_id = os.getenv("HF_REPO_ID") |
|
|
if repo_id: |
|
|
|
|
|
return repo_id |
|
|
|
|
|
for secret_name in ["hf_repo_id", "HF_REPO_ID", "Hf_Repo_Id"]: |
|
|
if secret_name in st.secrets: |
|
|
repo_id = st.secrets[secret_name] |
|
|
|
|
|
return repo_id |
|
|
|
|
|
|
|
|
if os.getenv("SPACE_ID"): |
|
|
space_id = os.getenv("SPACE_ID") |
|
|
if space_id: |
|
|
username = space_id.split("/")[0] |
|
|
constructed_repo_id = f"{username}/ai-flaw-reports" |
|
|
|
|
|
return constructed_repo_id |
|
|
|
|
|
|
|
|
|
|
|
return repo_id |
|
|
|
|
|
def initialize(self): |
|
|
"""Initialize the Hugging Face API client""" |
|
|
try: |
|
|
if not self.hf_token: |
|
|
|
|
|
|
|
|
return False |
|
|
|
|
|
if not self.repo_id or self.repo_id == "default-user/ai-flaw-reports": |
|
|
|
|
|
|
|
|
return False |
|
|
|
|
|
self.api = HfApi(token=self.hf_token) |
|
|
|
|
|
|
|
|
try: |
|
|
self.api.repo_info( |
|
|
repo_id=self.repo_id, |
|
|
repo_type="dataset" |
|
|
) |
|
|
|
|
|
except Exception as repo_error: |
|
|
|
|
|
try: |
|
|
create_repo( |
|
|
repo_id=self.repo_id, |
|
|
token=self.hf_token, |
|
|
repo_type="dataset", |
|
|
private=True |
|
|
) |
|
|
|
|
|
except Exception as create_error: |
|
|
return False |
|
|
|
|
|
self.initialized = True |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
return False |
|
|
|
|
|
def save_report(self, form_data): |
|
|
"""Save a report to the Hugging Face repository""" |
|
|
|
|
|
report_id = form_data.get("Report ID") |
|
|
if not report_id: |
|
|
report_id = f"report-{datetime.now().strftime('%Y%m%d-%H%M%S')}" |
|
|
form_data["Report ID"] = report_id |
|
|
|
|
|
|
|
|
|
|
|
if not self.initialized: |
|
|
|
|
|
if not self.initialize(): |
|
|
|
|
|
session_key = f"report_{report_id}" |
|
|
st.session_state[session_key] = { |
|
|
"form_data": form_data, |
|
|
"timestamp": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
return f"session_state:{session_key}", None |
|
|
|
|
|
try: |
|
|
|
|
|
machine_readable_output = None |
|
|
from form.data.schema import generate_machine_readable_output |
|
|
machine_readable_output = generate_machine_readable_output(form_data) |
|
|
|
|
|
report_data = { |
|
|
"form_data": form_data, |
|
|
"machine_readable": machine_readable_output, |
|
|
"timestamp": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
report_json = json.dumps(report_data, indent=2) |
|
|
|
|
|
report_path = f"reports/{report_id}.json" |
|
|
|
|
|
files = list_repo_files( |
|
|
repo_id=self.repo_id, |
|
|
repo_type="dataset", |
|
|
token=self.hf_token |
|
|
) |
|
|
|
|
|
if not any(f.startswith("reports/") for f in files): |
|
|
self.api.upload_file( |
|
|
path_or_fileobj=io.BytesIO(b""), |
|
|
path_in_repo="reports/.gitkeep", |
|
|
repo_id=self.repo_id, |
|
|
repo_type="dataset", |
|
|
commit_message="Create reports directory" |
|
|
) |
|
|
|
|
|
|
|
|
self.api.upload_file( |
|
|
path_or_fileobj=io.BytesIO(report_json.encode()), |
|
|
path_in_repo=report_path, |
|
|
repo_id=self.repo_id, |
|
|
repo_type="dataset", |
|
|
commit_message=f"Add/update report {report_id}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self._update_index_file(report_id, form_data) |
|
|
|
|
|
self._update_parquet_file(report_id, form_data, machine_readable_output) |
|
|
|
|
|
|
|
|
return f"huggingface:{self.repo_id}/{report_path}", machine_readable_output |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
|
|
|
session_key = f"report_{report_id}" |
|
|
st.session_state[session_key] = { |
|
|
"form_data": form_data, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"machine_readable": machine_readable_output |
|
|
} |
|
|
|
|
|
|
|
|
return f"session_state:{session_key}", machine_readable_output |
|
|
|
|
|
def _update_index_file(self, report_id, form_data): |
|
|
"""Update the index file with the new report information""" |
|
|
index_path = "reports_index.json" |
|
|
|
|
|
index_data = [] |
|
|
existing_index = hf_hub_download( |
|
|
repo_id=self.repo_id, |
|
|
filename=index_path, |
|
|
repo_type="dataset", |
|
|
token=self.hf_token |
|
|
) |
|
|
|
|
|
with open(existing_index, "r") as f: |
|
|
index_data = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
index_data = [r for r in index_data if r.get("report_id") != report_id] |
|
|
|
|
|
index_data.append({ |
|
|
"report_id": report_id, |
|
|
"report_status": form_data.get("Report Status", "Unknown"), |
|
|
"report_types": form_data.get("Report Types", []), |
|
|
"reporter_id": form_data.get("Reporter ID", "Anonymous"), |
|
|
"submission_timestamp": datetime.now().isoformat(), |
|
|
"file_path": f"reports/{report_id}.json" |
|
|
}) |
|
|
|
|
|
|
|
|
index_data.sort(key=lambda x: x.get("submission_timestamp", ""), reverse=True) |
|
|
|
|
|
index_json = json.dumps(index_data, indent=2) |
|
|
self.api.upload_file( |
|
|
path_or_fileobj=io.BytesIO(index_json.encode()), |
|
|
path_in_repo=index_path, |
|
|
repo_id=self.repo_id, |
|
|
repo_type="dataset", |
|
|
commit_message=f"Update index with report {report_id}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _update_parquet_file(self, report_id, form_data, machine_readable_output): |
|
|
"""Update the Parquet file with the new report data""" |
|
|
try: |
|
|
report_row = { |
|
|
"report_id": report_id, |
|
|
"report_status": form_data.get("Report Status", "Unknown"), |
|
|
"report_types": json.dumps(form_data.get("Report Types", [])), |
|
|
"reporter_id": form_data.get("Reporter ID", "Anonymous"), |
|
|
"submission_timestamp": datetime.now().isoformat(), |
|
|
"form_data": json.dumps(form_data), |
|
|
"machine_readable": json.dumps(machine_readable_output) if machine_readable_output else "" |
|
|
} |
|
|
|
|
|
new_df = pd.DataFrame([report_row]) |
|
|
|
|
|
|
|
|
try: |
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
parquet_path = os.path.join(tmp_dir, self.parquet_file) |
|
|
|
|
|
try: |
|
|
downloaded_file = hf_hub_download( |
|
|
repo_id=self.repo_id, |
|
|
filename=self.parquet_file, |
|
|
repo_type="dataset", |
|
|
token=self.hf_token, |
|
|
local_dir=tmp_dir, |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
|
|
|
existing_df = pd.read_parquet(downloaded_file) |
|
|
|
|
|
|
|
|
existing_df = existing_df[existing_df["report_id"] != report_id] |
|
|
|
|
|
updated_df = pd.concat([existing_df, new_df], ignore_index=True) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
updated_df = new_df |
|
|
|
|
|
updated_df.to_parquet(parquet_path, index=False) |
|
|
|
|
|
with open(parquet_path, "rb") as f: |
|
|
self.api.upload_file( |
|
|
path_or_fileobj=f, |
|
|
path_in_repo=self.parquet_file, |
|
|
repo_id=self.repo_id, |
|
|
repo_type="dataset", |
|
|
commit_message=f"Update Parquet file with report {report_id}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
except Exception as download_error: |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
parquet_path = os.path.join(tmp_dir, self.parquet_file) |
|
|
new_df.to_parquet(parquet_path, index=False) |
|
|
|
|
|
with open(parquet_path, "rb") as f: |
|
|
self.api.upload_file( |
|
|
path_or_fileobj=f, |
|
|
path_in_repo=self.parquet_file, |
|
|
repo_id=self.repo_id, |
|
|
repo_type="dataset", |
|
|
commit_message=f"Create new Parquet file with report {report_id}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
raise |
|
|
|
|
|
def get_report(self, report_id): |
|
|
"""Retrieve a report from the Hugging Face repository""" |
|
|
if not self.initialized: |
|
|
if not self.initialize(): |
|
|
session_key = f"report_{report_id}" |
|
|
if session_key in st.session_state: |
|
|
return st.session_state[session_key] |
|
|
return None |
|
|
|
|
|
try: |
|
|
report_path = f"reports/{report_id}.json" |
|
|
|
|
|
downloaded_file = hf_hub_download( |
|
|
repo_id=self.repo_id, |
|
|
filename=report_path, |
|
|
repo_type="dataset", |
|
|
token=self.hf_token |
|
|
) |
|
|
|
|
|
with open(downloaded_file, "r") as f: |
|
|
report_data = json.load(f) |
|
|
|
|
|
|
|
|
return report_data |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
session_key = f"report_{report_id}" |
|
|
if session_key in st.session_state: |
|
|
|
|
|
return st.session_state[session_key] |
|
|
|
|
|
return None |
|
|
|
|
|
def update_report(self, report_id, form_data): |
|
|
"""Update an existing report""" |
|
|
|
|
|
result, _ = self.save_report(form_data) |
|
|
return result.startswith("huggingface:") |
|
|
|
|
|
def list_reports(self, limit=100): |
|
|
"""List all reports in the repository""" |
|
|
reports = [] |
|
|
|
|
|
if not self.initialized: |
|
|
if not self.initialize(): |
|
|
for key in st.session_state: |
|
|
if key.startswith("report_"): |
|
|
report_id = key.replace("report_", "") |
|
|
data = st.session_state[key] |
|
|
form_data = data.get("form_data", {}) |
|
|
|
|
|
reports.append({ |
|
|
"report_id": report_id, |
|
|
"report_status": form_data.get("Report Status", "Unknown"), |
|
|
"report_types": form_data.get("Report Types", []), |
|
|
"reporter_id": form_data.get("Reporter ID", "Anonymous"), |
|
|
"submission_timestamp": data.get("timestamp", "Unknown") |
|
|
}) |
|
|
|
|
|
return reports[:limit] |
|
|
|
|
|
try: |
|
|
try: |
|
|
index_path = "reports_index.json" |
|
|
|
|
|
|
|
|
downloaded_index = hf_hub_download( |
|
|
repo_id=self.repo_id, |
|
|
filename=index_path, |
|
|
repo_type="dataset", |
|
|
token=self.hf_token |
|
|
) |
|
|
|
|
|
with open(downloaded_index, "r") as f: |
|
|
reports = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
except Exception as index_error: |
|
|
|
|
|
|
|
|
try: |
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
try: |
|
|
downloaded_file = hf_hub_download( |
|
|
repo_id=self.repo_id, |
|
|
filename=self.parquet_file, |
|
|
repo_type="dataset", |
|
|
token=self.hf_token, |
|
|
local_dir=tmp_dir, |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
|
|
|
df = pd.read_parquet(downloaded_file) |
|
|
|
|
|
|
|
|
for _, row in df.iterrows(): |
|
|
if isinstance(row["report_types"], str): |
|
|
report_types = json.loads(row["report_types"]) |
|
|
else: |
|
|
report_types = row["report_types"] |
|
|
|
|
|
reports.append({ |
|
|
"report_id": row["report_id"], |
|
|
"report_status": row["report_status"], |
|
|
"report_types": report_types, |
|
|
"reporter_id": row["reporter_id"], |
|
|
"submission_timestamp": row["submission_timestamp"] |
|
|
}) |
|
|
|
|
|
except Exception as parquet_error: |
|
|
|
|
|
raise |
|
|
|
|
|
except Exception: |
|
|
|
|
|
|
|
|
files = list_repo_files( |
|
|
repo_id=self.repo_id, |
|
|
repo_type="dataset", |
|
|
token=self.hf_token |
|
|
) |
|
|
|
|
|
report_files = [f for f in files if f.startswith("reports/") and f.endswith(".json")] |
|
|
|
|
|
|
|
|
|
|
|
for file_path in report_files[:limit]: |
|
|
report_id = file_path.replace("reports/", "").replace(".json", "") |
|
|
|
|
|
downloaded_file = hf_hub_download( |
|
|
repo_id=self.repo_id, |
|
|
filename=file_path, |
|
|
repo_type="dataset", |
|
|
token=self.hf_token |
|
|
) |
|
|
|
|
|
with open(downloaded_file, "r") as f: |
|
|
report_data = json.load(f) |
|
|
|
|
|
form_data = report_data.get("form_data", {}) |
|
|
|
|
|
reports.append({ |
|
|
"report_id": report_id, |
|
|
"report_status": form_data.get("Report Status", "Unknown"), |
|
|
"report_types": form_data.get("Report Types", []), |
|
|
"reporter_id": form_data.get("Reporter ID", "Anonymous"), |
|
|
"submission_timestamp": report_data.get("timestamp", "Unknown") |
|
|
}) |
|
|
|
|
|
|
|
|
for key in st.session_state: |
|
|
if key.startswith("report_"): |
|
|
report_id = key.replace("report_", "") |
|
|
|
|
|
|
|
|
if any(r.get("report_id") == report_id for r in reports): |
|
|
continue |
|
|
|
|
|
data = st.session_state[key] |
|
|
form_data = data.get("form_data", {}) |
|
|
|
|
|
reports.append({ |
|
|
"report_id": report_id, |
|
|
"report_status": form_data.get("Report Status", "Unknown"), |
|
|
"report_types": form_data.get("Report Types", []), |
|
|
"reporter_id": form_data.get("Reporter ID", "Anonymous"), |
|
|
"submission_timestamp": data.get("timestamp", "Unknown") |
|
|
}) |
|
|
|
|
|
|
|
|
reports.sort(key=lambda x: x.get("submission_timestamp", ""), reverse=True) |
|
|
|
|
|
return reports[:limit] |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
|
|
|
for key in st.session_state: |
|
|
if key.startswith("report_"): |
|
|
report_id = key.replace("report_", "") |
|
|
data = st.session_state[key] |
|
|
form_data = data.get("form_data", {}) |
|
|
|
|
|
reports.append({ |
|
|
"report_id": report_id, |
|
|
"report_status": form_data.get("Report Status", "Unknown"), |
|
|
"report_types": form_data.get("Report Types", []), |
|
|
"reporter_id": form_data.get("Reporter ID", "Anonymous"), |
|
|
"submission_timestamp": data.get("timestamp", "Unknown") |
|
|
}) |
|
|
|
|
|
return reports[:limit] |
|
|
|
|
|
def query_reports(self, query): |
|
|
""" |
|
|
Query reports (simplified implementation) |
|
|
|
|
|
Args: |
|
|
query (str): Query string to filter reports |
|
|
|
|
|
Returns: |
|
|
list: List of reports (without filtering for now) |
|
|
""" |
|
|
|
|
|
|
|
|
all_reports = self.list_reports(limit=1000) |
|
|
return all_reports |