sheami / tests /generate_test_data.py
vikramvasudevan's picture
Upload folder using huggingface_hub
2fcea48 verified
"""
scripts/generate_test_data.py
Generates realistic test data for Sheami using your modules.db.SheamiDB API.
Behavior:
- Creates N users (default 100)
- Each user: 3-5 patients (enforced)
- Each patient: 2-6 reports
- Each report: 3-6 tests drawn from TEST_POOL
- For each patient we write trends (per test) using add_or_update_trend
- For each patient we write a final report using add_final_report
Usage:
pip install faker pymongo python-dotenv
MONGODB_URI="mongodb+srv://<user>:<pass>@cluster0.xxxxx.mongodb.net" \
MONGODB_DB="sheami" \
python scripts/generate_test_data.py --num-users 100
The script CALLS THESE EXACT methods on your SheamiDB:
- add_user(email, name)
- add_patient(user_id, name, dob, gender)
- add_report(patient_id, file_name, parsed_data)
- add_or_update_trend(patient_id, test_name, trend_data)
- add_final_report(patient_id, summary, recommendations, trend_snapshots)
"""
import argparse
import random
from collections import defaultdict
from datetime import datetime, timedelta
import os
from faker import Faker
from dotenv import load_dotenv
# Ensure env is loaded
load_dotenv()
# import your DB wrapper
from modules.db import SheamiDB
# ---------- Config & test pool ----------
faker = Faker()
TEST_POOL = {
"Hemoglobin": (11.0, 17.5, "g/dL", "11.0-17.5"),
"Glucose (Fasting)": (60, 130, "mg/dL", "70-99 fasting"),
"Total Cholesterol": (120, 300, "mg/dL", "<200 desirable"),
"Triglycerides": (40, 300, "mg/dL", "<150 normal"),
"HDL": (30, 90, "mg/dL", ">40 desirable"),
"LDL": (50, 200, "mg/dL", "<100 ideal"),
"Creatinine": (0.5, 1.8, "mg/dL", "0.5-1.2"),
"Urea (BUN)": (7, 30, "mg/dL", "7-20"),
"Sodium": (130, 150, "mmol/L", "135-145"),
"Potassium": (3.2, 5.2, "mmol/L", "3.5-5.0"),
"ALT": (7, 55, "U/L", "<45"),
"AST": (8, 48, "U/L", "<40"),
}
def random_date_between(start_year=2019):
start = datetime(start_year, 1, 1)
end = datetime.now()
days = (end - start).days
return start + timedelta(days=random.randint(0, days))
def make_test_values(k):
"""Return list of test dicts matching parsed_data.tests schema."""
chosen = random.sample(list(TEST_POOL.items()), k=k)
tests = []
for name, (low, high, unit, ref) in chosen:
# generate float for float ranges, int for integer-like
if isinstance(low, float) or isinstance(high, float):
value = round(random.uniform(low, high), 2)
else:
value = int(round(random.uniform(low, high)))
tests.append({
"name": name,
"value": value,
"unit": unit,
"reference_range": ref
})
return tests
def compute_direction(points):
if len(points) < 2:
return "stable"
if points[-1]["value"] > points[-2]["value"]:
return "increasing"
if points[-1]["value"] < points[-2]["value"]:
return "decreasing"
return "stable"
# ---------- Generator function ----------
def generate_test_data(db_uri: str, db_name: str, num_users: int = 100,
min_patients=3, max_patients=5,
min_reports=2, max_reports=6,
min_tests=3, max_tests=6,
seed: int = None):
if seed is not None:
random.seed(seed)
Faker.seed(seed)
db = SheamiDB(db_uri, db_name=db_name)
counters = {"users": 0, "patients": 0, "reports": 0, "trends": 0, "final_reports": 0}
for u_idx in range(num_users):
# create user
user_name = faker.name()
user_email = faker.unique.safe_email()
user_id = db.add_user(email=user_email, name=user_name)
counters["users"] += 1
# 3-5 patients per user (as requested)
num_patients = random.randint(min_patients, max_patients)
for _p in range(num_patients):
patient_name = faker.name()
# realistic DOB between 18 and 85
age = random.randint(18, 85)
dob_dt = datetime.now() - timedelta(days=365 * age + random.randint(0, 365))
dob_str = dob_dt.strftime("%Y-%m-%d")
gender = random.choice(["male", "female", "other"])
patient_id = db.add_patient(user_id=user_id, name=patient_name, dob=dob_str, gender=gender)
counters["patients"] += 1
# collect trend points per test name
trends_map = defaultdict(list)
# 2-6 reports per patient
num_reports = random.randint(min_reports, max_reports)
for r_i in range(num_reports):
report_date_dt = random_date_between()
report_date = report_date_dt.strftime("%Y-%m-%d")
num_tests = random.randint(min_tests, max_tests)
tests = make_test_values(num_tests)
parsed_data = {
"tests": tests,
"report_date": report_date
}
file_name = f"report_{report_date.replace('-', '')}_{random.randint(1000,9999)}.pdf"
report_id = db.add_report(patient_id=patient_id, file_name=file_name, parsed_data=parsed_data)
counters["reports"] += 1
# append to trends_map
for t in tests:
trends_map[t["name"]].append({"date": report_date, "value": t["value"]})
# write trends to DB using add_or_update_trend (upsert)
for test_name, points in trends_map.items():
# sort points by date
pts_sorted = sorted(points, key=lambda x: x["date"])
db.add_or_update_trend(patient_id=patient_id, test_name=test_name, trend_data=pts_sorted)
counters["trends"] += 1
# create a final report summarizing trends
trend_snapshots = []
for test_name, points in trends_map.items():
pts_sorted = sorted(points, key=lambda x: x["date"])
latest_value = pts_sorted[-1]["value"]
direction = compute_direction(pts_sorted)
trend_snapshots.append({
"test_name": test_name,
"latest_value": latest_value,
"direction": direction
})
summary = f"Auto-generated summary for {patient_name} ({len(trend_snapshots)} tests)"
recommendations = []
# simple heuristic: if any trending up, recommend follow-up
if any(ts["direction"] == "increasing" for ts in trend_snapshots):
recommendations.append("Follow up for rising values")
else:
recommendations.append("Continue routine monitoring")
db.add_final_report(patient_id=patient_id,
summary=summary,
recommendations=recommendations,
trend_snapshots=trend_snapshots)
counters["final_reports"] += 1
# occasional progress print
if (u_idx + 1) % 10 == 0 or (u_idx + 1) == num_users:
print(f"Created {u_idx+1}/{num_users} users so far...")
# summary
print("Generation complete. Summary:")
for k, v in counters.items():
print(f" {k}: {v}")
# ---------- CLI ----------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate test data for Sheami (matches your db.py).")
parser.add_argument("--num-users", type=int, default=100, help="Number of users to create")
parser.add_argument("--db-uri", type=str, default=os.getenv("MONGODB_URI", "mongodb://localhost:27017"),
help="MongoDB connection URI")
parser.add_argument("--db-name", type=str, default=os.getenv("MONGODB_DB", "sheami"),
help="Database name")
parser.add_argument("--seed", type=int, default=None, help="Random seed (optional)")
args = parser.parse_args()
generate_test_data(db_uri=args.db_uri, db_name=args.db_name,
num_users=args.num_users, seed=args.seed)