File size: 1,842 Bytes
626eca0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import Any, Optional

import lightning
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler

from relik.reader.relik_reader_re import RelikReaderForTripletExtraction


class RelikReaderREPLModule(lightning.LightningModule):
    def __init__(
        self,
        cfg: dict,
        transformer_model: str,
        additional_special_symbols: int,
        num_layers: Optional[int] = None,
        activation: str = "gelu",
        linears_hidden_size: Optional[int] = 512,
        use_last_k_layers: int = 1,
        training: bool = False,
        *args: Any,
        **kwargs: Any
    ):
        super().__init__(*args, **kwargs)
        self.save_hyperparameters()

        self.relik_reader_re_model = RelikReaderForTripletExtraction(
            transformer_model,
            additional_special_symbols,
            num_layers,
            activation,
            linears_hidden_size,
            use_last_k_layers,
            training=training,
        )
        self.optimizer_factory = None

    def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
        relik_output = self.relik_reader_re_model(**batch)
        self.log("train-loss", relik_output["loss"])
        self.log("train-start_loss", relik_output["ned_start_loss"])
        self.log("train-end_loss", relik_output["ned_end_loss"])
        self.log("train-relation_loss", relik_output["re_loss"])
        return relik_output["loss"]

    def validation_step(
        self, batch: dict, *args: Any, **kwargs: Any
    ) -> Optional[STEP_OUTPUT]:
        return

    def set_optimizer_factory(self, optimizer_factory) -> None:
        self.optimizer_factory = optimizer_factory

    def configure_optimizers(self) -> OptimizerLRScheduler:
        return self.optimizer_factory(self.relik_reader_re_model)