NLL_Interface / interface.py
bytedancerneat's picture
Update interface.py
b263d05 verified
import pandas as pd
import json
import re
from json import loads, JSONDecodeError
import sys
import os
import ast
from util.vector_base import EmbeddingFunction, get_or_create_vector_base
from doubao_service import DouBaoService
from PROMPT_TEMPLATE import prompt_template
from util.Embeddings import TextEmb3LargeEmbedding
from langchain_core.documents import Document
from FlagEmbedding import FlagReranker
from retriever import retriever
import time
# from bm25s import BM25, tokenize
import contextlib
import io
import gradio as gr
import time
client = DouBaoService("DouBao128Pro")
embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
embedding = EmbeddingFunction(embeddingmodel)
safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding)
# reranker_model = FlagReranker(
# 'C://Users//Admin//Desktop//PDPO//NLL_LLM//model//bge-reranker-v2-m3',
# use_fp16=True,
# devices=["cpu"],
# )
OPTIONS = ['AI Governance',
'Data Accuracy',
'Data Minimization & Purpose Limitation',
'Data Retention',
'Data Security',
'Data Sharing',
'Individual Rights',
'Privacy by Design',
'Transparency']
def format_model_output(raw_output):
"""
处理模型输出:
- 将 \n 转换为实际换行
- 提取 ```json ``` 中的内容并格式化为可折叠的 JSON
"""
formatted = raw_output.replace('\\n', '\n')
def replace_json(match):
json_str = match.group(1).strip()
try:
json_obj = loads(json_str)
return f"```json\n{json.dumps(json_obj, indent=2, ensure_ascii=False)}\n```"
except JSONDecodeError:
return match.group(0)
formatted = re.sub(r'```json\n?(.*?)\n?```', replace_json, formatted, flags=re.DOTALL)
return ast.literal_eval(formatted)
def model_predict(input_text, if_split_po, topk, selected_items):
"""
selected_items: 用户选择的项目(可能是["All"]或具体PO)
"""
requirement = input_text
requirement = requirement.replace("\t", "").replace("\n", "").replace("\r", "")
if "All" in selected_items:
PO = OPTIONS
else:
PO = selected_items
if topk:
topk = int(topk)
else:
topk = 10
final_result = retriever(
requirement,
PO,
safeguard_vector_store,
reranker_model=None,
using_reranker=False,
using_BM25=False,
using_chroma=True,
k=topk,
if_split_po=if_split_po
)
mapping_safeguards = {}
for safeguard in final_result:
if safeguard[3] not in mapping_safeguards:
mapping_safeguards[safeguard[3]] = []
mapping_safeguards[safeguard[3]].append(
{
"Score": safeguard[0],
"Safeguard Number": safeguard[1],
"Safeguard Description": safeguard[2]
}
)
prompt = prompt_template(requirement, mapping_safeguards)
response = client.chat_complete(messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
])
# return {"requirement": requirement, "safeguards": mapping_safeguards}
print("requirement:", requirement)
print("mapping safeguards:", mapping_safeguards)
print("response:", response)
return {"requirement": requirement, "safeguards": format_model_output(response)}
with gr.Blocks(title="New Law Landing") as demo:
gr.Markdown("## 🏙️ New Law Landing")
requirement = gr.Textbox(label="Input Requirements", placeholder="Example: Data Minimization Consent for incompatible purposes")
details = gr.Textbox(label="Input Details", placeholder="Example: Require consent for...")
# 修改为 Number 输入组件
topk = gr.Number(
label="Top K safeguards",
value=10,
precision=0,
minimum=1,
interactive=True
)
with gr.Row():
with gr.Column(scale=1):
if_split_po = gr.Checkbox(
label="If Split Privacy Objective",
value=True,
info="Recall K Safeguards for each Privacy Objective"
)
with gr.Column(scale=1):
all_checkbox = gr.Checkbox(
label="ALL Privacy Objective",
value=True,
info="No specific Privacy Objective is specified"
)
with gr.Column(scale=4):
PO_checklist = gr.CheckboxGroup(
label="Choose Privacy Objective",
choices=OPTIONS,
value=[],
interactive=True
)
submit_btn = gr.Button("Submit", variant="primary")
result_output = gr.JSON(label="Related safeguards", open=True)
def sync_checkboxes(selected_items, all_selected):
if len(selected_items) > 0:
return False
return all_selected
PO_checklist.change(
fn=sync_checkboxes,
inputs=[PO_checklist, all_checkbox],
outputs=all_checkbox
)
def sync_all(selected_all, current_selection):
if selected_all:
return []
return current_selection
all_checkbox.change(
fn=sync_all,
inputs=[all_checkbox, PO_checklist],
outputs=PO_checklist
)
def process_inputs(requirement, details, topk, if_split_po, all_selected, PO_selected):
input_text = requirement + ": " + details
if all_selected:
return model_predict(input_text, if_split_po, int(topk), ["All"])
else:
return model_predict(input_text, if_split_po, int(topk), PO_selected)
submit_btn.click(
fn=process_inputs,
inputs=[requirement, details, topk, if_split_po, all_checkbox, PO_checklist],
outputs=[result_output]
)
if __name__ == "__main__":
demo.launch(share=True)