Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added deepforest-evaluation.datagrid
Binary file not shown.
Binary file added deepforest-predictions.datagrid
Binary file not shown.
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pydata-sphinx-theme
geopandas
huggingface_hub>=0.25.0
h5py
kangas
matplotlib
nbmake
nbsphinx
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- geopandas
- huggingface_hub>=0.25.0
- h5py
- kangas
- matplotlib
- nbmake
- nbsphinx
Expand Down
Binary file added evaluation-metrics.datagrid
Binary file not shown.
259 changes: 246 additions & 13 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import typing
import warnings

import kangas as kg
import numpy as np
import pandas as pd
import pytorch_lightning as pl
Expand Down Expand Up @@ -35,6 +35,10 @@ class deepforest(pl.LightningModule, PyTorchModelHubMixin):
existing_train_dataloader: a Pytorch dataloader that yields a tuple path, images, targets
existing_val_dataloader: a Pytorch dataloader that yields a tuple path, images, targets

Notes:
Kangas visualization is supported in predict_* and evaluate methods with visualize_with="kangas".
Install Kangas with `pip install kangas` to enable this optional feature.

Returns:
self: a deepforest pytorch lightning module
"""
Expand Down Expand Up @@ -126,6 +130,134 @@ def __init__(self,

self.save_hyperparameters()

def visualize_evaluation_kangas(
self,
predictions: pd.DataFrame,
ground_df: pd.DataFrame,
root_dir: str,
evaluation_results: typing.Optional[dict] = None) -> None:
"""Use Kangas to visualize predictions and ground truth data from
evaluation.

Args:
predictions: Predictions DataFrame with columns "image_path", "xmin", "ymin", "xmax", "ymax", "label", "score".
ground_df: Ground truth DataFrame with columns "image_path", "xmin", "ymin", "xmax", "ymax", "label".
root_dir: Directory where images are stored.
evaluation_results: Optional dictionary of metrics (e.g., precision, recall) to display alongside.

Returns:
None
"""
if kg is None:
print(
"Kangas is not installed. Run 'pip install kangas' to enable visualization."
)
return

kangas_data = []
for _, row in ground_df.iterrows():
image_path = os.path.join(root_dir, row["image_path"])
kangas_data.append({
"image_path":
image_path,
"bounding_boxes": [{
"xmin": row["xmin"],
"ymin": row["ymin"],
"xmax": row["xmax"],
"ymax": row["ymax"],
"label": f"GT_{row['label']}",
"score": 1.0
}]
})

for _, row in predictions.iterrows():
image_path = os.path.join(root_dir, row["image_path"])
kangas_data.append({
"image_path":
image_path,
"bounding_boxes": [{
"xmin": row["xmin"],
"ymin": row["ymin"],
"xmax": row["xmax"],
"ymax": row["ymax"],
"label": f"Pred_{row['label']}",
"score": row["score"]
}]
})

try:
grid = kg.DataGrid(kangas_data, name="DeepForest Evaluation")
grid.show()
except Exception as e:
print(f"Kangas evaluation failed :{e}")

if evaluation_results:
metrics_data = [{
"Metric": key,
"Value": value
}
for key, value in evaluation_results.items()
if isinstance(value, (int, float))]
if metrics_data:
metrics_grid = kg.DataGrid(metrics_data, name="Evaluation Metrics")
metrics_grid.show()

def visualize_kangas(self,
predictions: typing.Union[pd.DataFrame,
typing.List[pd.DataFrame]],
image_paths: typing.Optional[typing.List[str]] = None) -> None:
"""Visualize predictions using Kangas.

