File size: 3,492 Bytes
110d80a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import torch
from tokenizers import Tokenizer
from torch.utils.data import DataLoader
import streamlit as st
import base64
from model import CustomDataset, TransformerEncoder
st.set_page_config(layout="wide",page_title="TeknoFest We Bears NLP Competition", page_icon="./media/3bears.ico")
tag2id = {"O": 0, "olumsuz": 1, "nötr": 2, "olumlu": 3, "org": 4}
id2tag = {value: key for key, value in tag2id.items()}
device = torch.device('cpu')
@st.cache_resource
def load_model_to_cpu(_model, path="model.pth"):
checkpoint = torch.load(path, map_location=torch.device('cpu'))
_model.load_state_dict(checkpoint)
return _model
def get_base64(bin_file):
with open(bin_file, 'rb') as f:
data = f.read()
return base64.b64encode(data).decode()
def predict_fonk(model, device, example, tokenizer):
model.to(device)
model.eval()
predictions = []
encodings_prdict = tokenizer.encode(example)
predict_texts = [encodings_prdict.tokens]
predict_input_ids = [encodings_prdict.ids]
predict_attention_masks = [encodings_prdict.attention_mask]
predict_token_type_ids = [encodings_prdict.type_ids]
prediction_labels = [encodings_prdict.type_ids]
predict_data = CustomDataset(predict_texts, predict_input_ids, predict_attention_masks, predict_token_type_ids,
prediction_labels)
predict_loader = DataLoader(predict_data, batch_size=1, shuffle=False)
with torch.no_grad():
for dataset in predict_loader:
batch_input_ids = dataset['input_ids'].to(device)
batch_att_mask = dataset['attention_mask'].to(device)
outputs = model(batch_input_ids, batch_att_mask)
logits = outputs.view(-1, outputs.size(-1)) # Flatten the outputs
_, predicted = torch.max(logits, 1)
# Ignore padding tokens for predictions
predictions.append(predicted)
results_list = []
entity_list = []
results_dict = {}
trio = zip(predict_loader.dataset[0]["text"], predictions[0].tolist(), predict_attention_masks[0])
for i, (token, label, attention) in enumerate(trio):
if attention != 0 and label != 0 and label !=4:
for next_ones in predictions[0].tolist()[i+1:]:
i+=1
if next_ones == 4:
token = token +" "+ predict_loader.dataset[0]["text"][i]
else:break
if token not in entity_list:
entity_list.append(token)
results_list.append({"entity":token,"sentiment":id2tag.get(label)})
results_dict["entity_list"] = entity_list
results_dict["results"] = results_list
return results_dict
model = TransformerEncoder()
model = load_model_to_cpu(model, "model.pth")
tokenizer = Tokenizer.from_file("tokenizer.json")
background = get_base64("./media/background.jpg")
with open("./style/style.css", "r") as style:
css=f"""<style>{style.read().format(background=background)}</style>"""
st.markdown(css, unsafe_allow_html=True)
left, middle, right = st.columns([1,1.5,1])
main, comps , result = middle.tabs([" ", " ", " "])
with main:
example = st.text_area(label='Metin Kutusu: ', placeholder="Lütfen Şikayet veya Yorum Metnini Buraya Yazın, daha sonra Predicte tıklayın")
if st.button("Predict"):
predict_list = predict_fonk(model=model, device=device, example=example, tokenizer=tokenizer)
st.write(predict_list)
|