Spaces:
Sleeping
Sleeping
import pandas as pd | |
import json | |
import re | |
from json import loads, JSONDecodeError | |
import sys | |
import os | |
import ast | |
from util.vector_base import EmbeddingFunction, get_or_create_vector_base | |
from doubao_service import DouBaoService | |
from PROMPT_TEMPLATE import prompt_template | |
from util.Embeddings import TextEmb3LargeEmbedding | |
from langchain_core.documents import Document | |
from FlagEmbedding import FlagReranker | |
from retriever import retriever | |
import time | |
# from bm25s import BM25, tokenize | |
import contextlib | |
import io | |
import gradio as gr | |
import time | |
client = DouBaoService("DouBao128Pro") | |
embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58) | |
embedding = EmbeddingFunction(embeddingmodel) | |
safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding) | |
# reranker_model = FlagReranker( | |
# 'C://Users//Admin//Desktop//PDPO//NLL_LLM//model//bge-reranker-v2-m3', | |
# use_fp16=True, | |
# devices=["cpu"], | |
# ) | |
OPTIONS = ['AI Governance', | |
'Data Accuracy', | |
'Data Minimization & Purpose Limitation', | |
'Data Retention', | |
'Data Security', | |
'Data Sharing', | |
'Individual Rights', | |
'Privacy by Design', | |
'Transparency'] | |
def format_model_output(raw_output): | |
""" | |
处理模型输出: | |
- 将 \n 转换为实际换行 | |
- 提取 ```json ``` 中的内容并格式化为可折叠的 JSON | |
""" | |
formatted = raw_output.replace('\\n', '\n') | |
def replace_json(match): | |
json_str = match.group(1).strip() | |
try: | |
json_obj = loads(json_str) | |
return f"```json\n{json.dumps(json_obj, indent=2, ensure_ascii=False)}\n```" | |
except JSONDecodeError: | |
return match.group(0) | |
formatted = re.sub(r'```json\n?(.*?)\n?```', replace_json, formatted, flags=re.DOTALL) | |
return ast.literal_eval(formatted) | |
def model_predict(input_text, if_split_po, topk, selected_items): | |
""" | |
selected_items: 用户选择的项目(可能是["All"]或具体PO) | |
""" | |
requirement = input_text | |
requirement = requirement.replace("\t", "").replace("\n", "").replace("\r", "") | |
if "All" in selected_items: | |
PO = OPTIONS | |
else: | |
PO = selected_items | |
if topk: | |
topk = int(topk) | |
else: | |
topk = 10 | |
final_result = retriever( | |
requirement, | |
PO, | |
safeguard_vector_store, | |
reranker_model=None, | |
using_reranker=False, | |
using_BM25=False, | |
using_chroma=True, | |
k=topk, | |
if_split_po=if_split_po | |
) | |
mapping_safeguards = {} | |
for safeguard in final_result: | |
if safeguard[3] not in mapping_safeguards: | |
mapping_safeguards[safeguard[3]] = [] | |
mapping_safeguards[safeguard[3]].append( | |
{ | |
"Score": safeguard[0], | |
"Safeguard Number": safeguard[1], | |
"Safeguard Description": safeguard[2] | |
} | |
) | |
prompt = prompt_template(requirement, mapping_safeguards) | |
response = client.chat_complete(messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt}, | |
]) | |
# return {"requirement": requirement, "safeguards": mapping_safeguards} | |
print("requirement:", requirement) | |
print("mapping safeguards:", mapping_safeguards) | |
print("response:", response) | |
return {"requirement": requirement, "safeguards": format_model_output(response)} | |
with gr.Blocks(title="New Law Landing") as demo: | |
gr.Markdown("## 🏙️ New Law Landing") | |
requirement = gr.Textbox(label="Input Requirements", placeholder="Example: Data Minimization Consent for incompatible purposes") | |
details = gr.Textbox(label="Input Details", placeholder="Example: Require consent for...") | |
# 修改为 Number 输入组件 | |
topk = gr.Number( | |
label="Top K safeguards", | |
value=10, | |
precision=0, | |
minimum=1, | |
interactive=True | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
if_split_po = gr.Checkbox( | |
label="If Split Privacy Objective", | |
value=True, | |
info="Recall K Safeguards for each Privacy Objective" | |
) | |
with gr.Column(scale=1): | |
all_checkbox = gr.Checkbox( | |
label="ALL Privacy Objective", | |
value=True, | |
info="No specific Privacy Objective is specified" | |
) | |
with gr.Column(scale=4): | |
PO_checklist = gr.CheckboxGroup( | |
label="Choose Privacy Objective", | |
choices=OPTIONS, | |
value=[], | |
interactive=True | |
) | |
submit_btn = gr.Button("Submit", variant="primary") | |
result_output = gr.JSON(label="Related safeguards", open=True) | |
def sync_checkboxes(selected_items, all_selected): | |
if len(selected_items) > 0: | |
return False | |
return all_selected | |
PO_checklist.change( | |
fn=sync_checkboxes, | |
inputs=[PO_checklist, all_checkbox], | |
outputs=all_checkbox | |
) | |
def sync_all(selected_all, current_selection): | |
if selected_all: | |
return [] | |
return current_selection | |
all_checkbox.change( | |
fn=sync_all, | |
inputs=[all_checkbox, PO_checklist], | |
outputs=PO_checklist | |
) | |
def process_inputs(requirement, details, topk, if_split_po, all_selected, PO_selected): | |
input_text = requirement + ": " + details | |
if all_selected: | |
return model_predict(input_text, if_split_po, int(topk), ["All"]) | |
else: | |
return model_predict(input_text, if_split_po, int(topk), PO_selected) | |
submit_btn.click( | |
fn=process_inputs, | |
inputs=[requirement, details, topk, if_split_po, all_checkbox, PO_checklist], | |
outputs=[result_output] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |