From 161c619733865e83acdcff7318a5833da81e6425 Mon Sep 17 00:00:00 2001 From: Will Usher Date: Wed, 21 Jun 2023 15:01:06 +0200 Subject: [PATCH] Adds a wrapper around ReadStategy.read to return an xarray.DataSet --- requirements.txt | 1 + setup.cfg | 1 + src/otoole/input.py | 32 ++++++++++++++++++++++++++++++++ tests/test_input.py | 20 ++++++++++++++++++++ 4 files changed, 54 insertions(+) diff --git a/requirements.txt b/requirements.txt index 7a9f8eae..7e39ae5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ pandas pydantic pydot pyyaml +xarray xlrd diff --git a/setup.cfg b/setup.cfg index b410ffb8..f5ee88b5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,6 +52,7 @@ install_requires = flatten_dict openpyxl pydantic + xarray [options.packages.find] where = src exclude = diff --git a/src/otoole/input.py b/src/otoole/input.py index 4bb50921..d23e28e2 100644 --- a/src/otoole/input.py +++ b/src/otoole/input.py @@ -35,6 +35,7 @@ from typing import Any, Dict, List, Optional, TextIO, Tuple, Union import pandas as pd +import xarray as xr from otoole.exceptions import OtooleIndexError, OtooleNameMismatchError @@ -593,3 +594,34 @@ def read( self, filepath: Union[str, TextIO], **kwargs ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, Any]]: raise NotImplementedError() + + def to_xarray(self, filepath) -> xr.Dataset: + """Returns input data as an xarray.Dataset + + Arguments + --------- + filepath: Union[str, TextIO] + + """ + + model, defaults = self.read(filepath) + config = self.input_config + + data_vars = { + x: y.VALUE.to_xarray() + for x, y in model.items() + if config[x]["type"] == "param" + } + coords = { + x: y.values.T[0] for x, y in model.items() if config[x]["type"] == "set" + } + ds = xr.Dataset(data_vars=data_vars, coords=coords) + # ds = ds.assign_coords({'_REGION': model['REGION'].values.T[0]}) + + for param, default in defaults.items(): + if param in config and param in model and config[param]["type"] == "param": + ds[param].attrs["default"] = default + if default != 0: + ds[param] = ds[param].fillna(default) + + return ds diff --git a/tests/test_input.py b/tests/test_input.py index 37fd4729..3d629f8f 100644 --- a/tests/test_input.py +++ b/tests/test_input.py @@ -1,11 +1,13 @@ from typing import Any, Dict, TextIO, Tuple, Union import pandas as pd +import xarray as xr from pandas.testing import assert_frame_equal from pytest import fixture, mark, raises from otoole.exceptions import OtooleIndexError, OtooleNameMismatchError from otoole.input import ReadStrategy, WriteStrategy +from otoole.read_strategies import ReadMemory @fixture @@ -533,3 +535,21 @@ def test_compare_read_to_expected_exception(self, simple_user_config, expected): reader = DummyReadStrategy(simple_user_config) with raises(OtooleNameMismatchError): reader._compare_read_to_expected(names=expected) + + +class TestXarray: + def test_xarray(self, user_config): + """Test that xarray can be imported""" + + data = [ + ["SIMPLICITY", "ETH", 2014, 1.0], + ["SIMPLICITY", "RAWSUG", 2014, 0.5], + ["SIMPLICITY", "ETH", 2015, 1.03], + ["SIMPLICITY", "RAWSUG", 2015, 0.51], + ] + df = pd.DataFrame(data=data, columns=["REGION", "FUEL", "YEAR", "VALUE"]) + parameters = {"AccumulatedAnnualDemand": df} + + reader = ReadMemory(parameters, user_config) + actual = reader.to_xarray("") + assert isinstance(actual, xr.Dataset)