demo2 / app.py
patent's picture
.
d4e1c74
import streamlit as st
import time
import requests
import os
import json
import glob
import re
import random
import difflib
from random import randrange
prefix_lst = [
"pgj_d_4096",
"pgj_d_2048",
"pgj_d_1024_v2",
"pgj_d_1024_layer_14",
"pgj_d_1024_layer_7",
"pgj_d_1024_layer_2",
"pgj_d_1024_layer_1" ]
model_names = {
prefix_lst[0]: 'PatentGPT-J-6B',
prefix_lst[1]: 'PatentGPT-J-1.6B',
prefix_lst[2]: 'PatentGPT-J-456M',
prefix_lst[3]: 'PatentGPT-J-279M',
prefix_lst[4]: 'PatentGPT-J-191M',
prefix_lst[5]: 'PatentGPT-J-128M',
prefix_lst[6]: 'PatentGPT-J-115M',}
# experiment 3
# folder = os.path.join('experiments', 'non_patent')
# id_to_scroll = 1 # which of the above to scroll through
# first_claim_only = True
#experiment 2
folder = os.path.join('experiments', 'ipg20220104_500')
#folder = "device_serve_results"
id_to_scroll = 1 # which of the above to scroll through
first_claim_only = False
# prefix_lst = ["my_gptj_6b_tpu_size_8", "pgj_d_4096", "pgj_d_2048", "pgj_d_1024_layer_14", "pgj_d_1024_layer_7", "pgj_d_1024_layer_2", "pgj_d_1024_layer_1"]
# #, "pgj_large", "pgj_medium", "pgj_small", ]
# # "pgj_d_1024_layer_14"
# experiment 1
# folder = os.path.join('experiments', 'ipg22_500')
# # (previous) folder = "eval_ipg22_500"
# id_to_scroll = 1 # which of the above to scroll through
# first_claim_only = True
select_lst = []
def handle_char_return(text):
if text == '(none)': # unicorn text
text == ''
return text
def calc_height(s):
return int(len(s) / 10 * 3) + 30
def remove_end_of_claim_text(gen_text):
tag = '<|end_of_claim|>'
pos = gen_text.find(tag)
if pos > 0:
gen_text = gen_text[:pos+len(tag)]
return gen_text
tag = '<|endoftext|>'
pos = gen_text.find(tag)
if pos > 0:
gen_text = gen_text[:pos+len(tag)]
return gen_text
def update_content():
#st.write("The value of the slider is:", st.session_state.myslider)
pass
def prepare_select_lst():
num_set = set()
fn_lst = glob.glob(os.path.join(folder, '*'))
for i, fn in enumerate(fn_lst):
for prefix in prefix_lst:
v = re.search('(.*?)%s\_(\d+\_\d+)\_(.*?)' % prefix, fn)
if v is None:
v = re.search('(.*?)%s\_(\w+\_\d+)\_(.*?)' % prefix, fn)
if v is None:
continue
v = v.group(2)
if first_claim_only:
if v.endswith('_1'):
num_set.add(v)
else:
num_set.add(v)
num_lst = list(num_set)
num_lst.sort()
select_lst = []
for i, num in enumerate(num_lst):
all_existed = True
for prefix in prefix_lst:
fn = os.path.join(folder, '%s_%s_forward.json' % (prefix, num))
if os.path.exists(fn) == False:
all_existed = False
break
if all_existed:
select_lst.append(num)
select_lst.sort()
return select_lst
def update_selected():
global select_lst
#st.write("The value of the slider is:", st.session_state.myselectbox)
#num = selected.replace(')', '').replace(' (claim ', '_')
selected = st.session_state.myselectbox
pick_and_load(select_lst, selected)
def pick_and_load(select_lst, selected=None):
if selected is None:
pick = random.randrange(len(select_lst))
st.session_state['picked_flag'] = pick
selected = select_lst[pick]
num = selected.replace(')', '').replace(' (claim ', '_')
st.session_state['num'] = num
prefix = "pgj_d_1024_v2" # size: 456M
base_fn = '%s_%s_forward.json' % (prefix, num)
full_fn = os.path.join(folder, base_fn)
with open(full_fn) as f:
result = json.loads(f.read())
print("Loaded: %s" % full_fn)
st.session_state['result'] = result
return pick, num, result
def main():
st.set_page_config( # Alternate names: setup_page, page, layout
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
page_title="Patent-GPT-J demo", # String or None. Strings get appended with "• Streamlit".
page_icon=None, # String, anything supported by st.image, or None.
)
st.subheader("PatentGPT-J Demo 2 (Autocomplete Effectiveness)")
st.text("Data coverage: ipg220104 (in 2022-01-04)")
if 'select_lst' not in st.session_state:
select_lst = prepare_select_lst()
st.session_state['select_lst'] = select_lst
else:
select_lst = st.session_state['select_lst']
if len(select_lst) == 0:
st.text('select_lst is empty')
return
show_patent_lst = [ s.replace('_', ' (claim ') + ')' for s in select_lst]
#pick = 0
if 'picked_flag' not in st.session_state:
pick, num, result = pick_and_load(select_lst)
else:
pick = st.session_state['picked_flag']
num = st.session_state['num']
result = st.session_state['result']
if st.button('Random pick'):
pick, num, result = pick_and_load(select_lst)
# to-do, on_change --> load the file
selected = st.selectbox("Choose a patent claim", show_patent_lst, index=pick, key='myselectbox', on_change=update_selected)
#st.text('Selected: %s' % num)
recv = result['recv']
lst = result['output']
input_tokens = result['input']
height = calc_height(recv['context'])
st.text_area('context:', recv['context'], height=height)
pos = st.slider("Token position", 0, len(lst), key="myslider", on_change=update_content)
prompt = ''
for i in range(pos+1):
prompt += input_tokens[i]['text']
height = calc_height(prompt)
st.text_area('prompt:', prompt, height=height)
ch = handle_char_return(lst[pos]['actual_next_token_text'])
st.text('actual_next_token_text: %s --> pick seq: %s (prob: %.2f) top 10 tokens:' % (ch, int(lst[pos]['actual_next_token_top_seq'])+1,
float(lst[pos]['actual_next_token_top_prob'])))
msg = ''
for i, v in enumerate(lst[pos]['top_n_lst']):
ch = handle_char_return(v['top_n_text'])
msg += '(%s)[%s](%.2f) ' % (i+1, ch, float(v['top_n_prob']))
if i == 4:
st.text(msg)
msg = ''
st.text(msg)
gen_text = lst[pos]['gen_text']
gen_text = remove_end_of_claim_text(gen_text)
height = calc_height(gen_text)
st.text_area('generated:', gen_text, height=height)
#st.text('gen_text: %s' % gen_text)
#st.text("done. ok.")
#st.text('result:\n%s' % result)
if __name__ == "__main__":
main()