File size: 2,472 Bytes
edd4fd8
ce76bed
 
e3677e5
 
ce76bed
edd4fd8
a505b19
ce76bed
 
 
 
 
a505b19
 
2d7820e
a505b19
 
 
 
 
 
 
 
 
 
 
 
 
 
edd4fd8
 
 
 
 
 
 
 
 
ce76bed
edd4fd8
 
 
ce76bed
edd4fd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce76bed
 
 
 
 
e3677e5
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import csv
from io import BytesIO
from pathlib import Path
from sys import stderr
from traceback import print_exception
from zipfile import BadZipFile, ZipFile

import requests
from yaml import safe_load

CURRENT_DIR = Path(__file__).parent

_PROMPTS = safe_load(CURRENT_DIR.joinpath("prompts.yaml").read_text())


def fetch_task_attachment(api_url: str, task_id: str) -> tuple[bytes, str]:
    """
    Returns (file_bytes, content_type) or (b'', '') if no attachment found.
    Follows any redirect the endpoint issues.
    """
    url = f"{api_url}/files/{task_id}"
    try:
        r = requests.get(url, timeout=15, allow_redirects=True)
    except requests.RequestException as e:
        print(f"[DEBUG] GET {url} failed → {e}")
        return b"", ""
    if r.status_code != 200:
        print(f"[DEBUG] GET {url}{r.status_code}")
        return b"", ""
    return r.content, r.headers.get("content-type", "").lower()


def sniff_excel_type(blob: bytes) -> str:
    """
    Return one of 'xlsx', 'xls', 'csv', or '' (unknown) given raw bytes.
    """
    # 1️⃣ XLSX / XLSM / ODS  (ZIP container)
    if blob[:4] == b"PK\x03\x04":
        try:
            with ZipFile(BytesIO(blob)) as zf:
                names = set(zf.namelist())
                if {"xl/workbook.xml", "[Content_Types].xml"} & names:
                    return "xlsx"
        except BadZipFile:
            pass  # fall through

    # 2️⃣ Legacy XLS (OLE Compound File)
    if blob[:8] == b"\xd0\xcf\x11\xe0\xa1\xb1\x1a\xe1":
        return "xls"

    # 3️⃣ Text-like -> CSV/TSV
    try:
        sample = blob[:1024].decode("utf-8", "ignore")
        first_line = sample.splitlines()[0]
        if any(sep in first_line for sep in (",", ";", "\t")):
            # Confirm via csv.Sniffer to avoid random text
            csv.Sniffer().sniff(sample)
            return "csv"
    except (UnicodeDecodeError, csv.Error):
        pass

    return ""


def get_prompt(prompt_key: str, **kwargs: str) -> str:
    """Get a prompt by key and fill in placeholders via `.format(**kwargs)`"""
    return _PROMPTS[prompt_key].format(**kwargs)


def print_debug_trace(err: Exception, label: str = "") -> None:
    """
    Print the full stack trace of `err` to STDERR so it shows up in HF logs.
    """
    banner = f"[TRACE {label}]" if label else "[TRACE]"
    print(banner, file=stderr)
    print_exception(type(err), err, err.__traceback__, file=stderr)
    print("-" * 60, file=stderr)