diff --git a/deepforest-evaluation.datagrid b/deepforest-evaluation.datagrid new file mode 100644 index 000000000..d645d2e86 Binary files /dev/null and b/deepforest-evaluation.datagrid differ diff --git a/deepforest-predictions.datagrid b/deepforest-predictions.datagrid new file mode 100644 index 000000000..8bf952cd4 Binary files /dev/null and b/deepforest-predictions.datagrid differ diff --git a/dev_requirements.txt b/dev_requirements.txt index a19c1a94c..5d890eacb 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -8,6 +8,7 @@ pydata-sphinx-theme geopandas huggingface_hub>=0.25.0 h5py +kangas matplotlib nbmake nbsphinx diff --git a/environment.yml b/environment.yml index 37d14329f..3ba260ded 100644 --- a/environment.yml +++ b/environment.yml @@ -8,6 +8,7 @@ dependencies: - geopandas - huggingface_hub>=0.25.0 - h5py + - kangas - matplotlib - nbmake - nbsphinx diff --git a/evaluation-metrics.datagrid b/evaluation-metrics.datagrid new file mode 100644 index 000000000..6d532261d Binary files /dev/null and b/evaluation-metrics.datagrid differ diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 1db109d76..1c303aec1 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -3,7 +3,7 @@ import os import typing import warnings - +import kangas as kg import numpy as np import pandas as pd import pytorch_lightning as pl @@ -35,6 +35,10 @@ class deepforest(pl.LightningModule, PyTorchModelHubMixin): existing_train_dataloader: a Pytorch dataloader that yields a tuple path, images, targets existing_val_dataloader: a Pytorch dataloader that yields a tuple path, images, targets + Notes: + Kangas visualization is supported in predict_* and evaluate methods with visualize_with="kangas". + Install Kangas with `pip install kangas` to enable this optional feature. + Returns: self: a deepforest pytorch lightning module """ @@ -126,6 +130,134 @@ def __init__(self, self.save_hyperparameters() + def visualize_evaluation_kangas( + self, + predictions: pd.DataFrame, + ground_df: pd.DataFrame, + root_dir: str, + evaluation_results: typing.Optional[dict] = None) -> None: + """Use Kangas to visualize predictions and ground truth data from + evaluation. + + Args: + predictions: Predictions DataFrame with columns "image_path", "xmin", "ymin", "xmax", "ymax", "label", "score". + ground_df: Ground truth DataFrame with columns "image_path", "xmin", "ymin", "xmax", "ymax", "label". + root_dir: Directory where images are stored. + evaluation_results: Optional dictionary of metrics (e.g., precision, recall) to display alongside. + + Returns: + None + """ + if kg is None: + print( + "Kangas is not installed. Run 'pip install kangas' to enable visualization." + ) + return + + kangas_data = [] + for _, row in ground_df.iterrows(): + image_path = os.path.join(root_dir, row["image_path"]) + kangas_data.append({ + "image_path": + image_path, + "bounding_boxes": [{ + "xmin": row["xmin"], + "ymin": row["ymin"], + "xmax": row["xmax"], + "ymax": row["ymax"], + "label": f"GT_{row['label']}", + "score": 1.0 + }] + }) + + for _, row in predictions.iterrows(): + image_path = os.path.join(root_dir, row["image_path"]) + kangas_data.append({ + "image_path": + image_path, + "bounding_boxes": [{ + "xmin": row["xmin"], + "ymin": row["ymin"], + "xmax": row["xmax"], + "ymax": row["ymax"], + "label": f"Pred_{row['label']}", + "score": row["score"] + }] + }) + + try: + grid = kg.DataGrid(kangas_data, name="DeepForest Evaluation") + grid.show() + except Exception as e: + print(f"Kangas evaluation failed :{e}") + + if evaluation_results: + metrics_data = [{ + "Metric": key, + "Value": value + } + for key, value in evaluation_results.items() + if isinstance(value, (int, float))] + if metrics_data: + metrics_grid = kg.DataGrid(metrics_data, name="Evaluation Metrics") + metrics_grid.show() + + def visualize_kangas(self, + predictions: typing.Union[pd.DataFrame, + typing.List[pd.DataFrame]], + image_paths: typing.Optional[typing.List[str]] = None) -> None: + """Visualize predictions using Kangas. + + Args: + predictions: DataFrame or list of DataFrames with "image_path", "xmin", "ymin", "xmax", "ymax", "label", "score" + image_paths: Optional list of image paths if not included in predictions + """ + if kg is None: + print( + "Kangas is not installed. Run 'pip install kangas' to enable visualization." + ) + return + + # Handle different prediction formats + if isinstance(predictions, pd.DataFrame): + df = predictions + elif isinstance(predictions, list) and all( + isinstance(p, pd.DataFrame) for p in predictions): + df = pd.concat(predictions, ignore_index=True) + else: + raise ValueError("Predictions must be a DataFrame or list of DataFrames") + + # Ensure image paths are available + if "image_path" not in df.columns: + if not image_paths: + raise ValueError("image_paths must be provided if not in predictions") + if len(image_paths) != (len(predictions) + if isinstance(predictions, list) else 1): + raise ValueError("Length of image_paths must match predictions") + df["image_path"] = image_paths if isinstance( + predictions, list) else [image_paths[0]] * len(df) + + # Use root_dir if available + root_dir = getattr(df, "root_dir", None) if hasattr(df, "root_dir") else None + if root_dir and not os.path.isabs(df["image_path"].iloc[0]): + df["image_path"] = df["image_path"].apply(lambda x: os.path.join(root_dir, x)) + + # Group predictions by image_path + grouped = df.groupby("image_path") + kangas_data = [{ + "image_path": + image_path, + "bounding_boxes": + group[["xmin", "ymin", "xmax", "ymax", "label", + "score"]].to_dict(orient="records") + } for image_path, group in grouped] + + try: + grid = kg.DataGrid(kangas_data, name="DeepForest Predictions") + grid.show() + except Exception as e: + print(f"Kangas visualization failed: {e}") + def load_model(self, model_name="weecology/deepforest-tree", revision='main'): """Loads a model that has already been pretrained for a specific task, like tree crown detection. @@ -336,7 +468,10 @@ def val_dataloader(self): batch_size=self.config["batch_size"]) return loader - def predict_dataloader(self, ds): + def predict_dataloader( + self, + ds: torch.utils.data.Dataset, + visualize_with: typing.Optional[str] = None) -> torch.utils.data.DataLoader: """Create a PyTorch dataloader for prediction. Args: @@ -350,6 +485,22 @@ def predict_dataloader(self, ds): shuffle=False, num_workers=self.config["workers"]) + if visualize_with == "kangas": + self.model.eval() # Ensure evaluation mode + predictions = [] + for batch_idx, batch in enumerate(loader): + if isinstance(batch, (tuple, list)) and len(batch) > 1: + images, _, paths = batch + else: + images = batch if isinstance(batch, (tuple, list)) else batch + paths = [ds.paths[i] for i in range(batch_idx * self.config["batch_size"], + min((batch_idx + 1) * self.config["batch_size"], len(ds.paths)))] + batch_predictions = self.predict_step(images, batch_idx) + for pred, path in zip(batch_predictions, paths): + pred['image_path'] = path + predictions.append(pred) + if predictions: + self.visualize_kangas(predictions) return loader def predict_image(self, @@ -357,7 +508,8 @@ def predict_image(self, path: typing.Optional[str] = None, return_plot: bool = False, color: typing.Optional[tuple] = (0, 165, 255), - thickness: int = 1): + thickness: int = 1, + visualize_with: typing.Optional[str] = None): """Predict a single image with a deepforest model. Deprecation warning: The 'return_plot', and related 'color' and 'thickness' arguments @@ -428,10 +580,19 @@ def predict_image(self, else: root_dir = os.path.dirname(path) result = utilities.read_file(result, root_dir=root_dir) + # Visualize if requested + if visualize_with == "kangas": + self.visualize_kangas(result) return result - def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1): + def predict_file(self, + csv_file, + root_dir, + savedir=None, + color=None, + thickness=1, + visualize_with=None): """Create a dataset and predict entire annotation file Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position. Image_path is the @@ -469,6 +630,9 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1 thickness=thickness) results.root_dir = root_dir + # Visualize if requested + if visualize_with == "kangas": + self.visualize_kangas(results) return results @@ -487,7 +651,8 @@ def predict_tile(self, thickness=1, crop_model=None, crop_transform=None, - crop_augment=False): + crop_augment=False, + visualize_with=None): """For images too large to input into the model, predict_tile cuts the image into overlapping windows, predicts trees on each window and reassambles into a single array. @@ -625,6 +790,11 @@ def predict_tile(self, root_dir = os.path.dirname(raster_path) results = utilities.read_file(results, root_dir=root_dir) + if visualize_with == "kangas" and mosaic: + self.visualize_kangas(results) + elif visualize_with == "kangas": + print("Kangas visualization only supported with mosaic=True.") + return results def training_step(self, batch, batch_idx): @@ -819,6 +989,14 @@ def on_validation_epoch_end(self): if empty_accuracy is not None: results["empty_frame_accuracy"] = empty_accuracy + # Check config for Kangas visualization + if self.config.get("visualize_with") == "kangas": + self.visualize_evaluation_kangas( + predictions=self.predictions_df, + ground_df=ground_df, + root_dir=self.config["validation"]["root_dir"], + evaluation_results=results) + # Log each key value pair of the results dict if not results["class_recall"] is None: for key, value in results.items(): @@ -844,16 +1022,45 @@ def on_validation_epoch_end(self): except MisconfigurationException: pass - def predict_step(self, batch, batch_idx): - batch_results = self.model(batch) + def predict_step( + self, + batch: typing.Any, + batch_idx: int, + visualize_with: typing.Optional[str] = None, + image_paths: typing.Optional[typing.List[str]] = None + ) -> typing.List[pd.DataFrame]: + + self.model.eval() # Ensure evaluation mode + with torch.no_grad(): + batch_results = self.model(batch) # Returns list of dicts: [{"boxes": tensor, "labels": tensor, "scores": tensor}, ...] results = [] for result in batch_results: - boxes = visualize.format_boxes(result) - results.append(boxes) + # Ensure result is a dict from RetinaNet + if isinstance(result, dict): + boxes = visualize.format_boxes(result) + results.append(boxes) + else: + # Handle unexpected format (e.g., empty or malformed output) + empty_df = pd.DataFrame(columns=["xmin", "ymin", "xmax", "ymax", "label", "score"]) + results.append(empty_df) + + if visualize_with == "kangas" and image_paths: + if len(image_paths) != len(results): + raise ValueError(f"Length of image_paths ({len(image_paths)}) must match predictions ({len(results)})") + for pred, path in zip(results, image_paths): + pred['image_path'] = path + self.visualize_kangas(results) + return results - def predict_batch(self, images, preprocess_fn=None): + def predict_batch( + self, + images: typing.Union[torch.Tensor, np.ndarray], + preprocess_fn: typing.Optional[typing.Callable] = None, + visualize_with: typing.Optional[str] = None, + image_paths: typing.Optional[typing.List[str]] = None + ) -> typing.List[pd.DataFrame]: """Predict a batch of images with the deepforest model. Args: @@ -884,9 +1091,23 @@ def predict_batch(self, images, preprocess_fn=None): with torch.no_grad(): predictions = self.predict_step(images, 0) - #convert predictions to dataframes - results = [utilities.read_file(pred) for pred in predictions if pred is not None] + # Handle predictions, including empty ones + results = [] + for i, pred in enumerate(predictions): + if pred is None or pred.empty: + # Create an empty DataFrame with expected columns + empty_df = pd.DataFrame(columns=["xmin", "ymin", "xmax", "ymax", "label", "score", "image_path"]) + if image_paths and i < len(image_paths): + empty_df["image_path"] = [image_paths[i]] + results.append(empty_df) + else: + pred_df = utilities.read_file(pred) + if image_paths and i < len(image_paths): + pred_df["image_path"] = image_paths[i] + results.append(pred_df) + if visualize_with == "kangas": + self.visualize_kangas(results) return results def configure_optimizers(self): @@ -945,7 +1166,12 @@ def configure_optimizers(self): else: return optimizer - def evaluate(self, csv_file, root_dir, iou_threshold=None, savedir=None): + def evaluate(self, + csv_file, + root_dir, + iou_threshold=None, + savedir=None, + visualize_with=None): """Compute intersection-over-union and precision/recall for a given iou_threshold. @@ -971,5 +1197,12 @@ def evaluate(self, csv_file, root_dir, iou_threshold=None, savedir=None): root_dir=root_dir, iou_threshold=iou_threshold, numeric_to_label_dict=self.numeric_to_label_dict) + # If user wants Kangas visualization + if visualize_with == "kangas": + self.visualize_evaluation_kangas(predictions=predictions, + ground_df=ground_df, + root_dir=root_dir, + evaluation_results=results) return results + diff --git a/tests/test_main.py b/tests/test_main.py index 96fa67c1a..7e3cc07db 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -348,7 +348,7 @@ def test_predict_tile_from_array(m, raster_path): m.create_trainer() prediction = m.predict_tile(image=image, patch_size=300) - + assert not prediction.empty @@ -719,13 +719,13 @@ def test_batch_prediction(m, raster_path): tile = np.array(Image.open(raster_path)) ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=300) dl = DataLoader(ds, batch_size=3) - + # Perform prediction predictions = [] for batch in dl: prediction = m.predict_batch(batch) predictions.append(prediction) - + # Check results assert len(predictions) == len(dl) for batch_pred in predictions: @@ -739,21 +739,21 @@ def test_batch_inference_consistency(m, raster_path): tile = np.array(Image.open(raster_path)) ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=300) dl = DataLoader(ds, batch_size=4) - + batch_predictions = [] for batch in dl: prediction = m.predict_batch(batch) batch_predictions.extend(prediction) - + single_predictions = [] for image in ds: image = image.permute(1,2,0).numpy() * 255 prediction = m.predict_image(image=image) single_predictions.append(prediction) - + batch_df = pd.concat(batch_predictions, ignore_index=True) single_df = pd.concat(single_predictions, ignore_index=True) - + # Make all xmin, ymin, xmax, ymax integers for col in ["xmin", "ymin", "xmax", "ymax"]: batch_df[col] = batch_df[col].astype(int) @@ -864,4 +864,102 @@ def test_evaluate_on_epoch_interval(m): m.create_trainer() m.trainer.fit(m) assert m.trainer.logged_metrics["box_precision"] - assert m.trainer.logged_metrics["box_recall"] \ No newline at end of file + assert m.trainer.logged_metrics["box_recall"] + +# Test predict_dataloader with Kangas +def test_predict_dataloader_kangas(m, tmpdir): + """Test predict_dataloader triggers Kangas visualization.""" + csv_file = get_data("example.csv") + ds = dataset.TreeDataset(csv_file=csv_file, root_dir=os.path.dirname(csv_file), transforms=None, train=False) + ds.paths = [os.path.join(os.path.dirname(csv_file), img) for img in ds.annotations.image_path.unique()] + loader = m.predict_dataloader(ds, visualize_with="kangas") + assert isinstance(loader, torch.utils.data.DataLoader), "Should return a DataLoader" + # Kangas UI should open; we verify no crash and correct type + +# Test predict_image with Kangas +def test_predict_image_kangas(m, tmpdir): + """Test predict_image triggers Kangas visualization and returns correct output.""" + image_path = get_data("2019_YELL_2_528000_4978000_image_crop2.png") + result = m.predict_image(path=image_path, visualize_with="kangas") + assert isinstance(result, pd.DataFrame) or result is None, "Should return DataFrame or None" + if result is not None: + assert "image_path" in result.columns, "Result should include image_path" + assert not result.empty, "Should predict trees with pre-trained model" + +# Test predict_file with Kangas +def test_predict_file_kangas(m, tmpdir): + """Test predict_file triggers Kangas visualization and returns correct output.""" + csv_file = get_data("OSBS_029.csv") + root_dir = os.path.dirname(csv_file) + result = m.predict_file(csv_file=csv_file, root_dir=root_dir, visualize_with="kangas") + assert isinstance(result, pd.DataFrame), "Should return a DataFrame" + assert "image_path" in result.columns, "Result should include image_path" + assert not result.empty, "Should predict trees with pre-trained model" + +# Test predict_tile with Kangas +def test_predict_tile_kangas(m, raster_path): + """Test predict_tile triggers Kangas visualization with mosaic=True.""" + result = m.predict_tile(raster_path=raster_path, patch_size=300, patch_overlap=0.1, visualize_with="kangas") + assert isinstance(result, pd.DataFrame) or result is None, "Should return DataFrame or None" + if result is not None: + assert "image_path" in result.columns, "Result should include image_path" + assert not result.empty, "Should predict trees with pre-trained model" + +# Test predict_tile no visualization with mosaic=False +def test_predict_tile_kangas_no_mosaic(m, raster_path): + """Test predict_tile doesn’t visualize with mosaic=False.""" + result = m.predict_tile(raster_path=raster_path, patch_size=300, patch_overlap=0.1, mosaic=False, visualize_with="kangas") + assert isinstance(result, list), "Should return a list of (prediction, crop) tuples" + assert all(isinstance(r[0], pd.DataFrame) and isinstance(r[1], np.ndarray) for r in result), "Each item should be (DataFrame, array)" + # Kangas won’t trigger; we verify output type only + +# Test predict_step with Kangas +def test_predict_step_kangas(m, tmpdir): + """Test predict_step triggers Kangas visualization.""" + image_path = get_data("2019_YELL_2_528000_4978000_image_crop2.png") + image = np.array(Image.open(image_path).convert("RGB")).astype("float32") + batch = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255 + result = m.predict_step(batch, 0, visualize_with="kangas", image_paths=[image_path]) + assert isinstance(result, list), "Should return a list of predictions" + assert all(isinstance(r, pd.DataFrame) for r in result), "Each item should be a DataFrame" + + +# Test predict_batch with Kangas +def test_predict_batch_kangas(m, tmpdir): + """Test predict_batch triggers Kangas visualization.""" + image_path = get_data("2019_YELL_2_528000_4978000_image_crop2.png") + image = np.array(Image.open(image_path).convert("RGB")).astype("float32") + images = np.stack([image] * 2) # Batch of 2 images + image_paths = [image_path, image_path] + result = m.predict_batch(images, visualize_with="kangas", image_paths=image_paths) + assert isinstance(result, list), "Should return a list of DataFrames" + assert all(isinstance(r, pd.DataFrame) for r in result), "Each item should be a DataFrame" + assert all("image_path" in r.columns for r in result if not r.empty), "Results should include image_path" + assert any(not r.empty for r in result), "At least one result should have predictions" + +# Test evaluate with Kangas +def test_evaluate_kangas(m, tmpdir): + """Test evaluate triggers Kangas visualization with ground truth and metrics.""" + csv_file = get_data("OSBS_029.csv") + root_dir = os.path.dirname(csv_file) + result = m.evaluate(csv_file=csv_file, root_dir=root_dir, visualize_with="kangas") + assert isinstance(result, dict), "Should return a dict of evaluation metrics" + assert "box_precision" in result, "Should include precision" + assert "box_recall" in result, "Should include recall" + assert isinstance(result["results"], pd.DataFrame), "Results should be a DataFrame" + # Kangas UI should open with predictions, ground truth, and metrics; we verify no crash + +# Test without Kangas installed (mock Kangas unavailable) +def test_predict_image_no_kangas(m, tmpdir, monkeypatch): + """Test predict_image handles missing Kangas gracefully.""" + monkeypatch.setattr("deepforest.main.kg", None) # Simulate Kangas not installed + image_path = get_data("2019_YELL_2_528000_4978000_image_crop2.png") + result = m.predict_image(path=image_path, visualize_with="kangas") + assert isinstance(result, pd.DataFrame) or result is None, "Should return DataFrame or None despite no Kangas" + +# Test empty predictions with Kangas +def test_predict_image_empty_kangas(m, tmpdir): + """Test predict_image with empty predictions still triggers Kangas.""" + image = np.zeros((400, 400, 3), dtype=np.float32) # Black image, likely no predictions + result = m.predict_image(image=image, visualize_with="kangas") + assert result is None or (isinstance(result, pd.DataFrame) and result.empty), "Should return None or empty DataFrame"