|
import streamlit as st |
|
import time |
|
import requests |
|
|
|
import os |
|
import json |
|
import glob |
|
import re |
|
import random |
|
import difflib |
|
|
|
from random import randrange |
|
|
|
prefix_lst = [ |
|
"pgj_d_4096", |
|
"pgj_d_2048", |
|
"pgj_d_1024_v2", |
|
"pgj_d_1024_layer_14", |
|
"pgj_d_1024_layer_7", |
|
"pgj_d_1024_layer_2", |
|
"pgj_d_1024_layer_1" ] |
|
|
|
model_names = { |
|
prefix_lst[0]: 'PatentGPT-J-6B', |
|
prefix_lst[1]: 'PatentGPT-J-1.6B', |
|
prefix_lst[2]: 'PatentGPT-J-456M', |
|
prefix_lst[3]: 'PatentGPT-J-279M', |
|
prefix_lst[4]: 'PatentGPT-J-191M', |
|
prefix_lst[5]: 'PatentGPT-J-128M', |
|
prefix_lst[6]: 'PatentGPT-J-115M',} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
folder = os.path.join('experiments', 'ipg20220104_500') |
|
|
|
id_to_scroll = 1 |
|
first_claim_only = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
select_lst = [] |
|
|
|
def handle_char_return(text): |
|
if text == '(none)': |
|
text == '' |
|
|
|
return text |
|
|
|
def calc_height(s): |
|
return int(len(s) / 10 * 3) + 30 |
|
|
|
def remove_end_of_claim_text(gen_text): |
|
tag = '<|end_of_claim|>' |
|
pos = gen_text.find(tag) |
|
if pos > 0: |
|
gen_text = gen_text[:pos+len(tag)] |
|
return gen_text |
|
|
|
tag = '<|endoftext|>' |
|
pos = gen_text.find(tag) |
|
if pos > 0: |
|
gen_text = gen_text[:pos+len(tag)] |
|
|
|
return gen_text |
|
|
|
def update_content(): |
|
|
|
pass |
|
|
|
def prepare_select_lst(): |
|
num_set = set() |
|
fn_lst = glob.glob(os.path.join(folder, '*')) |
|
for i, fn in enumerate(fn_lst): |
|
for prefix in prefix_lst: |
|
v = re.search('(.*?)%s\_(\d+\_\d+)\_(.*?)' % prefix, fn) |
|
if v is None: |
|
v = re.search('(.*?)%s\_(\w+\_\d+)\_(.*?)' % prefix, fn) |
|
if v is None: |
|
continue |
|
|
|
v = v.group(2) |
|
if first_claim_only: |
|
if v.endswith('_1'): |
|
num_set.add(v) |
|
else: |
|
num_set.add(v) |
|
|
|
num_lst = list(num_set) |
|
num_lst.sort() |
|
|
|
select_lst = [] |
|
for i, num in enumerate(num_lst): |
|
all_existed = True |
|
for prefix in prefix_lst: |
|
fn = os.path.join(folder, '%s_%s_forward.json' % (prefix, num)) |
|
if os.path.exists(fn) == False: |
|
all_existed = False |
|
break |
|
if all_existed: |
|
select_lst.append(num) |
|
select_lst.sort() |
|
|
|
return select_lst |
|
|
|
def update_selected(): |
|
global select_lst |
|
|
|
|
|
|
|
selected = st.session_state.myselectbox |
|
pick_and_load(select_lst, selected) |
|
|
|
def pick_and_load(select_lst, selected=None): |
|
if selected is None: |
|
pick = random.randrange(len(select_lst)) |
|
st.session_state['picked_flag'] = pick |
|
selected = select_lst[pick] |
|
num = selected.replace(')', '').replace(' (claim ', '_') |
|
st.session_state['num'] = num |
|
|
|
prefix = "pgj_d_1024_v2" |
|
base_fn = '%s_%s_forward.json' % (prefix, num) |
|
full_fn = os.path.join(folder, base_fn) |
|
with open(full_fn) as f: |
|
result = json.loads(f.read()) |
|
print("Loaded: %s" % full_fn) |
|
st.session_state['result'] = result |
|
|
|
return pick, num, result |
|
|
|
def main(): |
|
st.set_page_config( |
|
layout="wide", |
|
initial_sidebar_state="auto", |
|
page_title="Patent-GPT-J demo", |
|
page_icon=None, |
|
) |
|
st.subheader("PatentGPT-J Demo 2 (Autocomplete Effectiveness)") |
|
st.text("Data coverage: ipg220104 (in 2022-01-04)") |
|
|
|
if 'select_lst' not in st.session_state: |
|
select_lst = prepare_select_lst() |
|
st.session_state['select_lst'] = select_lst |
|
else: |
|
select_lst = st.session_state['select_lst'] |
|
|
|
if len(select_lst) == 0: |
|
st.text('select_lst is empty') |
|
return |
|
|
|
show_patent_lst = [ s.replace('_', ' (claim ') + ')' for s in select_lst] |
|
|
|
|
|
if 'picked_flag' not in st.session_state: |
|
pick, num, result = pick_and_load(select_lst) |
|
else: |
|
pick = st.session_state['picked_flag'] |
|
num = st.session_state['num'] |
|
result = st.session_state['result'] |
|
|
|
if st.button('Random pick'): |
|
pick, num, result = pick_and_load(select_lst) |
|
|
|
|
|
selected = st.selectbox("Choose a patent claim", show_patent_lst, index=pick, key='myselectbox', on_change=update_selected) |
|
|
|
|
|
recv = result['recv'] |
|
lst = result['output'] |
|
input_tokens = result['input'] |
|
|
|
height = calc_height(recv['context']) |
|
st.text_area('context:', recv['context'], height=height) |
|
|
|
pos = st.slider("Token position", 0, len(lst), key="myslider", on_change=update_content) |
|
prompt = '' |
|
for i in range(pos+1): |
|
prompt += input_tokens[i]['text'] |
|
height = calc_height(prompt) |
|
st.text_area('prompt:', prompt, height=height) |
|
|
|
ch = handle_char_return(lst[pos]['actual_next_token_text']) |
|
st.text('actual_next_token_text: %s --> pick seq: %s (prob: %.2f) top 10 tokens:' % (ch, int(lst[pos]['actual_next_token_top_seq'])+1, |
|
float(lst[pos]['actual_next_token_top_prob']))) |
|
|
|
msg = '' |
|
for i, v in enumerate(lst[pos]['top_n_lst']): |
|
ch = handle_char_return(v['top_n_text']) |
|
msg += '(%s)[%s](%.2f) ' % (i+1, ch, float(v['top_n_prob'])) |
|
if i == 4: |
|
st.text(msg) |
|
msg = '' |
|
st.text(msg) |
|
|
|
gen_text = lst[pos]['gen_text'] |
|
gen_text = remove_end_of_claim_text(gen_text) |
|
|
|
height = calc_height(gen_text) |
|
st.text_area('generated:', gen_text, height=height) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|