Skip to content

Last Observation

ts_bolt.naive_forecasters.last_observation¤

LastObservationForecaster ¤

Bases: LightningModule

Spits out the forecasts using the last observation.

Parameters:

Name Type Description Default
horizon int

horizon of the forecast.

required
Source code in ts_bolt/naive_forecasters/last_observation.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class LastObservationForecaster(L.LightningModule):
    """Spits out the forecasts using the last observation.

    :param horizon: horizon of the forecast.
    """

    def __init__(self, horizon: int):
        super().__init__()
        self.horizon = horizon

    def _last_observation(self, x: torch.Tensor) -> torch.Tensor:
        return x[..., -1:, :]

    def predict_step(
        self, batch: Sequence[torch.Tensor], batch_idx: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x, y = batch

        y_hat = self._last_observation(x)

        y_hat = y_hat.repeat(1, self.horizon, 1)

        return x.squeeze(-1), y_hat.squeeze(-1)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        x = x.type(self.dtype)
        return (
            x.squeeze(-1),
            self._last_observation(x).repeat(1, self.horizon, 1).squeeze(-1),
        )