patentability / patent_app.py
VarshithaChennamsetti
Update patent_app.py
0e78b28 unverified
# Import statements
import streamlit as st
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
from datasets import load_dataset
# Torch and torch dataloader
import torch
from torch.utils.data import DataLoader
st.title('Patentability Decision App')
# Input all validation patent files
dataset_dict = load_dataset('HUPD/hupd',
name='sample',
data_files="https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather",
icpr_label=None,
train_filing_start_date='2016-01-01',
train_filing_end_date='2016-01-21',
val_filing_start_date='2016-01-22',
val_filing_end_date='2016-01-31',
)
# Remove all untrained decisions
# Label-to-index mapping for the decision status field
decision_to_str = {'REJECTED': 0, 'ACCEPTED': 1, 'PENDING': 2, 'CONT-REJECTED': 3, 'CONT-ACCEPTED': 4, 'CONT-PENDING': 5}
# Helper function
def map_decision_to_string(example):
return {'decision': decision_to_str[example['decision']]}
# Re-labeling/mapping in validation set
val_set = dataset_dict['validation'].map(map_decision_to_string)
# Filtering only those patents that have decisions as accepted/rejected
val_set = val_set.filter(lambda e: e['decision'] <= 1)
# Display all patent numbers to select a file
patent_num = st.selectbox("Select a patent based on its number", val_set['patent_number'])
# Keeping the session state
if "button_clicked" not in st.session_state:
st.session_state.button_clicked = False
# The button was clicked
def callback():
st.session_state.button_clicked = True
# Get the abstract and claims data to predict
if patent_num and (st.button('Get Data to predict!', on_click = callback) or st.session_state.button_clicked):
# Display the abstract and claims
val_set = val_set.filter(lambda e: e['patent_number'] == patent_num)
abstract_text = st.text_area('Abstract', val_set['abstract'][0])
claims_text = st.text_area('Claims', val_set['claims'][0])
# Predict on those texts
if abstract_text and claims_text and st.button('Predict!'):
# Model/tokenizer name or path to finetuned model
model_name_or_path = './models/'
model_name = 'distilbert-base-uncased'
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Model
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
# Tokenize the validation dataset and pass it to the model for prediction
_SECTION_ = 'claims'
val_set = val_set.map(lambda e: tokenizer((e[_SECTION_]), truncation=True, padding='max_length'),batched=True)
val_set.set_format(type='torch', columns=['input_ids', 'attention_mask', 'decision'])
# Creating a dataloader and only passing one row
val_dataloader = DataLoader(val_set, batch_size=16)
batch = next(iter(val_dataloader))
inputs = (batch['input_ids'][0])
decisions = (batch['decision'][0])
# Predict
with torch.no_grad():
outputs = model(input_ids=inputs, labels=decisions).logits
# Display prediction
prediction = np.argmax(outputs, axis=-1).stride()[0] # prediction
value = {i for i in decision_to_str if decision_to_str[i]==prediction}
st.text('This is the predicted decision: ' + str(value))
# Patentability score
st.text('Probability that it will be rejected : ' + str(outputs[0][0].item() * 100))
st.text('Probability that it will be accepted : ' + str(outputs[0][1].item() * 100))