Skip to content

Commit 6b32b48

Browse files
authored
Merge pull request #78 from russellb/pr-72-alternative
Handle type conversion errors in FilterByValueBlock
2 parents 6842582 + 0de3f62 commit 6b32b48

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

src/instructlab/sdg/filterblock.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,20 @@ def __init__(
2020
self.convert_dtype = convert_dtype
2121
self.num_procs = batch_kwargs.get("num_procs", 1)
2222

23+
def _convert_dtype(self, sample):
24+
try:
25+
sample[self.column_name] = self.convert_dtype(sample[self.column_name])
26+
except ValueError as e:
27+
logger.error(
28+
"Error converting dtype: %s, filling with None to be filtered later", e
29+
)
30+
sample[self.column_name] = None
31+
return sample
32+
2333
def generate(self, samples) -> Dataset:
2434
if self.convert_dtype:
2535
samples = samples.map(
26-
lambda x: {
27-
**x,
28-
self.column_name: self.convert_dtype(x[self.column_name]),
29-
},
36+
self._convert_dtype,
3037
num_proc=self.num_procs,
3138
)
3239

tests/test_filterblock.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Standard
2+
from unittest.mock import patch
3+
import operator
4+
import unittest
5+
6+
# Third Party
7+
from datasets import Dataset, Features, Value
8+
9+
# First Party
10+
from instructlab.sdg.filterblock import FilterByValueBlock
11+
12+
13+
class TestFilterByValueBlock(unittest.TestCase):
14+
def setUp(self):
15+
self.block = FilterByValueBlock(
16+
filter_column="age",
17+
filter_value=30,
18+
operation=operator.eq,
19+
convert_dtype=int,
20+
)
21+
self.dataset = Dataset.from_dict(
22+
{"age": ["25", "30", "35", "forty", "45"]},
23+
features=Features({"age": Value("string")}),
24+
)
25+
26+
@patch("instructlab.sdg.filterblock.logger")
27+
def test_generate_mixed_types(self, mock_logger):
28+
filtered_dataset = self.block.generate(self.dataset)
29+
self.assertEqual(len(filtered_dataset), 1)
30+
self.assertEqual(filtered_dataset["age"], [30])
31+
mock_logger.error.assert_called()

0 commit comments

Comments
 (0)