File size: 11,998 Bytes
65e2c9b
 
75be629
 
 
 
ecc2c28
 
 
e2db3bf
ecc2c28
 
 
 
 
 
e2db3bf
 
ecc2c28
 
020da56
 
e2db3bf
75be629
 
 
e2db3bf
75be629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4a5925
 
 
 
 
 
 
 
 
 
 
75be629
 
e4a5925
75be629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1788c95
 
 
75be629
 
 
 
 
 
 
 
 
 
 
1788c95
 
 
 
 
 
 
 
 
 
 
 
 
 
75be629
 
1788c95
75be629
 
 
1788c95
 
75be629
 
 
1788c95
 
 
 
 
 
 
75be629
 
 
 
 
1788c95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75be629
 
1788c95
 
 
 
 
 
 
 
 
 
 
75be629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4a5925
d0d1f61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75be629
 
 
 
 
 
 
 
 
 
 
 
 
8c05719
 
 
 
 
 
e4a5925
 
 
 
 
 
75be629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4a5925
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import os
import json
import tempfile
import subprocess
from pathlib import Path

# 统一把 HOME 指到 /tmp
os.environ["HOME"] = "/tmp"
Path("/tmp").mkdir(parents=True, exist_ok=True)

# 再确保所有 streamlit 相关路径也指向 /tmp
os.environ["XDG_CACHE_HOME"] = "/tmp"
os.environ["STREAMLIT_CACHE_DIR"] = "/tmp"
os.environ["STREAMLIT_GLOBAL_DATA_DIR"] = "/tmp"
os.environ["STREAMLIT_RUNTIME_DIR"] = "/tmp"
os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit"
Path("/tmp/.streamlit").mkdir(parents=True, exist_ok=True)

# 可选:关掉使用统计
os.environ["STREAMLIT_BROWSER_GATHER_USAGE_STATS"] = "false"
os.environ.setdefault("STREAMLIT_SERVER_ENABLE_CORS", "false")
os.environ.setdefault("STREAMLIT_SERVER_ENABLE_XSRF_PROTECTION", "false")

import streamlit as st
import langextract as lx  # pip: langextract[openai]

# 你的原始页面设置
st.set_page_config(page_title="LangmyOCR (Streamlit)", layout="wide")
st.title("LangmyOCR: OCRmyPDF + LangExtract (Streamlit Demo)")
st.caption("先 OCR,后(可选)结构化抽取与交互式复核。数据仅用于会话处理。")

# ---------------- Utilities ----------------
def has_bin(name: str) -> bool:
    return subprocess.run(["bash", "-lc", f"command -v {name} >/dev/null 2>&1"]).returncode == 0

def run_ocr(pdf_file, langs: str, rotate_pages: bool, deskew: bool, clean: bool,
            optimize_level: int, force_ocr: bool, skip_text: bool, export_sidecar: bool):
    if pdf_file is None:
        st.error("请先上传 PDF。")
        return None, None, None

    if not has_bin("ocrmypdf"):
        st.error("系统未检测到 ocrmypdf,可检查 Docker/依赖安装。")
        return None, None, None

    # 修复:重置文件指针到开头,然后读取内容
    try:
        pdf_file.seek(0)  # 重置文件指针
        pdf_content = pdf_file.read()
        if not pdf_content:
            st.error("PDF 文件内容为空。")
            return None, None, None
    except Exception as e:
        st.error(f"读取 PDF 文件失败:{e}")
        return None, None, None

    # 保存上传文件到临时路径
    with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
        tmp.write(pdf_content)  # 使用读取到的内容
        in_path = Path(tmp.name)

    work = Path(tempfile.mkdtemp(prefix="ocr_"))
    out_pdf = work / "output_ocr.pdf"
    sidecar = work / "out.txt"

    cmd = ["ocrmypdf", "-l", langs, str(in_path), str(out_pdf)]
    # 选项(插入到命令靠前位置,保持可读)
    if rotate_pages: cmd.insert(1, "--rotate-pages")
    if deskew:       cmd.insert(1, "--deskew")
    if clean:        cmd.insert(1, "--clean")
    cmd[1:1] = ["--optimize", str(optimize_level)]
    if skip_text:    cmd.insert(1, "--skip-text")
    if force_ocr:    cmd.insert(1, "--force-ocr")
    if export_sidecar: cmd[1:1] = ["--sidecar", str(sidecar)]
    cmd[1:1] = ["--output-type", "pdfa"]  # 归档友好

    with st.status("正在执行 OCR …", expanded=False) as s:
        proc = subprocess.run(cmd, capture_output=True, text=True)
        if proc.returncode != 0:
            s.update(label="OCR 失败", state="error")
            st.error(f"OCR 失败(退出码 {proc.returncode})")
            st.code(proc.stderr[-2000:], language="bash")
            return None, None, None
        s.update(label="OCR 完成", state="complete")

    preview = ""
    sidecar_path = None
    if export_sidecar and sidecar.exists():
        sidecar_path = str(sidecar)
        try:
            preview = sidecar.read_text(encoding="utf-8", errors="ignore")[:3000]
        except Exception:
            preview = "(sidecar 文本预览读取失败)"

    return str(out_pdf), sidecar_path, preview


