MKaan's picture
Update app.py
b7d25a0
import streamlit as st
from multiprocessing import Process
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
import torch
import pandas as pd
import json
import requests
import time
import os
model_name_or_directory = "MKaan/multilingual-cpv-sector-classifier"
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
config = AutoConfig.from_pretrained(model_name_or_directory)
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_directory, config=config)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
idx2cpv = pd.read_csv("idx2cpv.csv")
idx2cpv = dict(zip(idx2cpv.indexes, idx2cpv.sectors))
def get_result(input):
input_ids = tokenizer(input, return_tensors="pt").input_ids
output = model(input_ids)
pred = output.logits.argmax(dim=-1)
pred = pred.cpu().detach().numpy()[0]
return idx2cpv[pred]
if __name__ == "__main__":
st.title('Multilingual Sector Classifier 📄') #📊💼
st.subheader('Finds the correct sector for the given contract description')
st.markdown("Built by Mustafa Kaan Görgün, [Linkedin](https://www.linkedin.com/in/mustafa-kaan-görgün-a2461288/), [Model Card](https://huggingface.co/MKaan/multilingual-cpv-sector-classifier) ", unsafe_allow_html=True)
examples = pd.read_csv("examples.csv")
lang2example = dict(zip(examples.lang, examples.descr))
st.markdown(f'##### Try it now:')
#st.markdown(f'Choose a language in any of 22 languages')
input_lang = st.selectbox(
label="Choose a language from the list of 22 languages",
options=examples.lang,
index=5
)
input_text_1 = st.text_area(
label="Example description in choosen language",
value=lang2example[input_lang],
height=150,
max_chars=500
)
button1 = st.button('Run the example')
st.write("or")
#st.markdown('Write your own contract description in any of 104 languages that MBERT supports.')
input_text_2 = st.text_area(
label="Write your own contract description in any of 104 languages that MBERT supports.",
value="Your description comes here..",
height=100,
max_chars=500
)
button2 = st.button('Run your own')
st.markdown(f'##### Classified Sector: ')
if button1:
with st.spinner('In progress.......'):
sector_class = get_result(input_text_1)
#sector_class = input_text_1
st.success(sector_class)
if button2:
with st.spinner('In progress.......'):
sector_class = get_result(input_text_2)
#sector_class = input_text_2
st.success(sector_class)