Args:
predictions: DataFrame or list of DataFrames with "image_path", "xmin", "ymin", "xmax", "ymax", "label", "score"
image_paths: Optional list of image paths if not included in predictions
"""
if kg is None:
print(
"Kangas is not installed. Run 'pip install kangas' to enable visualization."
)
return

# Handle different prediction formats
if isinstance(predictions, pd.DataFrame):
df = predictions
elif isinstance(predictions, list) and all(
isinstance(p, pd.DataFrame) for p in predictions):
df = pd.concat(predictions, ignore_index=True)
else:
raise ValueError("Predictions must be a DataFrame or list of DataFrames")

# Ensure image paths are available
if "image_path" not in df.columns:
if not image_paths:
raise ValueError("image_paths must be provided if not in predictions")
if len(image_paths) != (len(predictions)
if isinstance(predictions, list) else 1):
raise ValueError("Length of image_paths must match predictions")
df["image_path"] = image_paths if isinstance(
predictions, list) else [image_paths[0]] * len(df)

# Use root_dir if available
root_dir = getattr(df, "root_dir", None) if hasattr(df, "root_dir") else None
if root_dir and not os.path.isabs(df["image_path"].iloc[0]):
df["image_path"] = df["image_path"].apply(lambda x: os.path.join(root_dir, x))

# Group predictions by image_path
grouped = df.groupby("image_path")
kangas_data = [{
"image_path":
image_path,
"bounding_boxes":
group[["xmin", "ymin", "xmax", "ymax", "label",
"score"]].to_dict(orient="records")
} for image_path, group in grouped]

try:
grid = kg.DataGrid(kangas_data, name="DeepForest Predictions")
grid.show()
except Exception as e:
print(f"Kangas visualization failed: {e}")

def load_model(self, model_name="weecology/deepforest-tree", revision='main'):
"""Loads a model that has already been pretrained for a specific task,
like tree crown detection.
Expand Down Expand Up @@ -336,7 +468,10 @@ def val_dataloader(self):
batch_size=self.config["batch_size"])
return loader

def predict_dataloader(self, ds):
def predict_dataloader(
self,
ds: torch.utils.data.Dataset,
visualize_with: typing.Optional[str] = None) -> torch.utils.data.DataLoader:
"""Create a PyTorch dataloader for prediction.

Args:
Expand All @@ -350,14 +485,31 @@ def predict_dataloader(self, ds):
shuffle=False,
num_workers=self.config["workers"])

if visualize_with == "kangas":
self.model.eval() # Ensure evaluation mode
predictions = []
for batch_idx, batch in enumerate(loader):
if isinstance(batch, (tuple, list)) and len(batch) > 1:
images, _, paths = batch
else:
images = batch if isinstance(batch, (tuple, list)) else batch
paths = [ds.paths[i] for i in range(batch_idx * self.config["batch_size"],
min((batch_idx + 1) * self.config["batch_size"], len(ds.paths)))]
batch_predictions = self.predict_step(images, batch_idx)
for pred, path in zip(batch_predictions, paths):
pred['image_path'] = path
predictions.append(pred)
if predictions:
self.visualize_kangas(predictions)
return loader

def predict_image(self,
image: typing.Optional[np.ndarray] = None,
path: typing.Optional[str] = None,
return_plot: bool = False,
color: typing.Optional[tuple] = (0, 165, 255),
thickness: int = 1):
thickness: int = 1,
visualize_with: typing.Optional[str] = None):
"""Predict a single image with a deepforest model.

