Spaces:
Runtime error
Runtime error
File size: 819 Bytes
dded391 fd0d102 bbc73b2 dded391 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import pandas as pd
import streamlit as st
import torch
from typing import List
from src.config import CONFIG
from src.model import BertLightningModel
def get_target_columns() -> List[str]:
return ['cohesion', 'syntax', 'vocabulary', 'phraseology', 'grammar', 'conventions']
@st.cache(allow_output_mutation=True)
def load_model() -> BertLightningModel:
ckpt_path = "./model.ckpt"
model = BertLightningModel.load_from_checkpoint(ckpt_path, config=CONFIG, map_location='cpu')
return model
@torch.no_grad()
def process_text(_text: str, _model: BertLightningModel) -> pd.DataFrame:
tokens = _model.tokenizer([_text], return_tensors='pt')
outputs = _model(tokens)[0].tolist()
df = pd.DataFrame({
'Criterion': get_target_columns(),
'Grade': outputs
})
return df
|