File size: 2,645 Bytes
597bf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from typing import Optional

import pandas as pd
import streamlit as st
from datasets import Dataset  # type: ignore

from data import encode_dataset, get_collator, get_data, get_split_df
from model import get_encoder, get_model, get_tokenizer
from subpages import Context
from utils import align_sample, device, explode_df

_TOKENIZER_NAME = (
    "xlm-roberta-base",
    "gagan3012/bert-tiny-finetuned-ner",
    "distilbert-base-german-cased",
)[0]


def _load_models_and_tokenizer(
    encoder_model_name: str,
    model_name: str,
    tokenizer_name: Optional[str],
    device: str = "cpu",
):
    sentence_encoder = get_encoder(encoder_model_name, device=device)
    tokenizer = get_tokenizer(tokenizer_name if tokenizer_name else model_name)
    labels = "O B-COMMA".split() if "comma" in model_name else None
    model = get_model(model_name, labels=labels)
    return sentence_encoder, model, tokenizer


@st.cache(allow_output_mutation=True)
def load_context(
    encoder_model_name: str,
    model_name: str,
    ds_name: str,
    ds_config_name: str,
    ds_split_name: str,
    split_sample_size: int,
    **kw_args,
) -> Context:

    sentence_encoder, model, tokenizer = _load_models_and_tokenizer(
        encoder_model_name=encoder_model_name,
        model_name=model_name,
        tokenizer_name=_TOKENIZER_NAME if "comma" in model_name else None,
        device=str(device),
    )
    collator = get_collator(tokenizer)

    # load data related stuff
    split: Dataset = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size)
    tags = split.features["ner_tags"].feature
    split_encoded, word_ids, ids = encode_dataset(split, tokenizer)

    # transform into dataframe
    df = get_split_df(split_encoded, model, tokenizer, collator, tags)
    df["word_ids"] = word_ids
    df["ids"] = ids

    # explode, clean, merge
    df_tokens = explode_df(df)
    df_tokens_cleaned = df_tokens.query("labels != 'IGN'")
    df_merged = pd.DataFrame(df.apply(align_sample, axis=1).tolist())
    df_tokens_merged = explode_df(df_merged)

    return Context(
        **{
            "model": model,
            "tokenizer": tokenizer,
            "sentence_encoder": sentence_encoder,
            "df": df,
            "df_tokens": df_tokens,
            "df_tokens_cleaned": df_tokens_cleaned,
            "df_tokens_merged": df_tokens_merged,
            "tags": tags,
            "labels": tags.names,
            "split_sample_size": split_sample_size,
            "ds_name": ds_name,
            "ds_config_name": ds_config_name,
            "ds_split_name": ds_split_name,
            "split": split,
        }
    )