Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 54 additions & 14 deletions deadtrees/deployment/inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union
from typing import Iterable, Union

import numpy as np
import torch
Expand All @@ -12,15 +12,18 @@


class Inference(ABC):
def __init__(self, model_file: Union[str, Path]) -> None:
self._model_file = (
model_file if isinstance(model_file, Path) else Path(model_file)
)
super().__init__()
# def __init__(self, model_file: Union[str, Path]) -> None:
# self._model_file = (
# model_file if isinstance(model_file, Path) else Path(model_file)
# )
# super().__init__()

@property
def model_file(self) -> str:
return self._model_file.name
if hasattr(self, "_models"):
return ",".join([m.name for m in self._models])
else:
return self._model.name

@abstractmethod
def run(self, input_tensor: torch.Tensor):
Expand All @@ -29,7 +32,11 @@ def run(self, input_tensor: torch.Tensor):

class PyTorchInference(Inference):
def __init__(self, model_file) -> None:
super().__init__(model_file)
# super().__init__(model_file)

self._model_file = (
model_file if isinstance(model_file, Path) else Path(model_file)
)

if self._model_file.suffix != ".ckpt":
raise ValueError(
Expand All @@ -44,7 +51,17 @@ def __init__(self, model_file) -> None:
# TODO: this is ugly, rename or restructure
self._model = model.model

def run(self, input_tensor, device: str = "cpu"):
@property
def channels(self) -> int:
return self._channels

@property
def classes(self) -> int:
return self._model.classes

def run(self, input_tensor, device: str = "cpu", return_raw: bool = False):
"""run the model, return either the raw logits of all models or the mode"""

if not isinstance(input_tensor, torch.Tensor):
raise TypeError("no pytorch tensor provided")

Expand All @@ -59,14 +76,22 @@ def run(self, input_tensor, device: str = "cpu"):
input_tensor = input_tensor[:, 0:3, :, :]
out = self._model(input_tensor)

return out.argmax(dim=1).squeeze()
if return_raw:
return out
else:
return out.argmax(dim=1).squeeze()


class PyTorchEnsembleInference:
def __init__(self, *model_files: Path):
self._models = []
self._channels = None

self._model_files = [
model_file if isinstance(model_file, Path) else Path(model_file)
for model_file in model_files
]

if len(model_files) % 2 == 0:
raise ValueError(
"PyTorchEnsembleInference requires an uneven number of models"
Expand All @@ -93,7 +118,16 @@ def __init__(self, *model_files: Path):
# TODO: this is ugly, rename or restructure
self._models.append(model.model)

def run(self, input_tensor, device: str = "cpu"):
@property
def channels(self) -> int:
return self._channels

@property
def classes(self) -> int:
return self._models[0].classes

def run(self, input_tensor, device: str = "cpu", return_raw: bool = False):
"""run the model(s), return either the raw logits of all models or the mode"""
if not isinstance(input_tensor, torch.Tensor):
raise TypeError("No PyTorch tensor provided")

Expand All @@ -107,15 +141,21 @@ def run(self, input_tensor, device: str = "cpu"):
outs = []
for model in self._models:
model.to(device)

with torch.no_grad():
out = model(input_tensor)

outs.append(out.argmax(dim=1).squeeze())
outs.append(out)

return torch.mode(torch.stack(outs, dim=1), axis=1)[0]
if return_raw:
# dims: m, bs, c, h, w
return torch.stack(outs, dim=0)
else:
# dims: bs, h, w
model_results = [out.argmax(dim=1).squeeze() for out in outs]
return torch.mode(torch.stack(model_results, dim=1), axis=1).values


# deprecated, do not use
class ONNXInference(Inference):
def __init__(self, model_file) -> None:
super().__init__(model_file)
Expand Down
170 changes: 0 additions & 170 deletions deadtrees/deployment/tiler.py

This file was deleted.

Loading