diff --git a/deadtrees/deployment/inference.py b/deadtrees/deployment/inference.py index 91320e3..a968879 100644 --- a/deadtrees/deployment/inference.py +++ b/deadtrees/deployment/inference.py @@ -1,7 +1,7 @@ import io from abc import ABC, abstractmethod from pathlib import Path -from typing import Union +from typing import Iterable, Union import numpy as np import torch @@ -12,15 +12,18 @@ class Inference(ABC): - def __init__(self, model_file: Union[str, Path]) -> None: - self._model_file = ( - model_file if isinstance(model_file, Path) else Path(model_file) - ) - super().__init__() + # def __init__(self, model_file: Union[str, Path]) -> None: + # self._model_file = ( + # model_file if isinstance(model_file, Path) else Path(model_file) + # ) + # super().__init__() @property def model_file(self) -> str: - return self._model_file.name + if hasattr(self, "_models"): + return ",".join([m.name for m in self._models]) + else: + return self._model.name @abstractmethod def run(self, input_tensor: torch.Tensor): @@ -29,7 +32,11 @@ def run(self, input_tensor: torch.Tensor): class PyTorchInference(Inference): def __init__(self, model_file) -> None: - super().__init__(model_file) + # super().__init__(model_file) + + self._model_file = ( + model_file if isinstance(model_file, Path) else Path(model_file) + ) if self._model_file.suffix != ".ckpt": raise ValueError( @@ -44,7 +51,17 @@ def __init__(self, model_file) -> None: # TODO: this is ugly, rename or restructure self._model = model.model - def run(self, input_tensor, device: str = "cpu"): + @property + def channels(self) -> int: + return self._channels + + @property + def classes(self) -> int: + return self._model.classes + + def run(self, input_tensor, device: str = "cpu", return_raw: bool = False): + """run the model, return either the raw logits of all models or the mode""" + if not isinstance(input_tensor, torch.Tensor): raise TypeError("no pytorch tensor provided") @@ -59,7 +76,10 @@ def run(self, input_tensor, device: str = "cpu"): input_tensor = input_tensor[:, 0:3, :, :] out = self._model(input_tensor) - return out.argmax(dim=1).squeeze() + if return_raw: + return out + else: + return out.argmax(dim=1).squeeze() class PyTorchEnsembleInference: @@ -67,6 +87,11 @@ def __init__(self, *model_files: Path): self._models = [] self._channels = None + self._model_files = [ + model_file if isinstance(model_file, Path) else Path(model_file) + for model_file in model_files + ] + if len(model_files) % 2 == 0: raise ValueError( "PyTorchEnsembleInference requires an uneven number of models" @@ -93,7 +118,16 @@ def __init__(self, *model_files: Path): # TODO: this is ugly, rename or restructure self._models.append(model.model) - def run(self, input_tensor, device: str = "cpu"): + @property + def channels(self) -> int: + return self._channels + + @property + def classes(self) -> int: + return self._models[0].classes + + def run(self, input_tensor, device: str = "cpu", return_raw: bool = False): + """run the model(s), return either the raw logits of all models or the mode""" if not isinstance(input_tensor, torch.Tensor): raise TypeError("No PyTorch tensor provided") @@ -107,15 +141,21 @@ def run(self, input_tensor, device: str = "cpu"): outs = [] for model in self._models: model.to(device) - with torch.no_grad(): out = model(input_tensor) - outs.append(out.argmax(dim=1).squeeze()) + outs.append(out) - return torch.mode(torch.stack(outs, dim=1), axis=1)[0] + if return_raw: + # dims: m, bs, c, h, w + return torch.stack(outs, dim=0) + else: + # dims: bs, h, w + model_results = [out.argmax(dim=1).squeeze() for out in outs] + return torch.mode(torch.stack(model_results, dim=1), axis=1).values +# deprecated, do not use class ONNXInference(Inference): def __init__(self, model_file) -> None: super().__init__(model_file) diff --git a/deadtrees/deployment/tiler.py b/deadtrees/deployment/tiler.py deleted file mode 100644 index ddbd7af..0000000 --- a/deadtrees/deployment/tiler.py +++ /dev/null @@ -1,170 +0,0 @@ -# flake8: noqa: E402 -import argparse -import warnings -from pathlib import Path -from typing import Optional, Tuple, Union - -import xarray - -warnings.filterwarnings("ignore", category=UserWarning) - -import math -from dataclasses import dataclass - -import numpy as np -import rioxarray -from deadtrees.utils.data_handling import ( - make_blocks_vectorized, - unmake_blocks_vectorized, -) - - -@dataclass -class TileInfo: - size: Tuple[int, int] - subtiles: Tuple[int, int] - - -def divisible_without_remainder(a, b): - if b == 0: - return False - return True if a % b == 0 else False - - -def inspect_tile( - infile: Union[str, Path, xarray.DataArray], - tile_shape: Tuple[int, int] = (8192, 8192), - subtile_shape: Tuple[int, int] = (512, 512), -) -> TileInfo: - with rioxarray.open_rasterio(infile).sel(band=1, drop=True) if not isinstance( - infile, xarray.DataArray - ) else infile as da: - - shape = tuple(da.shape) - - if not divisible_without_remainder(tile_shape[0], subtile_shape[0]): - raise ValueError(f"Shapes unaligned (v): {tile_shape[0], subtile_shape[0]}") - - if not divisible_without_remainder(tile_shape[1], subtile_shape[1]): - raise ValueError(f"Shapes unaligned (h): {tile_shape[1], subtile_shape[1]}") - - subtiles = ( - math.ceil(shape[0] / subtile_shape[0]), - math.ceil(shape[1] / subtile_shape[1]), - ) - - return TileInfo(size=shape, subtiles=subtiles) - - -class Tiler: - def __init__( - self, - infile: Optional[Union[str, Path]] = None, - tile_shape: Optional[Tuple[int, int]] = (2048, 2048), - subtile_shape: Optional[Tuple[int, int]] = (256, 256), - ) -> None: - self._infile = infile - self._tile_shape = tile_shape - self._subtile_shape = subtile_shape - - if subtile_shape[0] != subtile_shape[1]: - raise ValueError("Subtile required to have matching x/y dims") - - self._source: Optional[xarray.DataArray] = None - self._target: Optional[xarray.DataArray] = None - self._indata: Optional[np.ndarray] = None - self._outdata: Optional[np.ndarray] = None - self._batch_shape: Optional[np.ndarray] = None - self._subtiles_to_use: Optional[np.ndarray] = None - - self._tile_info: Optional[TileInfo] = None - - def load_file( - self, - infile: Union[str, Path], - tile_shape: Optional[Tuple[int, int]] = None, - subtile_shape: Optional[Tuple[int, int]] = None, - ) -> None: - - self._infile = infile - self._tile_shape = tile_shape or self._tile_shape - - if subtile_shape: - if subtile_shape[0] != subtile_shape[1]: - raise ValueError("Subtile required to have matching x/y dims") - self._subtile_shape = subtile_shape or self._subtile_shape - - self._tile_info = inspect_tile( - self._infile, self._tile_shape, self._subtile_shape - ) - - self._source = rioxarray.open_rasterio( - self._infile, chunks={"band": 4, "x": 256, "y": 256} - ) - - # define padded indata array and place original data inside - sv = self._source.values - if self._tile_shape != self._tile_info.size: - self._indata = np.zeros((4, *self._tile_shape), dtype=self._source.dtype) - self._indata[:, 0 : sv.shape[1], 0 : sv.shape[2]] = sv - else: - self._indata = sv - - # output xarray (single band) - self._target = ( - self._source.sel(band=1, drop=True).astype("uint8").copy(deep=True) - ) - - # define padded outdata array - self._outdata = np.zeros(self._tile_shape, dtype="uint8") - - # mark only necessary subtiles - subtiles_mask = np.zeros( - ( - self._tile_shape[0] // self._subtile_shape[0], - self._tile_shape[1] // self._subtile_shape[1], - ), - dtype=bool, - ) - subtiles_mask[ - 0 : self._tile_info.subtiles[0], 0 : self._tile_info.subtiles[1] - ] = 1 - self._subtiles_to_use = subtiles_mask.ravel() - - def write_file(self, outfile: Union[str, Path]) -> None: - if self._target is not None: - # copy data from outdata array into dataarray - self._target[:] = self._outdata[ - 0 : self._tile_info.size[0], 0 : self._tile_info.size[1] - ] - self._target.rio.to_raster(outfile, compress="LZW", tiled=True) - - def get_batches(self) -> np.ndarray: - subtiles = make_blocks_vectorized(self._indata, self._subtile_shape[0]) - self._batch_shape = self._batch_shape or subtiles.shape - return subtiles[self._subtiles_to_use] - - def put_batches(self, batches: np.ndarray) -> None: - batches_expanded = [] - batch_idx = 0 - for flag in self._subtiles_to_use: - if flag == 1: - batches_expanded.append(batches[batch_idx]) - batch_idx += 1 - else: - batches_expanded.append(np.zeros(batches[0].shape)) - - batches_expanded = np.array(batches_expanded) - - self._outdata = unmake_blocks_vectorized( - batches_expanded, - self._subtile_shape[0], - self._tile_shape[0], - self._tile_shape[1], - ) - - # pass data into geo-registered rioxarray object (only subset of expanded tile if not complete tile) - self._target = self._target.load() - self._target.loc[:] = self._outdata[ - 0 : self._tile_info.size[0], 0 : self._tile_info.size[1] - ] diff --git a/dvc.lock b/dvc.lock index 96cb897..0753a13 100644 --- a/dvc.lock +++ b/dvc.lock @@ -110,47 +110,61 @@ stages: size: 109902740888 nfiles: 30489 inference@2017: - cmd: mkdir -p data/predicted.2017; stdbuf -i0 -o0 -e0 python scripts/inference.py - --all --nopreview -o data/predicted.2017 data/processed.images.2017; gdal_merge.py -co - "TILED=YES" -co "COMPRESS=LZW" -co "PREDICTOR=2" -co "NUM_THREADS=ALL_CPUS" + cmd: mkdir -p data/predicted.2017; stdbuf -i0 -o0 -e0 python scripts/inference.py + --all --nopreview -o data/predicted.2017 data/processed.images.2017 -m checkpoints/earnest-dew-216_epoch_235.ckpt -m + checkpoints/fine-lake-207_epoch_279.ckpt -m checkpoints/sage-glitter-214_epoch_106.ckpt; + gdal_merge.py -co "TILED=YES" -co "COMPRESS=LZW" -co "PREDICTOR=2" -co "NUM_THREADS=ALL_CPUS" -o data/predicted_mosaic_2017.tif data/predicted.2017/ortho_ms_2017_EPSG3044_* deps: - - path: checkpoints/bestmodel.ckpt - md5: fa7e507a107381e5dd9c2ebc9ddef09f - size: 378649412 + - path: checkpoints/earnest-dew-216_epoch_235.ckpt + md5: 9ee3273773fb824722c7438d6e6bb994 + size: 378642436 + - path: checkpoints/fine-lake-207_epoch_279.ckpt + md5: 68d739c35961b1d218f931f673d57b8b + size: 378644356 + - path: checkpoints/sage-glitter-214_epoch_106.ckpt + md5: d69b30981b4f43333d9adc802e09c67d + size: 378642436 - path: data/processed.images.2017 md5: 66de96d201af5a47a09997691b992370.dir size: 109902740888 nfiles: 30489 outs: - path: data/predicted.2017 - md5: 6991e5e7cf8144ff46445ef25214fce6.dir - size: 601515230 + md5: d121ec6b99c1c20ad9c8d8c5a79f2717.dir + size: 590920267 nfiles: 20827 - path: data/predicted_mosaic_2017.tif - md5: 7cef75e791ea66524e8afdac1419e7e5 - size: 839252195 + md5: e8123ba2b1c70b76bf547cbd07c65cc4 + size: 827137067 inference@2019: - cmd: mkdir -p data/predicted.2019; stdbuf -i0 -o0 -e0 python scripts/inference.py - --all --nopreview -o data/predicted.2019 data/processed.images.2019; gdal_merge.py -co - "TILED=YES" -co "COMPRESS=LZW" -co "PREDICTOR=2" -co "NUM_THREADS=ALL_CPUS" + cmd: mkdir -p data/predicted.2019; stdbuf -i0 -o0 -e0 python scripts/inference.py + --all --nopreview -o data/predicted.2019 data/processed.images.2019 -m checkpoints/earnest-dew-216_epoch_235.ckpt -m + checkpoints/fine-lake-207_epoch_279.ckpt -m checkpoints/sage-glitter-214_epoch_106.ckpt; + gdal_merge.py -co "TILED=YES" -co "COMPRESS=LZW" -co "PREDICTOR=2" -co "NUM_THREADS=ALL_CPUS" -o data/predicted_mosaic_2019.tif data/predicted.2019/ortho_ms_2019_EPSG3044_* deps: - - path: checkpoints/bestmodel.ckpt - md5: fa7e507a107381e5dd9c2ebc9ddef09f - size: 378649412 + - path: checkpoints/earnest-dew-216_epoch_235.ckpt + md5: 9ee3273773fb824722c7438d6e6bb994 + size: 378642436 + - path: checkpoints/fine-lake-207_epoch_279.ckpt + md5: 68d739c35961b1d218f931f673d57b8b + size: 378644356 + - path: checkpoints/sage-glitter-214_epoch_106.ckpt + md5: d69b30981b4f43333d9adc802e09c67d + size: 378642436 - path: data/processed.images.2019 md5: 35c31b781cb0bdb19650329132d83d05.dir size: 144926802415 nfiles: 30489 outs: - path: data/predicted.2019 - md5: 88d9fe4afe7054d15e5014b9743d87c9.dir - size: 484982007 + md5: 59ded5bebfabb2fbdc933cc6a058c79c.dir + size: 472154703 nfiles: 16227 - path: data/predicted_mosaic_2019.tif - md5: fd472cc45d61c7223cbd09ab2872a30a - size: 811774816 + md5: 66c45c751409f6916e6652ee4e028fce + size: 797539107 computestats: cmd: 'python scripts/computestats.py --frac 0.1 data/processed.images.2017 data/processed.images.2018 data/processed.images.2019 data/processed.images.2020 ' deps: @@ -229,47 +243,61 @@ stages: size: 163333629599 nfiles: 30489 inference@2018: - cmd: mkdir -p data/predicted.2018; stdbuf -i0 -o0 -e0 python scripts/inference.py - --all --nopreview -o data/predicted.2018 data/processed.images.2018; gdal_merge.py -co - "TILED=YES" -co "COMPRESS=LZW" -co "PREDICTOR=2" -co "NUM_THREADS=ALL_CPUS" + cmd: mkdir -p data/predicted.2018; stdbuf -i0 -o0 -e0 python scripts/inference.py + --all --nopreview -o data/predicted.2018 data/processed.images.2018 -m checkpoints/earnest-dew-216_epoch_235.ckpt -m + checkpoints/fine-lake-207_epoch_279.ckpt -m checkpoints/sage-glitter-214_epoch_106.ckpt; + gdal_merge.py -co "TILED=YES" -co "COMPRESS=LZW" -co "PREDICTOR=2" -co "NUM_THREADS=ALL_CPUS" -o data/predicted_mosaic_2018.tif data/predicted.2018/ortho_ms_2018_EPSG3044_* deps: - - path: checkpoints/bestmodel.ckpt - md5: fa7e507a107381e5dd9c2ebc9ddef09f - size: 378649412 + - path: checkpoints/earnest-dew-216_epoch_235.ckpt + md5: 9ee3273773fb824722c7438d6e6bb994 + size: 378642436 + - path: checkpoints/fine-lake-207_epoch_279.ckpt + md5: 68d739c35961b1d218f931f673d57b8b + size: 378644356 + - path: checkpoints/sage-glitter-214_epoch_106.ckpt + md5: d69b30981b4f43333d9adc802e09c67d + size: 378642436 - path: data/processed.images.2018 md5: cfa0adee6401f838f162a0510085becf.dir size: 144861203951 nfiles: 30489 outs: - path: data/predicted.2018 - md5: 6dc4d6140a683a68243cd01d5d733001.dir - size: 551422735 + md5: 5dea00cb9a4b6987d2dd0966abaf14cd.dir + size: 545737314 nfiles: 19182 - path: data/predicted_mosaic_2018.tif - md5: f8ae9346bf27514a30592e745ab87bc1 - size: 830985617 + md5: 165d623d21d5e47a4eaf8f2f5e45fd27 + size: 824301626 inference@2020: - cmd: mkdir -p data/predicted.2020; stdbuf -i0 -o0 -e0 python scripts/inference.py - --all --nopreview -o data/predicted.2020 data/processed.images.2020; gdal_merge.py -co - "TILED=YES" -co "COMPRESS=LZW" -co "PREDICTOR=2" -co "NUM_THREADS=ALL_CPUS" + cmd: mkdir -p data/predicted.2020; stdbuf -i0 -o0 -e0 python scripts/inference.py + --all --nopreview -o data/predicted.2020 data/processed.images.2020 -m checkpoints/earnest-dew-216_epoch_235.ckpt -m + checkpoints/fine-lake-207_epoch_279.ckpt -m checkpoints/sage-glitter-214_epoch_106.ckpt; + gdal_merge.py -co "TILED=YES" -co "COMPRESS=LZW" -co "PREDICTOR=2" -co "NUM_THREADS=ALL_CPUS" -o data/predicted_mosaic_2020.tif data/predicted.2020/ortho_ms_2020_EPSG3044_* deps: - - path: checkpoints/bestmodel.ckpt - md5: fa7e507a107381e5dd9c2ebc9ddef09f - size: 378649412 + - path: checkpoints/earnest-dew-216_epoch_235.ckpt + md5: 9ee3273773fb824722c7438d6e6bb994 + size: 378642436 + - path: checkpoints/fine-lake-207_epoch_279.ckpt + md5: 68d739c35961b1d218f931f673d57b8b + size: 378644356 + - path: checkpoints/sage-glitter-214_epoch_106.ckpt + md5: d69b30981b4f43333d9adc802e09c67d + size: 378642436 - path: data/processed.images.2020 md5: 76d4daf0d9e69904a9370a0c74006d17.dir size: 163333629599 nfiles: 30489 outs: - path: data/predicted.2020 - md5: ca8bf8b4c611989f5c962b07dee85344.dir - size: 489635809 + md5: ee17adb78816e89fc3f699927ca0ac0e.dir + size: 478573657 nfiles: 16125 - path: data/predicted_mosaic_2020.tif - md5: 5e3c82736e20844bb9b1fc6f4d6cbbb6 - size: 812671767 + md5: 49988d362425fe43c496ecc75d4611ef + size: 799740259 createmasks@2017: cmd: python scripts/createmasks.py data/processed.images.2017 data/processed.masks.2017 data/raw/shapefiles/deadtrees_2017/deadtrees_2017.shp deps: diff --git a/dvc.yaml b/dvc.yaml index a0a4df2..0f64672 100644 --- a/dvc.yaml +++ b/dvc.yaml @@ -124,6 +124,7 @@ stages: # train: do this manually # inference + # overlap values possible: 32px or 128px inference: foreach: - 2017 @@ -133,14 +134,20 @@ stages: do: cmd: >- mkdir -p data/predicted.${item}; - stdbuf -i0 -o0 -e0 python scripts/inference.py --all --nopreview -o data/predicted.${item} data/processed.images.${item}; + stdbuf -i0 -o0 -e0 + python scripts/inference.py --all --overlap 32 -o data/predicted.${item} data/processed.images.${item} + -m checkpoints/earnest-dew-216_epoch_235.ckpt + -m checkpoints/fine-lake-207_epoch_279.ckpt + -m checkpoints/sage-glitter-214_epoch_106.ckpt; gdal_merge.py -co "TILED=YES" -co "COMPRESS=LZW" -co "PREDICTOR=2" -co "NUM_THREADS=ALL_CPUS" -o data/predicted_mosaic_${item}.tif data/predicted.${item}/ortho_ms_${item}_EPSG3044_* deps: - data/processed.images.${item} - - checkpoints/bestmodel.ckpt + - checkpoints/earnest-dew-216_epoch_235.ckpt + - checkpoints/fine-lake-207_epoch_279.ckpt + - checkpoints/sage-glitter-214_epoch_106.ckpt outs: - data/predicted.${item} - data/predicted_mosaic_${item}.tif @@ -160,3 +167,23 @@ stages: - data/predicted.2020 outs: - data/predicted.stats.csv + + computeaggregatemaps: + cmd: >- + mkdir data/maps; + python scripts/aggregate_results.py + data/processed.lus.2017 + data/processed.lus.2018 + data/processed.lus.2019 + data/processed.lus.2020 + deps: + - data/predicted.2017 + - data/predicted.2018 + - data/predicted.2019 + - data/predicted.2020 + - data/processed.lus.2017 + - data/processed.lus.2018 + - data/processed.lus.2019 + - data/processed.lus.2020 + outs: + - data/maps diff --git a/scripts/inference.py b/scripts/inference.py index 17a7678..86808ac 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -3,12 +3,14 @@ from pathlib import Path from typing import List +from tiler import Merger, Tiler + import numpy as np import rioxarray import torch +import xarray as xr from deadtrees.data.deadtreedata import val_transform from deadtrees.deployment.inference import PyTorchEnsembleInference, PyTorchInference -from deadtrees.deployment.tiler import Tiler from PIL import Image from tqdm import tqdm @@ -35,6 +37,14 @@ def main(): help="output directory", ) + parser.add_argument( + "--overlap", + dest="overlap", + type=int, + default=32, + help="overlap of subtiles (256x256px tile: 32 or 128, def:32)", + ) + parser.add_argument( "--all", action="store_true", @@ -43,19 +53,15 @@ def main(): help="process complete directory", ) - parser.add_argument( - "--nopreview", - action="store_false", - dest="preview", - default=True, - help="produce preview images", - ) - args = parser.parse_args() if len(args.model) == 0: args.model = [Path("checkpoints/bestmodel.ckpt")] + if args.overlap not in [32, 128]: + print("Currently only 32 or 128 allowed for subtile overlap") + exit(-1) + bs = 64 INFILE = args.infile @@ -64,7 +70,6 @@ def is_valid_tile(infile): with rioxarray.open_rasterio(infile).sel(band=1) as t: return False if np.isin(t, [0, 255]).all() else True - # inference = ONNXInference("checkpoints/bestmodel.onnx") if len(args.model) == 1: print("Default inference: single model") inference = PyTorchInference(args.model[0]) @@ -72,6 +77,13 @@ def is_valid_tile(infile): print(f"Ensemble inference: {len(args.model)} models") inference = PyTorchEnsembleInference(*args.model) + n_channel = inference.channels + n_classes = inference.classes + + print( + f"Inference using {n_channel}-channel model(s) and {n_classes} output classes" + ) + if args.all: INFILES = sorted(INFILE.glob("ortho*.tif")) else: @@ -82,37 +94,75 @@ def is_valid_tile(infile): if not is_valid_tile(INFILE): continue - tiler = Tiler() - tiler.load_file(INFILE) - - batches = tiler.get_batches() - batches = np.array_split(batches, math.ceil(len(batches) / bs), axis=0) + OUTFILE = args.outpath / INFILE.name - out_batches = [] + # read geotiff + source: xr.DataArray = rioxarray.open_rasterio( + INFILE, chunks={"band": 4, "x": 256, "y": 256} + ) + if n_channel > len(source.band): + print( + f"Source image {source.name} has wrong number of channels: image: {len(source.band)} model: {n_channel}" + ) - for b, batch in enumerate(tqdm(batches, desc=INFILE.name)): + # prepare output geotiff + target: xr.DataArray = ( + source.sel(band=1, drop=True).astype("uint8").copy(deep=True).load() + ) + + in_tiler = Tiler( + data_shape=source.values.shape, + tile_shape=(n_channel, 256, 256), + overlap=args.overlap, + channel_dimension=0, + ) + out_tiler = Tiler( + data_shape=source.values.shape, + tile_shape=(n_classes, 256, 256), + overlap=args.overlap, + channel_dimension=0, + ) + + # one merger for each model + out_merger = [Merger(out_tiler)] * len(args.model) + + # make sure n_channels of model match data size (use RGB aka 0:3 if model requires) + batches = [ + batch + for _, batch in in_tiler(source.values[0:n_channel, ...], batch_size=bs) + ] + + for batch_id, batch in enumerate(tqdm(batches, desc=INFILE.name)): batch_tensor = torch.stack( [val_transform(image=i.transpose(1, 2, 0))["image"] for i in batch] ) - # pytorch out_batch = ( - inference.run(batch_tensor.detach().to("cuda"), device="cuda") + inference.run( + batch_tensor.detach().to("cuda"), device="cuda", return_raw=True + ) .cpu() .numpy() ) - out_batches.append(out_batch) - - OUTFILE = args.outpath / INFILE.name - OUTFILE_PREVIEW = Path(str(args.outpath) + "_preview") / INFILE.name - - tiler.put_batches(np.concatenate(out_batches, axis=0)) - tiler.write_file(OUTFILE) - - if args.preview: - image = Image.fromarray(np.uint8(tiler._target.values * 255), "L") - image.save(OUTFILE_PREVIEW) + # dims: model, bs, c, h, w + if isinstance(inference, PyTorchEnsembleInference): + for i in range(len(args.model)): + out_merger[i].add_batch(batch_id, bs, out_batch[i]) + else: + out_merger[0].add_batch(batch_id, bs, out_batch) + + # this is still based on logits since we used return_raw in inference.run() ! + # 1) use merger to recreate full tile + # 2) argmax over logits to find dominent class in each pixel + # 3) take the mode over all models to derive final px class value (this is done via torch) + output_per_model = np.array( + [np.argmax(m.merge(unpad=True), axis=0) for m in out_merger] + ) + output = torch.mode(torch.Tensor(output_per_model), axis=0).values.numpy() + + target[:] = output + target.rio.to_raster(OUTFILE, compress="LZW", tiled=True) if __name__ == "__main__": diff --git a/setup.py b/setup.py index ad20e45..3738265 100644 --- a/setup.py +++ b/setup.py @@ -37,11 +37,13 @@ "python-dotenv", "hydra-core>=1.1.0", "hydra-colorlog>=1.1.0", + "numpy<=1.19", "pydantic", "torch>=1.10.0", "torchvision>=0.12.0", "pytorch-lightning>=1.5", "rich", + "tiler @ git+https://github.com/cwerner/tiler.git#egg=tiler-0.5.7", "tqdm", "webdataset==0.1.62", "segmentation_models_pytorch>=0.2.1", diff --git a/tests/test_tiler.py b/tests/test_tiler.py deleted file mode 100644 index c2ee381..0000000 --- a/tests/test_tiler.py +++ /dev/null @@ -1,148 +0,0 @@ -import tempfile -from math import prod -from pathlib import Path -from typing import Tuple, Union - -import pytest -from attr import dataclass - -import numpy as np -import rioxarray -from deadtrees.deployment.tiler import divisible_without_remainder, inspect_tile, Tiler -from deadtrees.utils.data_handling import ( - make_blocks_vectorized, - unmake_blocks_vectorized, -) - - -@pytest.fixture -def tiler(): - return Tiler() - - -@dataclass -class TileData: - filename: Union[str, Path] - size: Tuple[int, int] - subtiles: Tuple[int, int] - - -example1 = TileData( - Path("tests/testdata/tiles/ortho_2019_ESPG3044_49_11.tif"), - (8192, 8192), - (16, 16), -) - -example2 = TileData( - Path("tests/testdata/tiles/ortho_2019_ESPG3044_27_37.tif"), - (8192, 7433), - (16, 15), -) - -example3 = TileData( - Path("tests/testdata/tiles/ortho_2019_ESPG3044_52_26.tif"), - (2649, 8192), - (6, 16), -) - -examples = [example1, example2, example3] - - -@pytest.mark.parametrize("a,b,result", [(10, 2, True), (5, 4, False), (2, 0, False)]) -def test_divisible_without_remainder(a, b, result): - assert divisible_without_remainder(a, b) == result - - -class TestBlocksVectorized: - source = np.array([np.arange(16).reshape(4, 4)] * 3) - target = np.array( - [ - [[[0, 1], [4, 5]], [[0, 1], [4, 5]], [[0, 1], [4, 5]]], - [[[2, 3], [6, 7]], [[2, 3], [6, 7]], [[2, 3], [6, 7]]], - [[[8, 9], [12, 13]], [[8, 9], [12, 13]], [[8, 9], [12, 13]]], - [[[10, 11], [14, 15]], [[10, 11], [14, 15]], [[10, 11], [14, 15]]], - ] - ) - - def test_make_blocks_vectorized(self): - """break tile into batch of subtiles""" - np.testing.assert_array_equal( - make_blocks_vectorized(self.source, 2), self.target - ) - - def test_unmake_blocks_vectorized(self): - """place batches back into tile (2d)""" - np.testing.assert_array_equal( - unmake_blocks_vectorized(self.target[:, 0, :, :], 2, 4, 4), self.source[0] - ) - - -@pytest.mark.parametrize("tile", examples) -def test_tiler_inspect_tile_size(tile): - assert inspect_tile(tile.filename).size == tile.size - - -@pytest.mark.parametrize("tile", examples) -def test_tiler_inspect_tile_subtiles(tile): - assert inspect_tile(tile.filename).subtiles == tile.subtiles - - -@pytest.mark.parametrize("tile", examples[0:1]) -def test_tiler_inspect_tile_subtile_not_divisible(tile): - with pytest.raises(ValueError): - inspect_tile(tile.filename, subtile_shape=(512, 211)) - - -@pytest.mark.parametrize( - "tile", - [ - str(example1.filename), - example1.filename, - rioxarray.open_rasterio(example1.filename).sel(band=1, drop=True), - ], -) -def test_tiler_infile_types(tile): - assert inspect_tile(tile).size == example1.size - - -def test_tiler_catch_bad_subtile_dims(): - with pytest.raises(ValueError): - Tiler(example1.filename, tile_shape=(8192, 8192), subtile_shape=(256, 250)) - - -@pytest.mark.parametrize("tile", examples) -def test_tiler_load_file_subtiles_to_use(tiler, tile): - tiler.load_file(tile.filename) - assert sum(tiler._subtiles_to_use) == prod(tile.subtiles) - - -@pytest.mark.parametrize("tile", examples) -def test_tiler_get_batches(tiler, tile): - tiler.load_file(tile.filename) - assert tiler.get_batches().shape == (prod(tile.subtiles), 3, 512, 512) - - -@pytest.mark.parametrize("tile", examples) -def test_tiler_put_batches(tiler, tile): - tiler.load_file(tile.filename) - batches = tiler.get_batches() - pred_batches = np.random.choice( - a=[1, 0], size=(len(batches), 512, 512), p=[0.1, 0.9] - ) # single layer - tiler.put_batches(pred_batches) - assert tiler._outdata.shape == (8192, 8192) - - -@pytest.mark.parametrize("tile", examples) -def test_tiler_write_file(tiler, tile): - tiler.load_file(tile.filename) - batches = tiler.get_batches() - pred_batches = np.random.choice( - a=[1, 0], size=(len(batches), 512, 512), p=[0.1, 0.9] - ) # single layer - tiler.put_batches(pred_batches) - - with tempfile.NamedTemporaryFile(suffix=".tif") as tmp: - tiler.write_file(tmp.name) - - assert rioxarray.open_rasterio(tmp.name).values.shape == (1, *tile.size)