Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from datasets import load_dataset | |
from transformers import AutoTokenizer | |
from transformers import AutoModelForSequenceClassification | |
from transformers import pipeline | |
# Load HUPD dataset | |
dataset_dict = load_dataset( | |
"HUPD/hupd", | |
name="sample", | |
data_files="https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather", | |
icpr_label=None, | |
train_filing_start_date="2016-01-01", | |
train_filing_end_date="2016-01-21", | |
val_filing_start_date="2016-01-22", | |
val_filing_end_date="2016-01-31", | |
) | |
# Process data | |
filtered_dataset = dataset_dict["validation"].filter( | |
lambda e: e["decision"] == "ACCEPTED" or e["decision"] == "REJECTED" | |
) | |
dataset = filtered_dataset.shuffle(seed=42).select(range(20)) | |
dataset = dataset.sort("patent_number") | |
# Create pipeline using model trainned on Colab | |
model = torch.load("patent_classifier_v2.pt", map_location=torch.device("cpu")) | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) | |
def load_patent(): | |
selected_application = dataset.select([applications[st.session_state.id]]) | |
st.session_state.abstract = selected_application["abstract"][0] | |
st.session_state.claims = selected_application["claims"][0] | |
st.session_state.title = selected_application["title"][0] | |
st.title("CS-GY-6613 Project Milestone 3") | |
# List patent numbers for select box | |
applications = {} | |
for ds_index, example in enumerate(dataset): | |
applications.update({example["patent_number"]: ds_index}) | |
st.selectbox( | |
"Select a patent application:", applications, on_change=load_patent, key="id" | |
) | |
# Application title displayed for additional context only, not used with model | |
st.text_area("Title", key="title", value=dataset[0]["title"], height=50) | |
# Classifier input form | |
with st.form("Input Form"): | |
abstract = st.text_area( | |
"Abstract", key="abstract", value=dataset[0]["abstract"], height=200 | |
) | |
claims = st.text_area( | |
"Claims", key="claims", value=dataset[0]["abstract"], height=200 | |
) | |
submitted = st.form_submit_button("Get Patentability Score") | |
if submitted: | |
selected_application = dataset.select([applications[st.session_state.id]]) | |
res = classifier(abstract, claims) | |
if res[0]["label"] == "LABEL_0": | |
pred = "ACCEPTED" | |
elif res[0]["label"] == "LABEL_1": | |
pred = "REJECTED" | |
score = res[0]["score"] | |
label = selected_application["decision"][0] | |
result = st.markdown( | |
"This text was classified as **{}** with a confidence score of **{}**.".format( | |
pred, score | |
) | |
) | |