movie-genre / app.py
A-M-S's picture
Update app.py
cc8cfd2
raw history blame
No virus
2.46 kB
import streamlit as st
from transformers import AutoModelForSequenceClassification
import pandas as pd
import numpy as np
import torch
from preprocess import Preprocess
from utility import Utility
st.title("Movie Genre Predictor")
st.subheader("Enter the text you'd like to analyze.")
text = st.text_input('Enter text')
# out = model()
model = AutoModelForSequenceClassification.from_pretrained("./checkpoint-36819")
# model.to('cuda')
if st.button("Predict"):
st.write("Genre: ")
preprocess = Preprocess()
clean_plot = preprocess.apply(text)
utility = Utility()
id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller"])
st.write(tokenized_plot)
input_ids = [np.asarray(tokenized_plot['input_ids'])]
attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
st.write([np.asarray(tokenized_plot['input_ids'])])
st.write(clean_plot)
# # create Pandas DataFrame
# input_ids_labels_df = pd.DataFrame({'input_ids': xtrain_input_ids, 'attention_mask': xtrain_attention_mask, 'labels': ytrain.tolist()})
# # define data set object
# TD = CustomTextDataset(torch.IntTensor(input_ids_labels_df['input_ids']), torch.IntTensor(input_ids_labels_df['attention_mask']),\
# torch.FloatTensor(input_ids_labels_df['labels']))
# input_ids_labels_val_df = pd.DataFrame({'input_ids': xval_input_ids, 'attention_mask': xval_attention_mask, 'labels': yval.tolist()})
# VD = CustomTextDataset(torch.IntTensor(input_ids_labels_val_df['input_ids']), torch.IntTensor(input_ids_labels_val_df['attention_mask']),\
# torch.FloatTensor(input_ids_labels_val_df['labels']))
# # trainer = Trainer(
# # model,
# # train_dataset=TD,
# # eval_dataset=VD,
# # tokenizer=tokenizer,
# # compute_metrics=compute_metrics
# # )
# # y_pred = trainer.predict(VD)
# # y_pred = model(input_ids, attention_mask)
# preds = torch.FloatTensor(y_pred[0])
# y_predictions = []
# predictions = []
# for pred in preds:
# # apply sigmoid + threshold
# sigmoid = torch.nn.Sigmoid()
# probs = sigmoid(pred.squeeze().cpu())
# prediction = np.zeros(probs.shape)
# prediction[np.where(probs >= 0.5)] = 1
# predictions.append(prediction)
# y_pred = predictions
# st.write(out)