File size: 5,973 Bytes
b263d05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
929938f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import pandas as pd
import json
import re
from json import loads, JSONDecodeError
import sys
import os
import ast
from util.vector_base import EmbeddingFunction, get_or_create_vector_base
from doubao_service import DouBaoService
from PROMPT_TEMPLATE import prompt_template
from util.Embeddings import TextEmb3LargeEmbedding
from langchain_core.documents import Document
from FlagEmbedding import FlagReranker
from retriever import retriever
import time
# from bm25s import BM25, tokenize
import contextlib
import io

import gradio as gr
import time

client = DouBaoService("DouBao128Pro")
embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
embedding = EmbeddingFunction(embeddingmodel)
safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding)

# reranker_model = FlagReranker(
#     'C://Users//Admin//Desktop//PDPO//NLL_LLM//model//bge-reranker-v2-m3',
#     use_fp16=True,
#     devices=["cpu"],   
# )

OPTIONS = ['AI Governance',
 'Data Accuracy',
 'Data Minimization & Purpose Limitation',
 'Data Retention',
 'Data Security',
 'Data Sharing',
 'Individual Rights',
 'Privacy by Design',
 'Transparency']


def format_model_output(raw_output):
    """
    处理模型输出:
    - 将 \n 转换为实际换行
    - 提取 ```json ``` 中的内容并格式化为可折叠的 JSON
    """
    formatted = raw_output.replace('\\n', '\n')  
    def replace_json(match):
        json_str = match.group(1).strip()
        try:
            json_obj = loads(json_str)
            return f"```json\n{json.dumps(json_obj, indent=2, ensure_ascii=False)}\n```"
        except JSONDecodeError:
            return match.group(0) 
  
    formatted = re.sub(r'```json\n?(.*?)\n?```', replace_json, formatted, flags=re.DOTALL)
    return ast.literal_eval(formatted) 

def model_predict(input_text, if_split_po, topk, selected_items):
    """
    selected_items: 用户选择的项目(可能是["All"]或具体PO)
    """
    requirement = input_text
    requirement = requirement.replace("\t", "").replace("\n", "").replace("\r", "")
    if "All" in selected_items:
        PO = OPTIONS
    else:
        PO = selected_items
    if topk:
        topk = int(topk)
    else:
        topk = 10
    final_result = retriever(
                requirement,
                PO,
                safeguard_vector_store,
                reranker_model=None,
                using_reranker=False,
                using_BM25=False,
                using_chroma=True,
                k=topk,
                if_split_po=if_split_po
            )
    mapping_safeguards = {}
    for safeguard in final_result:
        if safeguard[3] not in mapping_safeguards:
            mapping_safeguards[safeguard[3]] = []
        mapping_safeguards[safeguard[3]].append(
            {
                "Score": safeguard[0],
                "Safeguard Number": safeguard[1],
                "Safeguard Description": safeguard[2]
            }
        )
    prompt = prompt_template(requirement, mapping_safeguards)
    response = client.chat_complete(messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt},
    ])
    # return {"requirement": requirement, "safeguards": mapping_safeguards}
    print("requirement:", requirement)
    print("mapping safeguards:", mapping_safeguards)
    print("response:", response)
    return {"requirement": requirement, "safeguards": format_model_output(response)}

with gr.Blocks(title="New Law Landing") as demo:
    gr.Markdown("## 🏙️ New Law Landing")
    
    requirement = gr.Textbox(label="Input Requirements", placeholder="Example: Data Minimization Consent for incompatible purposes")
    details = gr.Textbox(label="Input Details", placeholder="Example: Require consent for...")
    
    # 修改为 Number 输入组件
    topk = gr.Number(
        label="Top K safeguards", 
        value=10,  
        precision=0,  
        minimum=1,  
        interactive=True
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            if_split_po = gr.Checkbox(
                label="If Split Privacy Objective", 
                value=True, 
                info="Recall K Safeguards for each Privacy Objective"
            )
        with gr.Column(scale=1):
            all_checkbox = gr.Checkbox(
                label="ALL Privacy Objective", 
                value=True, 
                info="No specific Privacy Objective is specified"
            )
        with gr.Column(scale=4):
            PO_checklist = gr.CheckboxGroup(
                label="Choose Privacy Objective",
                choices=OPTIONS,
                value=[],  
                interactive=True
            )

    submit_btn = gr.Button("Submit", variant="primary")
    result_output = gr.JSON(label="Related safeguards", open=True)
    

    def sync_checkboxes(selected_items, all_selected):
        if len(selected_items) > 0:
            return False
        return all_selected
    
    PO_checklist.change(
        fn=sync_checkboxes,
        inputs=[PO_checklist, all_checkbox],
        outputs=all_checkbox
    )
    
    def sync_all(selected_all, current_selection):
        if selected_all:
            return []
        return current_selection
    
    all_checkbox.change(
        fn=sync_all,
        inputs=[all_checkbox, PO_checklist],
        outputs=PO_checklist
    )
 
    def process_inputs(requirement, details, topk, if_split_po, all_selected, PO_selected):
        input_text = requirement + ": " + details
        if all_selected:
            return model_predict(input_text, if_split_po, int(topk), ["All"])
        else:
            return model_predict(input_text, if_split_po, int(topk), PO_selected)
    
    submit_btn.click(
        fn=process_inputs,
        inputs=[requirement, details, topk, if_split_po, all_checkbox, PO_checklist],
        outputs=[result_output]
    )

if __name__ == "__main__":
    demo.launch(share=True)