Spaces:
Sleeping
Sleeping
| """ | |
| ClauseGuard API — HuggingFace Spaces Deployment | |
| Loads Legal-BERT from Hub, serves clause classification. | |
| """ | |
| import os | |
| import time | |
| import re | |
| from contextlib import asynccontextmanager | |
| from typing import Optional | |
| import httpx | |
| import numpy as np | |
| from fastapi import FastAPI, HTTPException, Depends, Header | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| # ─── Config ─── | |
| HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-legal-bert") | |
| SUPABASE_URL = os.environ.get("SUPABASE_URL", "") | |
| SUPABASE_SERVICE_KEY = os.environ.get("SUPABASE_SERVICE_ROLE_KEY", "") | |
| LABEL_NAMES = [ | |
| "Limitation of liability", "Unilateral termination", "Unilateral change", | |
| "Content removal", "Contract by using", "Choice of law", "Jurisdiction", "Arbitration", | |
| ] | |
| LABEL_DESCRIPTIONS = { | |
| "Limitation of liability": "Company limits or excludes liability for losses, data breaches, or service failures.", | |
| "Unilateral termination": "Company can terminate your account at any time without reason.", | |
| "Unilateral change": "Company can change terms at any time without your consent.", | |
| "Content removal": "Company can delete your content without notice or justification.", | |
| "Contract by using": "You are bound to the contract simply by using the service.", | |
| "Choice of law": "Governing law may differ from your country, reducing your legal protections.", | |
| "Jurisdiction": "Disputes must be resolved in a jurisdiction that may disadvantage you.", | |
| "Arbitration": "Forces disputes to arbitration instead of court. You waive your right to sue.", | |
| } | |
| SEVERITY_MAP = { | |
| "Limitation of liability": "HIGH", "Unilateral termination": "HIGH", "Arbitration": "HIGH", | |
| "Unilateral change": "MEDIUM", "Content removal": "MEDIUM", "Choice of law": "MEDIUM", | |
| "Jurisdiction": "MEDIUM", "Contract by using": "LOW", | |
| } | |
| LEGAL_BASIS = { | |
| "Arbitration": "EU Directive 93/13/EEC Art. 3; CFPB arbitration rule (US).", | |
| "Unilateral change": "EU Directive 93/13/EEC Annex 1(j) — unilateral alteration.", | |
| "Content removal": "EU Digital Services Act Art. 17 — statement of reasons required.", | |
| "Jurisdiction": "EU Regulation 1215/2012 Art. 18 — consumer domicile prevails.", | |
| "Choice of law": "EU Regulation 593/2008 Art. 6 — consumer protection of habitual residence.", | |
| "Limitation of liability": "EU Directive 93/13/EEC Annex 1(a) — excluding statutory rights.", | |
| "Unilateral termination": "EU Directive 93/13/EEC Annex 1(f)(g) — termination without notice.", | |
| "Contract by using": "EU Directive 2011/83/EU Art. 8 — active consent required.", | |
| } | |
| # ─── ML Model ─── | |
| classifier = None | |
| def load_model(): | |
| global classifier | |
| try: | |
| from transformers import pipeline | |
| print(f"Loading model from Hub: {HUB_MODEL_ID}") | |
| classifier = pipeline("text-classification", model=HUB_MODEL_ID, top_k=None, device=-1) | |
| print(f"Model loaded successfully") | |
| except Exception as e: | |
| print(f"Model load failed: {e} — using regex fallback") | |
| # ─── Regex fallback ─── | |
| PATTERNS = { | |
| 0: [r"not liable", r"shall not be (liable|responsible)", r"in no event.*liable", r"limitation of liability", r"without warranty", r"disclaim"], | |
| 1: [r"terminat.*at any time", r"suspend.*account.*without", r"we may (terminat|suspend|discontinu)", r"right to (terminat|suspend)"], | |
| 2: [r"sole discretion", r"reserves? the right to (modify|change|update|amend)", r"at any time.*without (prior )?notice", r"we may (modify|change|update)"], | |
| 3: [r"remove.*content.*without", r"right to remove", r"we may.*remove"], | |
| 4: [r"by (using|accessing).*you agree", r"continued use.*constitutes? acceptance"], | |
| 5: [r"governed by.*laws? of", r"shall be governed", r"laws of the state of"], | |
| 6: [r"exclusive jurisdiction", r"courts? of.*(california|delaware|new york|ireland|england)", r"submit to.*jurisdiction"], | |
| 7: [r"arbitrat", r"binding arbitration", r"waive.*right.*court", r"class action waiver"], | |
| } | |
| def classify_clause(text: str) -> list[dict]: | |
| if classifier: | |
| try: | |
| preds = classifier(text, truncation=True, max_length=512) | |
| items = preds[0] if isinstance(preds[0], list) else preds | |
| return [ | |
| {"name": p["label"], "severity": SEVERITY_MAP.get(p["label"], "MEDIUM"), | |
| "description": LABEL_DESCRIPTIONS.get(p["label"], ""), "confidence": round(p["score"], 3)} | |
| for p in items if p["score"] > 0.5 and p["label"] in LABEL_DESCRIPTIONS | |
| ] | |
| except Exception: | |
| pass | |
| results = [] | |
| text_lower = text.lower() | |
| for lid, pats in PATTERNS.items(): | |
| for p in pats: | |
| if re.search(p, text_lower): | |
| name = LABEL_NAMES[lid] | |
| results.append({"name": name, "severity": SEVERITY_MAP[name], | |
| "description": LABEL_DESCRIPTIONS[name], "confidence": 0.7}) | |
| break | |
| return results | |
| # ─── Auth (simplified for HF Spaces — no Supabase dependency required) ─── | |
| async def get_optional_user(authorization: Optional[str] = Header(None)) -> Optional[dict]: | |
| if not authorization: | |
| return None | |
| # In production, validate JWT here. For now, extract user ID from token claims. | |
| return None | |
| # ─── Supabase helpers ─── | |
| async def supabase_insert(table: str, data: dict): | |
| if not SUPABASE_URL or not SUPABASE_SERVICE_KEY: | |
| return | |
| async with httpx.AsyncClient() as client: | |
| await client.post( | |
| f"{SUPABASE_URL}/rest/v1/{table}", json=data, | |
| headers={"apikey": SUPABASE_SERVICE_KEY, "Authorization": f"Bearer {SUPABASE_SERVICE_KEY}", | |
| "Content-Type": "application/json", "Prefer": "return=minimal"}, | |
| ) | |
| # ─── Models ─── | |
| class AnalyzeRequest(BaseModel): | |
| clauses: list[str] = Field(..., min_length=1, max_length=500) | |
| source_url: Optional[str] = None | |
| class AnalyzeResponse(BaseModel): | |
| risk_score: int | |
| grade: str | |
| total_clauses: int | |
| flagged_count: int | |
| results: list[dict] | |
| model: str | |
| latency_ms: int | |
| class ExplainRequest(BaseModel): | |
| clause: str = Field(..., min_length=10, max_length=2000) | |
| category: str | |
| class ExplainResponse(BaseModel): | |
| clause: str | |
| category: str | |
| explanation: str | |
| legal_basis: str | |
| recommendation: str | |
| # ─── App ─── | |
| async def lifespan(app: FastAPI): | |
| load_model() | |
| yield | |
| app = FastAPI( | |
| title="ClauseGuard API", | |
| description="AI-powered unfair clause detection. Send contract clauses, get risk scores.", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| return { | |
| "name": "ClauseGuard API", | |
| "status": "running", | |
| "model": "ml" if classifier else "regex", | |
| "docs": "/docs", | |
| } | |
| async def health(): | |
| return {"status": "ok", "model": "ml" if classifier else "regex"} | |
| async def analyze(req: AnalyzeRequest): | |
| start = time.time() | |
| results = [{"text": c, "categories": classify_clause(c)} for c in req.clauses] | |
| flagged = [r for r in results if r["categories"]] | |
| sev = {"HIGH": 0, "MEDIUM": 0, "LOW": 0} | |
| for r in flagged: | |
| for c in r["categories"]: | |
| sev[c.get("severity", "LOW")] += 1 | |
| total = len(req.clauses) | |
| risk = min(100, round((sev["HIGH"] * 20 + sev["MEDIUM"] * 10 + sev["LOW"] * 5) / max(1, total) * 100)) | |
| grade = "F" if risk >= 60 else "D" if risk >= 40 else "C" if risk >= 20 else "B" if risk >= 10 else "A" | |
| latency = int((time.time() - start) * 1000) | |
| return AnalyzeResponse( | |
| risk_score=risk, grade=grade, total_clauses=total, | |
| flagged_count=len(flagged), results=results, | |
| model="ml" if classifier else "regex", latency_ms=latency, | |
| ) | |
| async def explain(req: ExplainRequest): | |
| desc = LABEL_DESCRIPTIONS.get(req.category, "Unknown category.") | |
| legal = LEGAL_BASIS.get(req.category, "Consult local consumer protection laws.") | |
| return ExplainResponse( | |
| clause=req.clause, category=req.category, | |
| explanation=desc, legal_basis=legal, | |
| recommendation="Review this clause carefully. Consider negotiating or seeking legal advice.", | |
| ) | |