import streamlit as st import torch from datasets import load_dataset from transformers import AutoTokenizer from transformers import AutoModelForSequenceClassification from transformers import pipeline # Load HUPD dataset dataset_dict = load_dataset('HUPD/hupd', name='sample', data_files="https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather", icpr_label=None, train_filing_start_date='2016-01-01', train_filing_end_date='2016-01-21', val_filing_start_date='2016-01-22', val_filing_end_date='2016-01-31', ) # Process data filtered_dataset = dataset_dict['validation'].filter(lambda e: e['decision'] == 'ACCEPTED' or e['decision'] == 'REJECTED') dataset = filtered_dataset.shuffle(seed=42).select(range(20)) dataset = dataset.sort("patent_number") # Create pipeline using model trainned on Colab model = torch.load("/workspaces/cs-gy-6613-project/patent_classification(1).pt", map_location=torch.device('cpu')) tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) def load_patent(): selected_application = dataset.select([applications[st.session_state.id]]) st.session_state.abstract = selected_application['abstract'][0] st.session_state.claims = selected_application['claims'][0] st.session_state.title = selected_application['title'][0] st.title("CS-GY-6613 Project Milestone 3") # List patent numbers for select box applications = {} for ds_index, example in enumerate(dataset): applications.update({example['patent_number']: ds_index }) st.selectbox("Select a patent application:", applications, on_change=load_patent, key="id") # Application title displayed for additional context only, not used with model st.text_area("Title", key="title", value=dataset[0]['title'], height=50) # Classifier input form with st.form('Input Form'): abstract = st.text_area("Abstract", key="abstract", value=dataset[0]['abstract'], height=200) claims = st.text_area("Claims", key="claims", value=dataset[0]['abstract'], height=200) submitted = st.form_submit_button("Get Patentability Score") if submitted: selected_application = dataset.select([applications[st.session_state.id]]) res = classifier(abstract, claims) if res[0]["label"] == 'LABEL_0': pred = "ACCEPTED" elif res[0]["label"] == 'LABEL_1': pred = "REJECTED" score = res[0]["score"] label = selected_application['decision'][0] result = st.markdown("This text was classified as **{}** with a confidence score of **{}**.".format(pred, score)) check = st.markdown("Actual Label: **{}**.".format(label))