Deprecation warning: The 'return_plot', and related 'color' and 'thickness' arguments
Expand Down Expand Up @@ -428,10 +580,19 @@ def predict_image(self,
else:
root_dir = os.path.dirname(path)
result = utilities.read_file(result, root_dir=root_dir)
# Visualize if requested
if visualize_with == "kangas":
self.visualize_kangas(result)

return result

def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1):
def predict_file(self,
csv_file,
root_dir,
savedir=None,
color=None,
thickness=1,
visualize_with=None):
"""Create a dataset and predict entire annotation file Csv file format
is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax"
for the image name and bounding box position. Image_path is the
Expand Down Expand Up @@ -469,6 +630,9 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1
thickness=thickness)

results.root_dir = root_dir
# Visualize if requested
if visualize_with == "kangas":
self.visualize_kangas(results)

return results

Expand All @@ -487,7 +651,8 @@ def predict_tile(self,
thickness=1,
crop_model=None,
crop_transform=None,
crop_augment=False):
crop_augment=False,
visualize_with=None):
"""For images too large to input into the model, predict_tile cuts the
image into overlapping windows, predicts trees on each window and
reassambles into a single array.
Expand Down Expand Up @@ -625,6 +790,11 @@ def predict_tile(self,
root_dir = os.path.dirname(raster_path)
results = utilities.read_file(results, root_dir=root_dir)

if visualize_with == "kangas" and mosaic:
self.visualize_kangas(results)
elif visualize_with == "kangas":
print("Kangas visualization only supported with mosaic=True.")

return results

def training_step(self, batch, batch_idx):
Expand Down Expand Up @@ -819,6 +989,14 @@ def on_validation_epoch_end(self):
if empty_accuracy is not None:
results["empty_frame_accuracy"] = empty_accuracy

# Check config for Kangas visualization
if self.config.get("visualize_with") == "kangas":
self.visualize_evaluation_kangas(
predictions=self.predictions_df,
ground_df=ground_df,
root_dir=self.config["validation"]["root_dir"],
evaluation_results=results)

# Log each key value pair of the results dict
if not results["class_recall"] is None:
for key, value in results.items():
Expand All @@ -844,16 +1022,45 @@ def on_validation_epoch_end(self):
except MisconfigurationException:
pass

def predict_step(self, batch, batch_idx):
batch_results = self.model(batch)
def predict_step(
self,
batch: typing.Any,
batch_idx: int,
visualize_with: typing.Optional[str] = None,
image_paths: typing.Optional[typing.List[str]] = None
) -> typing.List[pd.DataFrame]:

self.model.eval() # Ensure evaluation mode
with torch.no_grad():
batch_results = self.model(batch) # Returns list of dicts: [{"boxes": tensor, "labels": tensor, "scores": tensor}, ...]

results = []
for result in batch_results:
boxes = visualize.format_boxes(result)
results.append(boxes)
# Ensure result is a dict from RetinaNet
if isinstance(result, dict):
boxes = visualize.format_boxes(result)
results.append(boxes)
else:
# Handle unexpected format (e.g., empty or malformed output)
empty_df = pd.DataFrame(columns=["xmin", "ymin", "xmax", "ymax", "label", "score"])
results.append(empty_df)

if visualize_with == "kangas" and image_paths:
if len(image_paths) != len(results):
raise ValueError(f"Length of image_paths ({len(image_paths)}) must match predictions ({len(results)})")
for pred, path in zip(results, image_paths):
pred['image_path'] = path
self.visualize_kangas(results)

return results

def predict_batch(self, images, preprocess_fn=None):
def predict_batch(
self,
images: typing.Union[torch.Tensor, np.ndarray],
preprocess_fn: typing.Optional[typing.Callable] = None,
visualize_with: typing.Optional[str] = None,
image_paths: typing.Optional[typing.List[str]] = None
) -> typing.List[pd.DataFrame]:
"""Predict a batch of images with the deepforest model.

Args:
Expand Down Expand Up @@ -884,9 +1091,23 @@ def predict_batch(self, images, preprocess_fn=None):
with torch.no_grad():
predictions = self.predict_step(images, 0)

#convert predictions to dataframes
results = [utilities.read_file(pred) for pred in predictions if pred is not None]
# Handle predictions, including empty ones
results = []
for i, pred in enumerate(predictions):
if pred is None or pred.empty:
# Create an empty DataFrame with expected columns
empty_df = pd.DataFrame(columns=["xmin", "ymin", "xmax", "ymax", "label", "score", "image_path"])
if image_paths and i < len(image_paths):
empty_df["image_path"] = [image_paths[i]]
results.append(empty_df)
else:
pred_df = utilities.read_file(pred)
if image_paths and i < len(image_paths):
pred_df["image_path"] = image_paths[i]
results.append(pred_df)

if visualize_with == "kangas":
self.visualize_kangas(results)
return results

def configure_optimizers(self):
Expand Down Expand Up @@ -945,7 +1166,12 @@ def configure_optimizers(self):
else:
return optimizer

def evaluate(self, csv_file, root_dir, iou_threshold=None, savedir=None):
def evaluate(self,
csv_file,
root_dir,
iou_threshold=None,
savedir=None,
visualize_with=None):
"""Compute intersection-over-union and precision/recall for a given
iou_threshold.

Expand All @@ -971,5 +1197,12 @@ def evaluate(self, csv_file, root_dir, iou_threshold=None, savedir=None):
root_dir=root_dir,
iou_threshold=iou_threshold,
numeric_to_label_dict=self.numeric_to_label_dict)
# If user wants Kangas visualization
if visualize_with == "kangas":
self.visualize_evaluation_kangas(predictions=predictions,
ground_df=ground_df,
root_dir=root_dir,
evaluation_results=results)

return results

Loading
Loading