def run_extract(sidecar_text: str, provider: str, model_id: str, prompt: str):
    if not sidecar_text:
        return None, None, "没有可供抽取的文本。"
    if provider == "None":
        return None, None, "未选择模型,跳过抽取。"

    # 1) 读取 Key,并统一默认打开 fence_output
    fence_output = True               # << 对 Gemini 也打开
    use_schema_constraints = False    # 先不启 Schema(必要时再开)
    if provider == "Gemini":
        api_key = os.environ.get("LANGEXTRACT_API_KEY")
        if not api_key:
            return None, None, "未检测到 Gemini API Key(LANGEXTRACT_API_KEY)。"
    elif provider == "OpenAI":
        api_key = os.environ.get("OPENAI_API_KEY")
        if not api_key:
            return None, None, "未检测到 OpenAI API Key(OPENAI_API_KEY)。"
    else:
        return None, None, "未知的 provider。"

    # 2) 收紧提示语(覆盖面向法律的 schema),严格要求“只返回 JSON 数组”
    strict_prompt = (
        "You are an information extraction engine. "
        "Extract legal entities, events, relationships, and evidence anchors from the input text. "
        "Return ONLY a JSON array, no prose, no markdown, no comments. "
        "Schema per item: {"
        "\"class\": one of [\"party\",\"event\",\"date\",\"relation\",\"evidence\"], "
        "\"text\": string (exact span), "
        "\"attributes\": object (key-value), "
        "\"source_hint\": string (optional page/line) "
        "}."
    )

    # 3) 精简可运行的 few-shot(与法律场景贴近)
    examples = [
        lx.data.ExampleData(
            text="On 15 February 2022, Dr Gavin Soo completed a medicolegal report to Walker Law Group.",
            extractions=[
                lx.data.Extraction(
                    extraction_class="party",
                    extraction_text="Walker Law Group",
                    attributes={"role": "law_firm"},
                ),
                lx.data.Extraction(
                    extraction_class="event",
                    extraction_text="completed a medicolegal report",
                    attributes={"actor": "Dr Gavin Soo"},
                ),
                lx.data.Extraction(
                    extraction_class="date",
                    extraction_text="15 February 2022",
                    attributes={}
                ),
            ],
        )
    ]

    # 4) 先跑一次;若解析失败,再以更强硬提示重试一次
    work = Path(tempfile.mkdtemp(prefix="lx_"))
    jsonl_path = work / "extractions.jsonl"
    html_path = work / "review.html"
    raw_path1 = work / "raw_attempt1.txt"
    raw_path2 = work / "raw_attempt2.txt"

    def _try_extract(prompt_text):
        # LangExtract 没有公开 raw 输出参数,我们用 try/except 捕获并让其保存在日志(同时缩短输入验证)
        return lx.extract(
            text_or_documents=sidecar_text[:15000],  # 先限长,避免超长触发安全策略
            prompt_description=prompt_text.strip(),
            examples=examples,
            model_id=model_id.strip(),
            api_key=api_key,
            fence_output=fence_output,
            use_schema_constraints=use_schema_constraints,
        )

    with st.status("正在进行结构化抽取 …", expanded=False) as s:
        try:
            result = _try_extract(strict_prompt)
        except Exception as e1:
            # 第一次失败:很可能是返回了非 JSON。我们把提示再加强,强调 “only JSON array”
            hard_prompt = strict_prompt + " Output must be a compact JSON array. Do not include any other text."
            try:
                result = _try_extract(hard_prompt)
            except Exception as e2:
                s.update(label="抽取失败", state="error")
                return None, None, f"LangExtract 抽取失败:{e2}"

        # 保存结果并可视化
        try:
            lx.io.save_annotated_documents([result], output_name=str(jsonl_path))
            html_content = lx.visualize(str(jsonl_path))
            html_path.write_text(html_content, encoding="utf-8")
        except Exception as e:
            s.update(label="可视化失败", state="error")
            return None, None, f"可视化失败:{e}"

        s.update(label="抽取完成", state="complete")

    return str(html_path), str(jsonl_path), "抽取成功。"


