Spaces:
Sleeping
Sleeping
| import os | |
| cache_dir = "/tmp/hf_cache" | |
| os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
| os.makedirs(cache_dir, exist_ok=True) | |
| from gradual.models import GradualInput, GradualOutput | |
| # from gradual.computations import compute_gradual_semantics | |
| from gradual.computations import compute_gradual_space | |
| from aba.aba_builder import prepare_aba_plus_framework, build_aba_framework_from_text | |
| from relations.predict_bert import predict_relation | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from fastapi.responses import FileResponse, StreamingResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| import torch | |
| import pandas as pd | |
| from pathlib import Path | |
| import asyncio | |
| import json | |
| import io | |
| from aba.models import ( | |
| RuleDTO, | |
| FrameworkSnapshot, | |
| TransformationStep, | |
| ABAApiResponseModel, | |
| ABAPlusDTO, | |
| MetaInfo, | |
| ) | |
| from copy import deepcopy | |
| from datetime import datetime | |
| # -------------------- Config -------------------- # | |
| ABA_EXAMPLES_DIR = Path("./aba/examples") | |
| SAMPLES_DIR = Path("./relations/examples/samples") | |
| GRADUAL_EXAMPLES_DIR = Path("./gradual/examples") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_name = "edgar-demeude/bert-argument" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| model.to(device) | |
| # -------------------- App -------------------- # | |
| app = FastAPI(title="Argument Mining API") | |
| origins = ["http://localhost:3000", "http://127.0.0.1:3000"] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # -------------------- Endpoints -------------------- # | |
| def root(): | |
| return {"message": "Argument Mining API is running..."} | |
| # --- Predictions --- # | |
| def predict_text(arg1: str = Form(...), arg2: str = Form(...)): | |
| """Predict relation between two text arguments using BERT.""" | |
| result = predict_relation(arg1, arg2, model, tokenizer, device) | |
| return {"arg1": arg1, "arg2": arg2, "relation": result} | |
| async def predict_csv_stream(file: UploadFile): | |
| """Stream CSV predictions progressively using SSE.""" | |
| content = await file.read() | |
| df = pd.read_csv(io.StringIO(content.decode("utf-8")), quotechar='"') | |
| if len(df) > 250: | |
| df = df.head(250) | |
| async def event_generator(): | |
| total = len(df) | |
| completed = 0 | |
| for _, row in df.iterrows(): | |
| try: | |
| result = predict_relation( | |
| row["parent"], row["child"], model, tokenizer, device) | |
| completed += 1 | |
| payload = { | |
| "parent": row["parent"], | |
| "child": row["child"], | |
| "relation": result, | |
| "progress": completed / total | |
| } | |
| yield f"data: {json.dumps(payload)}\n\n" | |
| # FORCER flush | |
| await asyncio.sleep(0) | |
| except Exception as e: | |
| yield f"data: {json.dumps({'error': str(e), 'parent': row.get('parent'), 'child': row.get('child')})}\n\n" | |
| await asyncio.sleep(0) | |
| return StreamingResponse(event_generator(), media_type="text/event-stream") | |
| def list_samples(): | |
| files = [f for f in os.listdir(SAMPLES_DIR) if f.endswith(".csv")] | |
| return {"samples": files} | |
| def get_sample(filename: str): | |
| file_path = os.path.join(SAMPLES_DIR, filename) | |
| if not os.path.exists(file_path): | |
| return {"error": "Sample not found"} | |
| return FileResponse(file_path, media_type="text/csv") | |
| # --- ABA --- # | |
| def _make_snapshot(fw) -> FrameworkSnapshot: | |
| return FrameworkSnapshot( | |
| language=[str(l) for l in sorted(fw.language, key=str)], | |
| assumptions=[str(a) for a in sorted(fw.assumptions, key=str)], | |
| rules=[ | |
| RuleDTO( | |
| id=r.rule_name, | |
| head=str(r.head), | |
| body=[str(b) for b in sorted(r.body, key=str)], | |
| ) | |
| for r in sorted(fw.rules, key=lambda r: r.rule_name) | |
| ], | |
| contraries=[ | |
| (str(c.contraried_literal), str(c.contrary_attacker)) | |
| for c in sorted(fw.contraries, key=str) | |
| ], | |
| preferences={ | |
| str(k): [str(v) for v in sorted(vals, key=str)] | |
| for k, vals in (fw.preferences or {}).items() | |
| } if getattr(fw, "preferences", None) else None, | |
| ) | |
| def _format_set(s) -> str: | |
| # s may be a Python set/frozenset of Literal or strings. | |
| try: | |
| items = sorted([str(x) for x in s], key=str) | |
| except Exception: | |
| # fallback if s is already a string like "{a,b}" | |
| return str(s) | |
| return "{" + ",".join(items) + "}" | |
| async def _process_aba_framework( | |
| text: str, | |
| enable_aba_plus: bool = False, | |
| ) -> dict: | |
| """ | |
| Core processing logic for ABA frameworks. | |
| Args: | |
| text: The uploaded file content as text | |
| enable_aba_plus: If True, compute ABA+ elements | |
| Returns: | |
| Complete response with before/after snapshots and all computations | |
| """ | |
| # === 1. Build original framework === | |
| base_framework = build_aba_framework_from_text(text) | |
| base_framework.generate_arguments() | |
| base_framework.generate_attacks() | |
| original_snapshot = _make_snapshot(base_framework) | |
| # --- Classical (argument-level) data --- | |
| original_arguments = [str(arg) for arg in sorted(base_framework.arguments, key=str)] | |
| original_attacks = [str(att) for att in sorted(base_framework.attacks, key=str)] | |
| original_reverse_attacks = [] | |
| # === 2. Transform framework === | |
| transformed_framework = deepcopy(base_framework).transform_aba() | |
| transformations = _detect_transformations(base_framework, transformed_framework) | |
| # --- Initialize containers --- | |
| original_assumption_sets = [] | |
| final_assumption_sets = [] | |
| original_aba_plus_attacks = [] | |
| final_aba_plus_attacks = [] | |
| original_reverse_attacks = [] | |
| final_reverse_attacks = [] | |
| warnings = [] | |
| # === 3. ABA+ computations === | |
| if enable_aba_plus: | |
| # --- ABA+ on original framework --- | |
| fw_plus_original = prepare_aba_plus_framework(deepcopy(base_framework)) | |
| fw_plus_original.generate_arguments() | |
| fw_plus_original.generate_attacks() | |
| fw_plus_original.make_aba_plus() | |
| original_assumption_sets = sorted( | |
| [_format_set(s) for s in getattr(fw_plus_original, "assumption_combinations", [])], | |
| key=lambda x: (len(x), x), | |
| ) | |
| original_aba_plus_attacks = [ | |
| f"{_format_set(src)} → {_format_set(dst)}" | |
| for (src, dst) in sorted( | |
| getattr(fw_plus_original, "normal_attacks", []), | |
| key=lambda p: (str(p[0]), str(p[1])), | |
| ) | |
| ] | |
| original_reverse_attacks = [ | |
| f"{_format_set(src)} → {_format_set(dst)}" | |
| for (src, dst) in sorted( | |
| getattr(fw_plus_original, "reverse_attacks", []), | |
| key=lambda p: (str(p[0]), str(p[1])), | |
| ) | |
| ] | |
| # --- Ensure transformed framework is consistent before ABA+ --- | |
| transformed_framework.generate_arguments() | |
| transformed_framework.generate_attacks() | |
| # --- Compute ABA+ on transformed framework --- | |
| fw_plus_transformed = prepare_aba_plus_framework(deepcopy(transformed_framework)) | |
| fw_plus_transformed.generate_arguments() | |
| fw_plus_transformed.generate_attacks() | |
| fw_plus_transformed.make_aba_plus() | |
| final_assumption_sets = sorted( | |
| [_format_set(s) for s in getattr(fw_plus_transformed, "assumption_combinations", [])], | |
| key=lambda x: (len(x), x), | |
| ) | |
| # Debug sanity checks | |
| print("DEBUG: fw_plus_transformed.assumptions =", getattr(fw_plus_transformed, "assumptions", [])) | |
| print("DEBUG: fw_plus_transformed.normal_attacks =", getattr(fw_plus_transformed, "normal_attacks", [])) | |
| final_aba_plus_attacks = [ | |
| f"{_format_set(src)} → {_format_set(dst)}" | |
| for (src, dst) in sorted( | |
| getattr(fw_plus_transformed, "normal_attacks", []), | |
| key=lambda p: (str(p[0]), str(p[1])), | |
| ) | |
| ] | |
| final_reverse_attacks = [ | |
| f"{_format_set(src)} → {_format_set(dst)}" | |
| for (src, dst) in sorted( | |
| getattr(fw_plus_transformed, "reverse_attacks", []), | |
| key=lambda p: (str(p[0]), str(p[1])), | |
| ) | |
| ] | |
| warnings = _validate_aba_plus_framework(fw_plus_transformed) | |
| else: | |
| warnings = _validate_framework(transformed_framework) | |
| # === 4. Classical ABA computations (arguments + attacks) === | |
| base_framework.generate_arguments() | |
| base_framework.generate_attacks() | |
| transformed_framework.generate_arguments() | |
| transformed_framework.generate_attacks() | |
| original_arguments = [str(arg) for arg in sorted(base_framework.arguments, key=str)] | |
| original_arguments_attacks = [str(att) for att in sorted(base_framework.attacks, key=str)] | |
| final_arguments = [str(arg) for arg in sorted(transformed_framework.arguments, key=str)] | |
| final_arguments_attacks = [str(att) for att in sorted(transformed_framework.attacks, key=str)] | |
| # === 5. Snapshots === | |
| original_snapshot = _make_snapshot(base_framework) | |
| final_snapshot = _make_snapshot(transformed_framework) | |
| # === 6. Build response with CORRECT structure === | |
| before_state = { | |
| "framework": original_snapshot.dict(), | |
| "arguments": original_arguments, | |
| "arguments_attacks": original_arguments_attacks, | |
| "argument_attacks": original_arguments_attacks, # same as arguments_attacks for classical ABA | |
| "assumption_set_attacks": original_aba_plus_attacks if enable_aba_plus else [], | |
| "reverse_attacks": original_reverse_attacks if enable_aba_plus else [], | |
| "assumption_sets": original_assumption_sets if enable_aba_plus else [], | |
| } | |
| after_state = { | |
| "framework": final_snapshot.dict(), | |
| "arguments": final_arguments, | |
| "arguments_attacks": final_arguments_attacks, | |
| "argument_attacks": final_arguments_attacks, # same as arguments_attacks for classical ABA | |
| "assumption_set_attacks": original_aba_plus_attacks if enable_aba_plus else [], | |
| "reverse_attacks": original_reverse_attacks if enable_aba_plus else [], | |
| "assumption_sets": original_assumption_sets if enable_aba_plus else [], | |
| } | |
| response = { | |
| "meta": { | |
| "request_id": f"req-{datetime.utcnow().timestamp()}", | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "transformed": any(t.get("applied", False) for t in [_transform_to_dict(t) for t in transformations]), | |
| "transformations_applied": [ | |
| t.get("step") for t in [_transform_to_dict(t) for t in transformations] if t.get("applied", False) | |
| ], | |
| "warnings": warnings, | |
| "errors": [], | |
| }, | |
| "before_transformation": before_state, | |
| "after_transformation": after_state, | |
| "transformations": [_transform_to_dict(t) for t in transformations], | |
| } | |
| return response | |
| def _detect_transformations( | |
| base_framework, | |
| transformed_framework, | |
| ) -> list: | |
| """ | |
| Detect and describe which transformations were applied. | |
| """ | |
| transformations = [] | |
| if transformed_framework.language == base_framework.language and \ | |
| transformed_framework.rules == base_framework.rules: | |
| # No transformation needed | |
| transformations.append({ | |
| "step": "none", | |
| "applied": False, | |
| "reason": "The framework was already non-circular and atomic.", | |
| "description": "No transformation applied.", | |
| "result_snapshot": None, | |
| }) | |
| return transformations | |
| # Determine transformation type | |
| was_circular = base_framework.is_aba_circular() | |
| was_atomic = base_framework.is_aba_atomic() | |
| step_name = "non_circular" if was_circular else "atomic" | |
| reason = "circular dependencies" if was_circular else "non-atomic rules" | |
| transformations.append({ | |
| "step": step_name, | |
| "applied": True, | |
| "reason": f"The framework contained {reason}.", | |
| "description": f"Transformed into a {step_name.replace('_', '-')} version.", | |
| "result_snapshot": _make_snapshot(transformed_framework), | |
| }) | |
| return transformations | |
| def _transform_to_dict(t): | |
| """Convert TransformationStep to dict if needed.""" | |
| if isinstance(t, dict): | |
| return t | |
| return { | |
| "step": t.step, | |
| "applied": t.applied, | |
| "reason": t.reason, | |
| "description": t.description, | |
| "result_snapshot": t.result_snapshot, | |
| } | |
| def _validate_framework(framework) -> list[str]: | |
| """ | |
| Validate framework and return any warnings. | |
| """ | |
| warnings = [] | |
| if hasattr(framework, "preferences") and framework.preferences: | |
| all_assumptions = {str(a) for a in framework.assumptions} | |
| pref_keys = {str(k) for k in framework.preferences.keys()} | |
| if not pref_keys.issubset(all_assumptions): | |
| warnings.append( | |
| "Incomplete preference relation: not all assumptions appear in the preference mapping." | |
| ) | |
| return warnings | |
| def _validate_aba_plus_framework(framework) -> list[str]: | |
| """ | |
| Validate ABA+ framework and return any warnings. | |
| """ | |
| return _validate_framework(framework) | |
| async def aba_upload(file: UploadFile = File(...)): | |
| """ | |
| Handle classical ABA framework generation. | |
| Returns: original & final frameworks with arguments and attacks (no ABA+ data) | |
| """ | |
| content = await file.read() | |
| text = content.decode("utf-8") | |
| return await _process_aba_framework(text, enable_aba_plus=False) | |
| async def aba_plus_upload(file: UploadFile = File(...)): | |
| """ | |
| Handle ABA+ framework generation. | |
| Returns: original & final frameworks with arguments, attacks, AND reverse_attacks for both | |
| """ | |
| content = await file.read() | |
| text = content.decode("utf-8") | |
| return await _process_aba_framework(text, enable_aba_plus=True) | |
| def list_aba_examples(): | |
| examples = [f.name for f in ABA_EXAMPLES_DIR.glob("*.txt")] | |
| return {"examples": examples} | |
| def get_aba_example(filename: str): | |
| file_path = ABA_EXAMPLES_DIR / filename | |
| if not file_path.exists() or not file_path.is_file(): | |
| return {"error": "File not found"} | |
| return FileResponse(file_path, media_type="text/plain", filename=filename) | |
| # --- Gradual semantics --- # | |
| # @app.post("/gradual", response_model=GradualOutput) | |
| # def compute_gradual(input_data: GradualInput): | |
| # """API endpoint to compute Weighted h-Categorizer samples and convex hull.""" | |
| # return compute_gradual_semantics( | |
| # A=input_data.A, | |
| # R=input_data.R, | |
| # n_samples=input_data.n_samples, | |
| # max_iter=input_data.max_iter | |
| # ) | |
| def compute_gradual(input_data: GradualInput): | |
| """ | |
| API endpoint to compute Weighted h-Categorizer samples | |
| and their convex hull (acceptability degree space). | |
| """ | |
| num_args, hull_volume, hull_area, hull_points, samples, axes = compute_gradual_space( | |
| num_args=input_data.num_args, | |
| R=input_data.R, | |
| n_samples=input_data.n_samples, | |
| axes=input_data.axes, | |
| controlled_args=input_data.controlled_args, | |
| ) | |
| return GradualOutput( | |
| num_args=num_args, | |
| hull_volume=hull_volume, | |
| hull_area=hull_area, | |
| hull_points=hull_points, | |
| samples=samples, | |
| axes=axes, | |
| ) | |
| def list_gradual_examples(): | |
| """ | |
| List all available gradual semantics example files. | |
| Each example must be a JSON file with structure: | |
| { | |
| # "args": ["A", "B", "C"], | |
| # "relations": [["A", "B"], ["B", "C"]] | |
| "num_args": 3, | |
| "R": [["A", "B"], ["B", "C"], ["C", "A"]], | |
| } | |
| """ | |
| if not GRADUAL_EXAMPLES_DIR.exists(): | |
| return {"examples": []} | |
| examples = [] | |
| for file in GRADUAL_EXAMPLES_DIR.glob("*.json"): | |
| examples.append({ | |
| "name": file.stem, | |
| "path": file.name, | |
| "content": None | |
| }) | |
| return {"examples": examples} | |
| def get_gradual_example(example_name: str): | |
| """ | |
| Return the content of a specific gradual example file. | |
| Example: GET /gradual-examples/simple.json | |
| """ | |
| file_path = GRADUAL_EXAMPLES_DIR / example_name | |
| if not file_path.exists(): | |
| raise HTTPException(status_code=404, detail="Example not found") | |
| try: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| content = json.load(f) | |
| return JSONResponse(content=content) | |
| except json.JSONDecodeError: | |
| raise HTTPException( | |
| status_code=400, detail="Invalid JSON format in example file") | |