import pathlib
import tempfile
import zipfile
from copy import deepcopy
from typing import List
from typing import Optional
import pandas as pd
from typing_extensions import Self
from etna.core import SaveMixin
from etna.core import load
from etna.datasets import TSDataset
from etna.loggers import tslogger
from etna.pipeline.base import BasePipeline
[docs]class EnsembleMixin:
"""Base mixin for the ensembles."""
@staticmethod
def _validate_pipeline_number(pipelines: List[BasePipeline]):
"""Check that given valid number of pipelines."""
if len(pipelines) < 2:
raise ValueError("At least two pipelines are expected.")
@staticmethod
def _get_horizon(pipelines: List[BasePipeline]) -> int:
"""Get ensemble's horizon."""
horizons = {pipeline.horizon for pipeline in pipelines}
if len(horizons) > 1:
raise ValueError("All the pipelines should have the same horizon.")
return horizons.pop()
@staticmethod
def _fit_pipeline(pipeline: BasePipeline, ts: TSDataset) -> BasePipeline:
"""Fit given pipeline with ``ts``."""
tslogger.log(msg=f"Start fitting {pipeline}.")
pipeline.fit(ts=ts)
tslogger.log(msg=f"Pipeline {pipeline} is fitted.")
return pipeline
@staticmethod
def _forecast_pipeline(pipeline: BasePipeline, ts: TSDataset) -> TSDataset:
"""Make forecast with given pipeline."""
tslogger.log(msg=f"Start forecasting with {pipeline}.")
forecast = pipeline.forecast(ts=ts)
tslogger.log(msg=f"Forecast is done with {pipeline}.")
return forecast
@staticmethod
def _predict_pipeline(
ts: TSDataset,
pipeline: BasePipeline,
start_timestamp: Optional[pd.Timestamp],
end_timestamp: Optional[pd.Timestamp],
) -> TSDataset:
"""Make predict with given pipeline."""
tslogger.log(msg=f"Start prediction with {pipeline}.")
prediction = pipeline.predict(ts=ts, start_timestamp=start_timestamp, end_timestamp=end_timestamp)
tslogger.log(msg=f"Prediction is done with {pipeline}.")
return prediction
[docs]class SaveEnsembleMixin(SaveMixin):
"""Implementation of ``AbstractSaveable`` abstract class for ensemble pipelines.
It saves object to the zip archive with 3 entities:
* metadata.json: contains library version and class name.
* object.pkl: pickled without pipelines and ts.
* pipelines: folder with saved pipelines.
"""
pipelines: List[BasePipeline]
ts: Optional[TSDataset]
[docs] def save(self, path: pathlib.Path):
"""Save the object.
Parameters
----------
path:
Path to save object to.
"""
pipelines = self.pipelines
ts = self.ts
try:
# extract attributes we can't easily save
delattr(self, "pipelines")
delattr(self, "ts")
# save the remaining part
super().save(path=path)
finally:
self.pipelines = pipelines
self.ts = ts
with zipfile.ZipFile(path, "a") as archive:
with tempfile.TemporaryDirectory() as _temp_dir:
temp_dir = pathlib.Path(_temp_dir)
# save transforms separately
pipelines_dir = temp_dir / "pipelines"
pipelines_dir.mkdir()
num_digits = 8
for i, pipeline in enumerate(pipelines):
save_name = f"{i:0{num_digits}d}.zip"
pipeline_save_path = pipelines_dir / save_name
pipeline.save(pipeline_save_path)
archive.write(pipeline_save_path, f"pipelines/{save_name}")
[docs] @classmethod
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Self:
"""Load an object.
Warning
-------
This method uses :py:mod:`dill` module which is not secure.
It is possible to construct malicious data which will execute arbitrary code during loading.
Never load data that could have come from an untrusted source, or that could have been tampered with.
Parameters
----------
path:
Path to load object from.
ts:
TSDataset to set into loaded pipeline.
Returns
-------
:
Loaded object.
"""
obj = super().load(path=path)
obj.ts = deepcopy(ts)
with zipfile.ZipFile(path, "r") as archive:
with tempfile.TemporaryDirectory() as _temp_dir:
temp_dir = pathlib.Path(_temp_dir)
archive.extractall(temp_dir)
# load pipelines
pipelines_dir = temp_dir / "pipelines"
pipelines = []
for path in sorted(pipelines_dir.iterdir()):
pipelines.append(load(path, ts=ts))
obj.pipelines = pipelines
return obj