Skip to content

Commit 7271863

Browse files
rmitschRaphael Mitsch
andauthored
chore: Change classification multi_label bool to mode: Literal["multi", "single"]. (#244)
Co-authored-by: Raphael Mitsch <raphael@climatiq.com>
1 parent 3fd21fe commit 7271863

File tree

11 files changed

+64
-61
lines changed

11 files changed

+64
-61
lines changed

demos/crisis_tweets/case_study.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def _(batch_size, data_sampled, model):
220220
crisis_label_classifier = tasks.Classification(
221221
task_id="crisis_label_classifier",
222222
labels=data_sampled.label.unique(),
223-
multi_label=False,
223+
mode='single',
224224
model=model,
225225
batch_size=batch_size,
226226
)
@@ -249,7 +249,7 @@ def related_to_crisis(doc: Doc) -> bool:
249249
crisis_type_classifier = tasks.Classification(
250250
task_id="crisis_type_classifier",
251251
labels=data_sampled.crisis_type.unique(),
252-
multi_label=False,
252+
mode='single',
253253
model=model,
254254
condition=related_to_crisis,
255255
batch_size=batch_size,

demos/demo_spam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def _(model):
141141

142142
classifier = tasks.Classification(
143143
labels=["spam", "not spam"],
144-
multi_label=False,
144+
mode='single',
145145
model=model,
146146
)
147147
summarizer = tasks.Summarization(n_words=10, model=model)

docs/guides/distillation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ The distillation process automatically handles both classification modes:
291291
task = Classification(
292292
labels=["technology", "politics", "sports"],
293293
model=model,
294-
multi_label=False,
294+
mode='single',
295295
)
296296
```
297297

docs/tasks/predictive/classification.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ The `Classification` task returns a unified result schema regardless of the mode
1818
--8<-- "sieves/tasks/predictive/schemas/classification.py:Result"
1919
```
2020

21-
- When `multi_label=True` (default): results are of type `ResultMultiLabel`, containing a list of `(label, score)` tuples.
22-
- When `multi_label=False`: results are of type `ResultSingleLabel`, containing a single `label` and `score`.
21+
- When `mode == 'multi'` (default): results are of type `ResultMultiLabel`, containing a list of `(label, score)` tuples.
22+
- When `mode == 'single'`: results are of type `ResultSingleLabel`, containing a single `label` and `score`.
2323

2424
---
2525

sieves/model_wrappers/huggingface_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def execute(values: Sequence[dict[str, Any]]) -> Sequence[tuple[Result | None, A
6666
sequences=[doc_values["text"] for doc_values in values],
6767
candidate_labels=prompt_signature,
6868
hypothesis_template=template,
69-
multi_label=True,
69+
mode="multi",
7070
**self._inference_kwargs,
7171
)
7272

sieves/tasks/predictive/classification/bridges.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ def __init__(
3434
task_id: str,
3535
prompt_instructions: str | None,
3636
labels: list[str] | dict[str, str],
37-
multi_label: bool,
37+
mode: Literal["single", "multi"],
3838
model_settings: ModelSettings,
3939
):
4040
"""Initialize ClassificationBridge.
4141
4242
:param task_id: Task ID.
4343
:param prompt_instructions: Custom prompt instructions. If None, default instructions are used.
4444
:param labels: Labels to classify. Can be a list of label strings, or a dict mapping labels to descriptions.
45-
:param multi_label: If True, task returns confidence scores for all specified labels. If False, task returns
45+
:param mode: If 'multi'', task returns confidence scores for all specified labels. If 'single', task returns
4646
most likely class label. In the latter case label forcing mechanisms are utilized, which can lead to higher
4747
accuracy.
4848
:param model_settings: Model settings.
@@ -59,7 +59,7 @@ def __init__(
5959
else:
6060
self._labels = labels
6161
self._label_descriptions = {}
62-
self._multi_label = multi_label
62+
self._mode = mode
6363

6464
def _get_label_descriptions(self) -> str:
6565
"""Return a string with the label descriptions.
@@ -87,7 +87,7 @@ class DSPyClassification(ClassificationBridge[dspy_.PromptSignature, dspy_.Resul
8787
@override
8888
@property
8989
def _default_prompt_instructions(self) -> str:
90-
if self._multi_label:
90+
if self._mode == "multi":
9191
return f"""
9292
Multi-label classification of the provided text given the labels {self._labels}.
9393
For each label, provide the confidence with which you believe that the provided text should be assigned
@@ -121,7 +121,7 @@ def prompt_signature(self) -> type[dspy_.PromptSignature]:
121121
labels = self._labels
122122
LabelType = Literal[*labels] # type: ignore[valid-type]
123123

124-
if self._multi_label:
124+
if self._mode == "multi":
125125

126126
class MultiLabelTextClassification(dspy.Signature): # type: ignore[misc]
127127
text: str = dspy.InputField(description="Text to classify.")
@@ -164,7 +164,7 @@ def integrate(self, results: Sequence[dspy_.Result], docs: list[Doc]) -> list[Do
164164
reverse=True,
165165
)
166166

167-
if self._multi_label:
167+
if self._mode == "multi":
168168
doc.results[self._task_id] = ResultMultiLabel(label_scores=sorted_preds)
169169
else:
170170
if isinstance(sorted_preds, list) and len(sorted_preds) > 0:
@@ -188,7 +188,7 @@ def consolidate(
188188

189189
# Clamp score to range between 0 and 1. Alternatively we could force this in the prompt signature,
190190
# but this fails occasionally with some models and feels too strict.
191-
if self._multi_label:
191+
if self._mode == "multi":
192192
for label, score in res.confidence_per_label.items():
193193
label_scores[label] += max(0, min(score, 1))
194194
else:
@@ -228,7 +228,7 @@ def _default_prompt_instructions(self) -> str:
228228
@override
229229
@property
230230
def _prompt_example_template(self) -> str | None:
231-
if self._multi_label:
231+
if self._mode == "multi":
232232
return """
233233
{% if examples|length > 0 -%}
234234
@@ -285,7 +285,7 @@ def inference_mode(self) -> huggingface_.InferenceMode:
285285
def integrate(self, results: Sequence[huggingface_.Result], docs: list[Doc]) -> list[Doc]:
286286
for doc, result in zip(docs, results):
287287
label_scores = [(label, score) for label, score in zip(result["labels"], result["scores"])]
288-
if self._multi_label:
288+
if self._mode == "multi":
289289
doc.results[self._task_id] = ResultMultiLabel(label_scores=label_scores)
290290
else:
291291
if len(label_scores) > 0:
@@ -333,7 +333,7 @@ class PydanticBasedClassification(
333333
@override
334334
@property
335335
def _default_prompt_instructions(self) -> str:
336-
if self._multi_label:
336+
if self._mode == "multi":
337337
return (
338338
f"""
339339
Perform multi-label classification of the provided text given the provided labels: {",".join(self._labels)}.
@@ -369,7 +369,7 @@ def _default_prompt_instructions(self) -> str:
369369
@override
370370
@property
371371
def _prompt_example_template(self) -> str | None:
372-
if self._multi_label:
372+
if self._mode == "multi":
373373
return """
374374
{% if examples|length > 0 -%}
375375
Examples:
@@ -417,7 +417,7 @@ def _prompt_conclusion(self) -> str | None:
417417
@override
418418
@cached_property
419419
def prompt_signature(self) -> type[pydantic.BaseModel] | list[str]:
420-
if self._multi_label:
420+
if self._mode == "multi":
421421
prompt_sig = pydantic.create_model( # type: ignore[no-matching-overload]
422422
"MultilabelClassification",
423423
__base__=pydantic.BaseModel,
@@ -442,7 +442,7 @@ class SingleLabelClassification(pydantic.BaseModel):
442442
@override
443443
def integrate(self, results: Sequence[pydantic.BaseModel | str], docs: list[Doc]) -> list[Doc]:
444444
for doc, result in zip(docs, results):
445-
if self._multi_label:
445+
if self._mode == "multi":
446446
assert isinstance(result, pydantic.BaseModel)
447447
label_scores = result.model_dump()
448448
sorted_label_scores = sorted(
@@ -471,7 +471,7 @@ def consolidate(
471471

472472
# We clamp the score to 0 <= x <= 1. Alternatively we could force this in the prompt signature, but
473473
# this fails occasionally with some models and feels too strict.
474-
if self._multi_label:
474+
if self._mode == "multi":
475475
for label in self._labels:
476476
label_scores[label] += max(0, min(getattr(res, label), 1))
477477
else:
@@ -482,7 +482,7 @@ def consolidate(
482482
assert issubclass(prompt_signature, pydantic.BaseModel) # type: ignore[arg-type]
483483
assert callable(prompt_signature)
484484

485-
if self._multi_label:
485+
if self._mode == "multi":
486486
consolidated_results.append(prompt_signature(**avg_label_scores))
487487
else:
488488
max_score_label = max(avg_label_scores, key=avg_label_scores.__getitem__)
@@ -510,12 +510,12 @@ class PydanticBasedClassificationWithLabelForcing(PydanticBasedClassification[Mo
510510
@override
511511
@cached_property
512512
def prompt_signature(self) -> type[pydantic.BaseModel] | list[str]:
513-
return super().prompt_signature if self._multi_label else self._labels
513+
return super().prompt_signature if self._mode == "multi" else self._labels
514514

515515
@override
516516
@property
517517
def _default_prompt_instructions(self) -> str:
518-
if self._multi_label:
518+
if self._mode == "multi":
519519
return super()._default_prompt_instructions
520520

521521
return f"""
@@ -534,7 +534,7 @@ def _default_prompt_instructions(self) -> str:
534534
@override
535535
@property
536536
def _prompt_example_template(self) -> str | None:
537-
if self._multi_label:
537+
if self._mode == "multi":
538538
return super()._prompt_example_template
539539

540540
return """
@@ -555,7 +555,7 @@ def _prompt_example_template(self) -> str | None:
555555

556556
@override
557557
def integrate(self, results: Sequence[pydantic.BaseModel | str], docs: list[Doc]) -> list[Doc]:
558-
if self._multi_label:
558+
if self._mode == "multi":
559559
return super().integrate(results, docs)
560560

561561
for doc, result in zip(docs, results):
@@ -572,7 +572,7 @@ def integrate(self, results: Sequence[pydantic.BaseModel | str], docs: list[Doc]
572572
def consolidate(
573573
self, results: Sequence[pydantic.BaseModel | str], docs_offsets: list[tuple[int, int]]
574574
) -> Sequence[pydantic.BaseModel | str]:
575-
if self._multi_label:
575+
if self._mode == "multi":
576576
return super().consolidate(results, docs_offsets)
577577

578578
else:
@@ -592,5 +592,5 @@ class OutlinesClassification(PydanticBasedClassificationWithLabelForcing[outline
592592
@property
593593
def inference_mode(self) -> outlines_.InferenceMode:
594594
return self._model_settings.inference_mode or (
595-
outlines_.InferenceMode.json if self._multi_label else outlines_.InferenceMode.choice
595+
outlines_.InferenceMode.json if self._mode == "multi" else outlines_.InferenceMode.choice
596596
)

sieves/tasks/predictive/classification/core.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
from collections.abc import Callable, Iterable, Sequence
77
from pathlib import Path
8-
from typing import Any, override
8+
from typing import Any, Literal, override
99

1010
import datasets
1111
import dspy
@@ -74,7 +74,7 @@ def __init__(
7474
batch_size: int = -1,
7575
prompt_instructions: str | None = None,
7676
fewshot_examples: Sequence[FewshotExample] = (),
77-
multi_label: bool = True,
77+
mode: Literal["single", "multi"] = "multi",
7878
model_settings: ModelSettings = ModelSettings(),
7979
condition: Callable[[Doc], bool] | None = None,
8080
) -> None:
@@ -90,7 +90,7 @@ def __init__(
9090
:param batch_size: Batch size to use for inference. Use -1 to process all documents at once.
9191
:param prompt_instructions: Custom prompt instructions. If None, default instructions are used.
9292
:param fewshot_examples: Few-shot examples.
93-
:param multi_label: If True, task returns confidence scores for all specified labels. If False, task returns
93+
:param mode: If 'multi', task returns confidence scores for all specified labels. If 'single', task returns
9494
most likely class label. In the latter case label forcing mechanisms are utilized, which can lead to higher
9595
accuracy.
9696
:param model_settings: Model settings.
@@ -102,7 +102,7 @@ def __init__(
102102
else:
103103
self._labels = list(labels)
104104
self._label_descriptions = {}
105-
self._multi_label = multi_label
105+
self._mode = mode
106106

107107
super().__init__(
108108
model=model,
@@ -137,7 +137,7 @@ def _init_bridge(self, model_type: ModelType) -> _TaskBridge:
137137
prompt_signature=gliner2.inference.engine.Schema().classification(
138138
task="classification",
139139
labels=labels,
140-
multi_label=self._multi_label,
140+
mode=self._mode,
141141
),
142142
model_settings=self._model_settings,
143143
inference_mode=gliner_.InferenceMode.classification,
@@ -158,7 +158,7 @@ def _init_bridge(self, model_type: ModelType) -> _TaskBridge:
158158
task_id=self._task_id,
159159
prompt_instructions=self._custom_prompt_instructions,
160160
labels=labels,
161-
multi_label=self._multi_label,
161+
mode=self._mode,
162162
model_settings=self._model_settings,
163163
)
164164
except KeyError as err:
@@ -179,12 +179,12 @@ def _validate_fewshot_examples(self) -> None:
179179
label_error_text = (
180180
"Label mismatch: {task_id} has labels {labels}. Few-shot examples have labels {example_labels}."
181181
)
182-
example_type_error_text = "Fewshot example type mismatch: multi_label = {multi_label} requires {example_type}."
182+
example_type_error_text = "Fewshot example type mismatch: mode = {mode} requires {example_type}."
183183

184184
for fs_example in self._fewshot_examples or []:
185-
if self._multi_label:
185+
if self._mode == "multi":
186186
assert isinstance(fs_example, FewshotExampleMultiLabel), TypeError(
187-
example_type_error_text.format(example_type=FewshotExampleMultiLabel, multi_label=self._multi_label)
187+
example_type_error_text.format(example_type=FewshotExampleMultiLabel, mode=self._mode)
188188
)
189189
if any([label not in self._labels for label in fs_example.confidence_per_label]) or not all(
190190
[label in fs_example.confidence_per_label for label in self._labels]
@@ -196,9 +196,7 @@ def _validate_fewshot_examples(self) -> None:
196196
)
197197
else:
198198
assert isinstance(fs_example, FewshotExampleSingleLabel), TypeError(
199-
example_type_error_text.format(
200-
example_type=FewshotExampleSingleLabel, multi_label=self._multi_label
201-
)
199+
example_type_error_text.format(example_type=FewshotExampleSingleLabel, mode=self._mode)
202200
)
203201
if fs_example.label not in self._labels:
204202
raise ValueError(
@@ -283,7 +281,7 @@ def distill(
283281
default_init_kwargs: dict[str, Any] = {}
284282
metric_kwargs: dict[str, Any] = {}
285283

286-
if self._multi_label:
284+
if self._mode == "multi":
287285
default_init_kwargs["multi_target_strategy"] = "multi-output"
288286
metric_kwargs = {"average": "macro"}
289287

@@ -369,7 +367,7 @@ def to_hf_dataset(self, docs: Iterable[Doc], threshold: float = 0.5) -> datasets
369367
data: list[dict[str, str | list[bool]]] = []
370368

371369
# Define metadata and features (multi-hot across declared labels for multi-label).
372-
if self._multi_label:
370+
if self._mode == "multi":
373371
features = datasets.Features(
374372
{"text": datasets.Value("string"), "labels": datasets.Sequence(datasets.Value("bool"))}
375373
)
@@ -380,7 +378,7 @@ def to_hf_dataset(self, docs: Iterable[Doc], threshold: float = 0.5) -> datasets
380378

381379
info = datasets.DatasetInfo(
382380
description=(
383-
f"{'Multi-label' if self._multi_label else 'Single-label'} classification dataset with labels "
381+
f"{'Multi-label' if self._mode == 'multi' else 'Single-label'} classification dataset with labels "
384382
f"{self._labels}. Generated with sieves v{Config.get_version()}."
385383
),
386384
features=features,
@@ -391,7 +389,7 @@ def to_hf_dataset(self, docs: Iterable[Doc], threshold: float = 0.5) -> datasets
391389
scores = Classification._result_to_scores(doc.results[self._task_id])
392390

393391
# If multi-label: store one-hot representation.
394-
if self._multi_label:
392+
if self._mode == "multi":
395393
result_normalized = [int(scores.get(label, 0.0) >= threshold) for label in self._labels] # type: ignore[no-matching-overload]
396394
# If single-label: get single-label result as is.
397395
else:
@@ -410,7 +408,7 @@ def to_hf_dataset(self, docs: Iterable[Doc], threshold: float = 0.5) -> datasets
410408
def _evaluate_optimization_example(
411409
self, truth: dspy.Example, pred: dspy.Prediction, trace: Any, model: dspy.LM
412410
) -> float:
413-
if not self._multi_label:
411+
if self._mode == "single":
414412
return 1 - abs(truth["confidence"] - pred["confidence"]) if truth["label"] == pred["label"] else 0
415413

416414
# For multi-label: compute label-wise accuracy as

sieves/tasks/predictive/gliner_bridge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ def integrate(self, results: Sequence[gliner_.Result], docs: list[Doc]) -> list[
155155
# Used by: Classification
156156
case gliner_.InferenceMode.classification:
157157
assert hasattr(self._prompt_signature.schema, "__getitem__")
158-
is_multilabel = self._prompt_signature.schema["classifications"][0]["multi_label"]
158+
mode = self._prompt_signature.schema["classifications"][0]["mode"]
159159

160-
if is_multilabel:
160+
if mode == "multi":
161161
label_scores: list[tuple[str, float]] = []
162162
for res in sorted(result, key=lambda x: x["score"], reverse=True):
163163
assert isinstance(res, dict)

sieves/tests/docs/test_optimization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_basic_optimization_example(small_dspy_model):
6161
},
6262
model=model,
6363
fewshot_examples=examples,
64-
multi_label=False,
64+
mode='single',
6565
model_settings=ModelSettings(),
6666
)
6767
# --8<-- [end:optimization-task-setup]

0 commit comments

Comments
 (0)