Spaces:
Sleeping
Sleeping
File size: 7,820 Bytes
8e1cd4f 42cb591 8e1cd4f 42cb591 8e1cd4f 42cb591 8e1cd4f 42cb591 8e1cd4f 3e3a4c6 8e1cd4f 42cb591 3e3a4c6 42cb591 8e1cd4f 42cb591 8e1cd4f 42cb591 8e1cd4f 42cb591 8e1cd4f 42cb591 8e1cd4f 42cb591 8e1cd4f 42cb591 8e1cd4f 42cb591 8e1cd4f 42cb591 8e1cd4f 42cb591 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import sqlite3
import gradio as gr
from hashlib import md5 as hash_algo
from re import match
from io import BytesIO
from pypdf import PdfReader
from llm_rs import AutoModel,SessionConfig,GenerationConfig,Precision
repo_name = "rustformers/mpt-7b-ggml"
file_name = "mpt-7b-instruct-q5_1-ggjt.bin"
script_env = 'prod'
session_config = SessionConfig(threads=2,batch_size=2)
model = AutoModel.from_pretrained(repo_name, model_file=file_name, session_config=session_config,verbose=True)
def process_stream(rules, log, temperature, top_p, top_k, max_new_tokens, seed):
con = sqlite3.connect("history.db")
cur = con.cursor()
instruction = ''
hashes = []
if type(rules) is not list:
rules = [rules]
for rule in rules:
data, hash = get_file_contents(rule)
instruction += data + '\n'
hashes.append(hash)
hashes.sort()
hashes = hash_algo(''.join(hashes).encode()).hexdigest()
largest = 0
lines = instruction.split('\r\n')
if len(lines) == 1:
lines = instruction.split('\n')
for line in lines:
m = match('^(\d+)\.', line)
if m != None:
num = int(line[m.start():m.end()-1])
if num > largest:
largest = num
instruction += str(largest + 1) + '. '
query, hash = get_file_contents(log)
hashes = hash_algo((hashes + hash).encode()).hexdigest()
instruction = instruction.replace('\r\r\n', '\n')
full_req = "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.\r\n\r\nQ: Read the rules stated below and check the queries for any violation. State the rules which are violated by a query (if any). Also suggest a possible remediation, if possible. Do not make any assumptions outside of the rules stated below.\r\n\r\n" + instruction + 'The queries are as follows:\r\n' + query + '\r\n \r\nA: '
full_req = full_req.replace('\r\n', '\n')
prompt=f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{full_req}
### Response:
Answer:"""
response = ""
row = cur.execute('SELECT response FROM queries WHERE hexdigest = ?', [hashes]).fetchone()
if row != None:
response += "Cached Result:\n" + row[0]
yield response
else:
if script_env != 'test':
generation_config = GenerationConfig(seed=seed,temperature=temperature,top_p=top_p,top_k=top_k,max_new_tokens=max_new_tokens)
streamer = model.stream(prompt=prompt,generation_config=generation_config)
for new_text in streamer:
response += new_text
yield response
else:
num = 0
while num < 100:
response += " " + str(num)
num += 1
yield response
cur.execute('INSERT INTO queries VALUES(?, ?)', (hashes, response))
con.commit()
cur.close()
con.close()
def get_file_contents(file):
data = None
byte_hash = ''
with open(file.name, 'rb') as f:
data = f.read()
byte_hash = hash_algo(data).hexdigest()
if file.name.endswith('.pdf'):
rdr = PdfReader(BytesIO(data))
data = ''
for page in rdr.pages:
data += page.extract_text()
else:
data = data.decode()
if file.name.endswith(".csv"):
data = data.replace(',', ' ')
return (data, byte_hash)
def upload_log_file(files):
file_paths = [file.name for file in files]
return file_paths
def upload_file(files):
file_paths = [file.name for file in files]
return file_paths
with gr.Blocks(
theme=gr.themes.Soft(),
css=".disclaimer {font-variant-caps: all-small-caps;}",
) as demo:
gr.Markdown(
"""<h1><center>Grid 5.0 Information Security Track</center></h1>
"""
)
rules = gr.File(file_count="multiple")
upload_button = gr.UploadButton("Click to upload a new Compliance Document", file_types=[".txt", ".pdf"], file_count="multiple")
upload_button.upload(upload_file, upload_button, rules)
with gr.Row():
with gr.Column():
log = gr.File()
upload_log_button = gr.UploadButton("Click to upload a log file", file_types=[".txt", ".csv", ".pdf"], file_count="multiple")
upload_log_button.upload(upload_log_file, upload_log_button, log)
with gr.Accordion("Advanced Options:", open=False):
with gr.Row():
with gr.Column():
with gr.Row():
temperature = gr.Slider(
label="Temperature",
value=0.8,
minimum=0.1,
maximum=1.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
with gr.Column():
with gr.Row():
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.95,
minimum=0.0,
maximum=1.0,
step=0.01,
interactive=True,
info=(
"Sample from the smallest possible set of tokens whose cumulative probability "
"exceeds top_p. Set to 1 to disable and sample from all tokens."
),
)
with gr.Column():
with gr.Row():
top_k = gr.Slider(
label="Top-k",
value=40,
minimum=5,
maximum=80,
step=1,
interactive=True,
info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
)
with gr.Column():
with gr.Row():
max_new_tokens = gr.Slider(
label="Maximum new tokens",
value=256,
minimum=0,
maximum=1024,
step=5,
interactive=True,
info="The maximum number of new tokens to generate",
)
with gr.Column():
with gr.Row():
seed = gr.Number(
label="Seed",
value=42,
interactive=True,
info="The seed to use for the generation",
precision=0
)
with gr.Row():
submit = gr.Button("Submit")
with gr.Row():
with gr.Box():
gr.Markdown("**Output**")
output_7b = gr.Markdown()
submit.click(
process_stream,
inputs=[rules, log, temperature, top_p, top_k, max_new_tokens,seed],
outputs=output_7b,
)
demo.queue(max_size=4, concurrency_count=1).launch(debug=True) |