Booguy's picture
Update utils.py
fd0d102
raw
history blame contribute delete
No virus
819 Bytes
import pandas as pd
import streamlit as st
import torch
from typing import List
from src.config import CONFIG
from src.model import BertLightningModel
def get_target_columns() -> List[str]:
return ['cohesion', 'syntax', 'vocabulary', 'phraseology', 'grammar', 'conventions']
@st.cache(allow_output_mutation=True)
def load_model() -> BertLightningModel:
ckpt_path = "./model.ckpt"
model = BertLightningModel.load_from_checkpoint(ckpt_path, config=CONFIG, map_location='cpu')
return model
@torch.no_grad()
def process_text(_text: str, _model: BertLightningModel) -> pd.DataFrame:
tokens = _model.tokenizer([_text], return_tensors='pt')
outputs = _model(tokens)[0].tolist()
df = pd.DataFrame({
'Criterion': get_target_columns(),
'Grade': outputs
})
return df