|
import torch |
|
from tokenizers import Tokenizer |
|
from torch.utils.data import DataLoader |
|
|
|
import streamlit as st |
|
import base64 |
|
from model import CustomDataset, TransformerEncoder |
|
|
|
st.set_page_config(layout="wide",page_title="TeknoFest We Bears NLP Competition", page_icon="./media/3bears.ico") |
|
|
|
tag2id = {"O": 0, "olumsuz": 1, "nötr": 2, "olumlu": 3, "org": 4} |
|
id2tag = {value: key for key, value in tag2id.items()} |
|
device = torch.device('cpu') |
|
|
|
@st.cache_resource |
|
def load_model_to_cpu(_model, path="model.pth"): |
|
checkpoint = torch.load(path, map_location=torch.device('cpu')) |
|
_model.load_state_dict(checkpoint) |
|
return _model |
|
|
|
def get_base64(bin_file): |
|
with open(bin_file, 'rb') as f: |
|
data = f.read() |
|
return base64.b64encode(data).decode() |
|
|
|
def predict_fonk(model, device, example, tokenizer): |
|
model.to(device) |
|
model.eval() |
|
predictions = [] |
|
|
|
encodings_prdict = tokenizer.encode(example) |
|
|
|
predict_texts = [encodings_prdict.tokens] |
|
predict_input_ids = [encodings_prdict.ids] |
|
predict_attention_masks = [encodings_prdict.attention_mask] |
|
predict_token_type_ids = [encodings_prdict.type_ids] |
|
prediction_labels = [encodings_prdict.type_ids] |
|
|
|
predict_data = CustomDataset(predict_texts, predict_input_ids, predict_attention_masks, predict_token_type_ids, |
|
prediction_labels) |
|
|
|
predict_loader = DataLoader(predict_data, batch_size=1, shuffle=False) |
|
|
|
with torch.no_grad(): |
|
for dataset in predict_loader: |
|
batch_input_ids = dataset['input_ids'].to(device) |
|
batch_att_mask = dataset['attention_mask'].to(device) |
|
|
|
|
|
|
|
outputs = model(batch_input_ids, batch_att_mask) |
|
logits = outputs.view(-1, outputs.size(-1)) |
|
_, predicted = torch.max(logits, 1) |
|
|
|
|
|
predictions.append(predicted) |
|
|
|
results_list = [] |
|
entity_list = [] |
|
results_dict = {} |
|
trio = zip(predict_loader.dataset[0]["text"], predictions[0].tolist(), predict_attention_masks[0]) |
|
|
|
for i, (token, label, attention) in enumerate(trio): |
|
if attention != 0 and label != 0 and label !=4: |
|
for next_ones in predictions[0].tolist()[i+1:]: |
|
i+=1 |
|
if next_ones == 4: |
|
token = token +" "+ predict_loader.dataset[0]["text"][i] |
|
else:break |
|
if token not in entity_list: |
|
entity_list.append(token) |
|
results_list.append({"entity":token,"sentiment":id2tag.get(label)}) |
|
|
|
|
|
results_dict["entity_list"] = entity_list |
|
results_dict["results"] = results_list |
|
|
|
|
|
return results_dict |
|
|
|
model = TransformerEncoder() |
|
model = load_model_to_cpu(model, "model.pth") |
|
tokenizer = Tokenizer.from_file("tokenizer.json") |
|
|
|
background = get_base64("./media/background.jpg") |
|
|
|
with open("./style/style.css", "r") as style: |
|
css=f"""<style>{style.read().format(background=background)}</style>""" |
|
st.markdown(css, unsafe_allow_html=True) |
|
|
|
left, middle, right = st.columns([1,1.5,1]) |
|
main, comps , result = middle.tabs([" ", " ", " "]) |
|
with main: |
|
example = st.text_area(label='Metin Kutusu: ', placeholder="Lütfen Şikayet veya Yorum Metnini Buraya Yazın, daha sonra Predicte tıklayın") |
|
|
|
if st.button("Predict"): |
|
predict_list = predict_fonk(model=model, device=device, example=example, tokenizer=tokenizer) |
|
|
|
st.write(predict_list) |
|
|