File size: 819 Bytes
dded391
 
 
 
fd0d102
bbc73b2
 
 
 
 
 
dded391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import pandas as pd
import streamlit as st
import torch

from typing import List
from src.config import CONFIG
from src.model import BertLightningModel


def get_target_columns() -> List[str]:
    return ['cohesion', 'syntax', 'vocabulary', 'phraseology', 'grammar', 'conventions']


@st.cache(allow_output_mutation=True)
def load_model() -> BertLightningModel:

    ckpt_path = "./model.ckpt"
    model = BertLightningModel.load_from_checkpoint(ckpt_path, config=CONFIG, map_location='cpu')

    return model


@torch.no_grad()
def process_text(_text: str, _model: BertLightningModel) -> pd.DataFrame:
    tokens = _model.tokenizer([_text], return_tensors='pt')
    outputs = _model(tokens)[0].tolist()

    df = pd.DataFrame({
        'Criterion': get_target_columns(),
        'Grade': outputs
    })

    return df