Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 35 additions & 29 deletions tests/test_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,15 @@ def forward(self, x):
@pytest.fixture
def example_df():
"""
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ numeric ┆ text ┆ list_col ┆ nomal_obj ┆ self_obj β”‚
β”‚ --- ┆ --- ┆ --- ┆ --- ┆ --- β”‚
β”‚ i64 ┆ str ┆ list[i64] ┆ object ┆ object β”‚
β•žβ•β•β•β•β•β•β•β•β•β•ͺ══════β•ͺ═══════════β•ͺ═══════════β•ͺ═════════════║
β”‚ 1 ┆ a ┆ [1, 2] ┆ example ┆ 100 K β”‚
β”‚ 2 ┆ null ┆ [3] ┆ dataframe ┆ null β”‚
β”‚ null ┆ c ┆ null ┆ null ┆ 0.00 Β± 0.00 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ numeric ┆ text ┆ list_col ┆ null_col ┆ nomal_obj β”‚
β”‚ --- ┆ --- ┆ --- ┆ --- ┆ --- β”‚
β”‚ i64 ┆ str ┆ list[i64] ┆ null ┆ object β”‚
β•žβ•β•β•β•β•β•β•β•β•β•ͺ══════β•ͺ═══════════β•ͺ══════════β•ͺ═══════════║
β”‚ 1 ┆ a ┆ [1, 2] ┆ null ┆ example β”‚
β”‚ 2 ┆ null ┆ [3] ┆ null ┆ dataframe β”‚
β”‚ null ┆ c ┆ null ┆ null ┆ null β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""

from polars import Object as pl_obj
Expand All @@ -175,6 +175,7 @@ def example_df():
"numeric": [1, 2, None],
"text": ["a", None, "c"],
"list_col": [[1, 2], [3], None],
"null_col": [None, None, None],
"nomal_obj": [
Text("example"),
Text("dataframe"),
Expand Down Expand Up @@ -1339,16 +1340,16 @@ def test_df2tb_structure(self, simple_tabular_renderer, example_df) -> None:
res = simple_tabular_renderer.df2tb(example_df, show_raw=False)

assert isinstance(res, Table)
assert len(res.columns) == 5
assert len(res.columns) == 6
assert res.row_count == 3

tb_headers = [col_obj.header for col_obj in res.columns]
assert tb_headers == ["numeric", "text", "list_col", "nomal_obj", "self_obj"]
assert tb_headers == ["numeric", "text", "list_col", "null_col", "nomal_obj", "self_obj"]

assert str(example_df[0, 0]) == self.tbval_getter(0, 0, res)
assert str(example_df[2, 1]) == self.tbval_getter(2, 1, res)
assert str(example_df[1, 2].to_list()) == self.tbval_getter(1, 2, res)
assert str(example_df[0, 3]) == self.tbval_getter(0, 3, res)
assert str(example_df[0, 4]) == self.tbval_getter(0, 4, res)

# ιͺŒθ―ζ ·εΌεΊ”用调用
mock_apply.assert_any_call(
Expand Down Expand Up @@ -1378,11 +1379,16 @@ def test_df2tb_none_handling(self, simple_tabular_renderer, example_df) -> None:
# list none
assert self.tbval_getter(2, 2, res) == "-"

# normal object none
# null none
assert self.tbval_getter(0, 3, res) == "-"
assert self.tbval_getter(1, 3, res) == "-"
assert self.tbval_getter(2, 3, res) == "-"

# normal object none
assert self.tbval_getter(2, 4, res) == "-"

# self object none
assert self.tbval_getter(1, 4, res) == "test none_str"
assert self.tbval_getter(1, 5, res) == "test none_str"

def test_df2tb_show_raw(self, simple_tabular_renderer, example_df) -> None:
"""Test whether the show_raw argument works well"""
Expand All @@ -1394,17 +1400,17 @@ def test_df2tb_show_raw(self, simple_tabular_renderer, example_df) -> None:
assert self.tbval_getter(0, 0, noraml_res) == "1"
assert self.tbval_getter(0, 1, noraml_res) == "a"
assert self.tbval_getter(1, 2, noraml_res) == "[3]"
assert self.tbval_getter(1, 3, noraml_res) == "dataframe"
assert self.tbval_getter(0, 4, noraml_res) == "100 K"
assert self.tbval_getter(2, 4, noraml_res) == "0.00 Β± 0.00"
assert self.tbval_getter(1, 4, noraml_res) == "dataframe"
assert self.tbval_getter(0, 5, noraml_res) == "100 K"
assert self.tbval_getter(2, 5, noraml_res) == "0.00 Β± 0.00"

# verify raw display
assert self.tbval_getter(0, 0, raw_res) == "1"
assert self.tbval_getter(0, 1, raw_res) == "a"
assert self.tbval_getter(1, 2, raw_res) == "[3]"
assert self.tbval_getter(1, 3, raw_res) == "dataframe"
assert self.tbval_getter(0, 4, raw_res) == "100000.0"
assert self.tbval_getter(2, 4, raw_res) == "0.0"
assert self.tbval_getter(1, 4, raw_res) == "dataframe"
assert self.tbval_getter(0, 5, raw_res) == "100000.0"
assert self.tbval_getter(2, 5, raw_res) == "0.0"

def test_clear(self, simple_tabular_renderer, example_df, monkeypatch) -> None:
"""Test the stat dataframe clearing logic"""
Expand Down Expand Up @@ -1612,7 +1618,7 @@ def test_new_col(self, simple_tabular_renderer, example_df) -> None:
col_func=lambda x: ["test"] * len(x),
col_idx=0,
)
assert new_df.shape == (3, 6)
assert new_df.shape == (3, 7)
assert new_df.columns[0] == "new_col"
assert new_df["new_col"].to_list() == ["test"] * 3

Expand All @@ -1623,9 +1629,9 @@ def test_new_col(self, simple_tabular_renderer, example_df) -> None:
col_func=lambda df: df.drop_in_place(name="numeric"),
col_idx=0,
)
assert new_df.shape == (3, 6)
assert example_df.shape == (3, 5)
assert example_df.columns == ["numeric", "text", "list_col", "nomal_obj", "self_obj"]
assert new_df.shape == (3, 7)
assert example_df.shape == (3, 6)
assert example_df.columns == ["numeric", "text", "list_col", "null_col", "nomal_obj", "self_obj"]
assert new_df["origin_numeric"].to_list() == example_df["numeric"].to_list()

# verify col_idx
Expand All @@ -1636,7 +1642,7 @@ def test_new_col(self, simple_tabular_renderer, example_df) -> None:
col_func=lambda x: ["test"] * len(x),
col_idx=1,
)
assert new_df.shape == (3, 6)
assert new_df.shape == (3, 7)
assert new_df.columns[1] == "new_col"

# non-negative and out of range (add last)
Expand All @@ -1646,8 +1652,8 @@ def test_new_col(self, simple_tabular_renderer, example_df) -> None:
col_func=lambda x: ["test"] * len(x),
col_idx=8,
)
assert new_df.shape == (3, 6)
assert new_df.columns[5] == "new_col"
assert new_df.shape == (3, 7)
assert new_df.columns[6] == "new_col"

# negative and in range
new_df = new_col(
Expand All @@ -1656,7 +1662,7 @@ def test_new_col(self, simple_tabular_renderer, example_df) -> None:
col_func=lambda x: ["test"] * len(x),
col_idx=-1,
)
assert new_df.shape == (3, 6)
assert new_df.shape == (3, 7)
assert new_df.columns[-1] == "new_col"

# negative and out of range (add first)
Expand All @@ -1666,7 +1672,7 @@ def test_new_col(self, simple_tabular_renderer, example_df) -> None:
col_func=lambda x: ["test"] * len(x),
col_idx=-9,
)
assert new_df.shape == (3, 6)
assert new_df.shape == (3, 7)
assert new_df.columns[0] == "new_col"

# verify return_type is correctly applied
Expand Down
8 changes: 6 additions & 2 deletions torchmeter/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,8 @@ def col_args(self, custom_args: Dict[str, Any]) -> None:
self.col_args.mark_change()

def df2tb(self, df: DataFrame, show_raw: bool = False) -> Table:
from polars import Null as pl_null

# create rich table
tb_fields = df.columns
tb = apply_setting(
Expand All @@ -670,8 +672,10 @@ def df2tb(self, df: DataFrame, show_raw: bool = False) -> Table:
return tb

# collect each column's none replacing string
col_none_str = {col_name: getattr(df[col_name].drop_nulls()[0], "none_str", "-")
for col_name in df.schema} # fmt: skip
col_none_str = {
col_name: getattr(df[col_name].drop_nulls()[0], "none_str", "-") if not col_type.is_(pl_null) else "-"
for col_name, col_type in df.schema.items()
} # fmt: skip

# fill table
for vals_dict in df.iter_rows(named=True):
Expand Down
Loading