import torch |
from tokenizers import Tokenizer |
from torch.utils.data import DataLoader |
import streamlit as st |
import base64 |
from model import CustomDataset, TransformerEncoder |
st.set_page_config(layout="wide",page_title="TeknoFest We Bears NLP Competition", page_icon="./media/3bears.ico") |
tag2id = {"O": 0, "olumsuz": 1, "nötr": 2, "olumlu": 3, "org": 4} |
id2tag = {value: key for key, value in tag2id.items()} |
device = torch.device('cpu') |
@st.cache_resource |
def load_model_to_cpu(_model, path="model.pth"): |
checkpoint = torch.load(path, map_location=torch.device('cpu')) |
_model.load_state_dict(checkpoint) |
return _model |
def get_base64(bin_file): |
with open(bin_file, 'rb') as f: |
data = f.read() |
return base64.b64encode(data).decode() |
def predict_fonk(model, device, example, tokenizer): |
model.to(device) |
model.eval() |
predictions = [] |
encodings_prdict = tokenizer.encode(example) |
predict_texts = [encodings_prdict.tokens] |
predict_input_ids = [encodings_prdict.ids] |
predict_attention_masks = [encodings_prdict.attention_mask] |
predict_token_type_ids = [encodings_prdict.type_ids] |
prediction_labels = [encodings_prdict.type_ids] |
predict_data = CustomDataset(predict_texts, predict_input_ids, predict_attention_masks, predict_token_type_ids, |
prediction_labels) |
predict_loader = DataLoader(predict_data, batch_size=1, shuffle=False) |
with torch.no_grad(): |
for dataset in predict_loader: |
batch_input_ids = dataset['input_ids'].to(device) |
batch_att_mask = dataset['attention_mask'].to(device) |
outputs = model(batch_input_ids, batch_att_mask) |
logits = outputs.view(-1, outputs.size(-1)) |
_, predicted = torch.max(logits, 1) |
predictions.append(predicted) |
results_list = [] |
entity_list = [] |
results_dict = {} |
trio = zip(predict_loader.dataset[0]["text"], predictions[0].tolist(), predict_attention_masks[0]) |
for i, (token, label, attention) in enumerate(trio): |
if attention != 0 and label != 0 and label !=4: |
for next_ones in predictions[0].tolist()[i+1:]: |
i+=1 |
if next_ones == 4: |
token = token +" "+ predict_loader.dataset[0]["text"][i] |
else:break |
if token not in entity_list: |
entity_list.append(token) |
results_list.append({"entity":token,"sentiment":id2tag.get(label)}) |
results_dict["entity_list"] = entity_list |
results_dict["results"] = results_list |
return results_dict |
model = TransformerEncoder() |
model = load_model_to_cpu(model, "model.pth") |
tokenizer = Tokenizer.from_file("tokenizer.json") |
background = get_base64("./media/background.jpg") |
with open("./style/style.css", "r") as style: |
css=f"""<style>{style.read().format(background=background)}</style>""" |
st.markdown(css, unsafe_allow_html=True) |
left, middle, right = st.columns([1,1.5,1]) |
main, comps , result = middle.tabs([" ", " ", " "]) |
with main: |
example = st.text_area(label='Metin Kutusu: ', placeholder="Lütfen Şikayet veya Yorum Metnini Buraya Yazın, daha sonra Predicte tıklayın") |
if st.button("Predict"): |
predict_list = predict_fonk(model=model, device=device, example=example, tokenizer=tokenizer) |
st.write(predict_list) |