# ---------------- UI ----------------
with st.sidebar:
    st.header("参数")

    # 用 form 把"上传 + 参数 + 提交"打包,避免按钮重跑导致 file_uploader 丢值
    with st.form("run_form", clear_on_submit=False):
        pdf = st.file_uploader("上传扫描 PDF", type=["pdf"], accept_multiple_files=False, key="pdf_uploader")

        langs = st.text_input("OCR 语言(Tesseract 语法)", value="eng+chi_sim")
        col_a, col_b, col_c = st.columns(3)
        with col_a:
            rotate_pages = st.checkbox("自动旋转校正", value=True)
        with col_b:
            deskew = st.checkbox("去偏斜", value=True)
        with col_c:
            clean = st.checkbox("清理底噪/污渍", value=True)

        optimize_level = st.select_slider("优化级别", options=[0,1,2], value=1)
        skip_text = st.checkbox("跳过已有文本层 (--skip-text)", value=True)
        force_ocr = st.checkbox("强制重做文本层 (--force-ocr) [谨慎]", value=False)
        export_sidecar = st.checkbox("导出 sidecar 文本", value=True)

        st.markdown("---")
        provider = st.selectbox("抽取提供方", ["None", "Gemini", "OpenAI"], index=0)
        model_id = st.text_input("模型 ID", value="gemini-2.5-flash")
        prompt = st.text_area(
            "抽取任务描述(建议按你的法律场景定制)",
            value=("Extract legal entities, events, relationships, and evidence anchors. "
                   "Return JSON objects with fields: {party, role, event, date, relation, citation, quote}. "
                   "Preserve exact source spans for traceability."),
            height=160,
        )

        submitted = st.form_submit_button("运行 OCR(+可选抽取)", type="primary")

    
col1, col2 = st.columns([1,1])
with col1:
    st.subheader("OCR 结果")
    ocr_pdf_slot = st.empty()
    sidecar_slot = st.empty()
    preview_slot = st.empty()

with col2:
    st.subheader("抽取与复核")
    html_slot = st.empty()
    jsonl_slot = st.empty()
    status_slot = st.empty()

# 辅助:显示文件已被接收(提交前就可见,便于确认)
if "pdf_uploader" in st.session_state and st.session_state["pdf_uploader"]:
    st.sidebar.success(f"已选择:{st.session_state['pdf_uploader'].name} "
                       f"({st.session_state['pdf_uploader'].size/1024:.1f} KB)")

if submitted:
    # 添加调试信息
    if pdf is None:
        st.error("PDF 为 None - 检查文件上传")
    else:
        st.info(f"PDF 文件信息:名称={pdf.name}, 大小={pdf.size} bytes")
    
    out_pdf, sidecar_path, preview = run_ocr(
        pdf, langs, rotate_pages, deskew, clean, optimize_level,
        force_ocr, skip_text, export_sidecar
    )
    if out_pdf:
        with open(out_pdf, "rb") as f:
            ocr_pdf_slot.download_button("下载 OCR 后 PDF", f, file_name="output_ocr.pdf")
    if sidecar_path:
        with open(sidecar_path, "rb") as f:
            sidecar_slot.download_button("下载 sidecar 文本", f, file_name="out.txt")
        preview_slot.text_area("sidecar 文本预览(前 3000 字)", value=preview, height=240)

    if sidecar_path and provider != "None":
        txt = Path(sidecar_path).read_text(encoding="utf-8", errors="ignore")
        html_path, jsonl_path, status = run_extract(txt, provider, model_id, prompt)
        status_slot.info(status)
        if html_path and Path(html_path).exists():
            html_content = Path(html_path).read_text(encoding="utf-8", errors="ignore")
            st.components.v1.html(html_content, height=650, scrolling=True)
        if jsonl_path and Path(jsonl_path).exists():
            with open(jsonl_path, "rb") as f:
                jsonl_slot.download_button("下载抽取结果 JSONL", f, file_name="extractions.jsonl")