@@ -34,15 +34,15 @@ def __init__(
3434 task_id : str ,
3535 prompt_instructions : str | None ,
3636 labels : list [str ] | dict [str , str ],
37- multi_label : bool ,
37+ mode : Literal [ "single" , "multi" ] ,
3838 model_settings : ModelSettings ,
3939 ):
4040 """Initialize ClassificationBridge.
4141
4242 :param task_id: Task ID.
4343 :param prompt_instructions: Custom prompt instructions. If None, default instructions are used.
4444 :param labels: Labels to classify. Can be a list of label strings, or a dict mapping labels to descriptions.
45- :param multi_label : If True , task returns confidence scores for all specified labels. If False , task returns
45+ :param mode : If 'multi'' , task returns confidence scores for all specified labels. If 'single' , task returns
4646 most likely class label. In the latter case label forcing mechanisms are utilized, which can lead to higher
4747 accuracy.
4848 :param model_settings: Model settings.
@@ -59,7 +59,7 @@ def __init__(
5959 else :
6060 self ._labels = labels
6161 self ._label_descriptions = {}
62- self ._multi_label = multi_label
62+ self ._mode = mode
6363
6464 def _get_label_descriptions (self ) -> str :
6565 """Return a string with the label descriptions.
@@ -87,7 +87,7 @@ class DSPyClassification(ClassificationBridge[dspy_.PromptSignature, dspy_.Resul
8787 @override
8888 @property
8989 def _default_prompt_instructions (self ) -> str :
90- if self ._multi_label :
90+ if self ._mode == "multi" :
9191 return f"""
9292 Multi-label classification of the provided text given the labels { self ._labels } .
9393 For each label, provide the confidence with which you believe that the provided text should be assigned
@@ -121,7 +121,7 @@ def prompt_signature(self) -> type[dspy_.PromptSignature]:
121121 labels = self ._labels
122122 LabelType = Literal [* labels ] # type: ignore[valid-type]
123123
124- if self ._multi_label :
124+ if self ._mode == "multi" :
125125
126126 class MultiLabelTextClassification (dspy .Signature ): # type: ignore[misc]
127127 text : str = dspy .InputField (description = "Text to classify." )
@@ -164,7 +164,7 @@ def integrate(self, results: Sequence[dspy_.Result], docs: list[Doc]) -> list[Do
164164 reverse = True ,
165165 )
166166
167- if self ._multi_label :
167+ if self ._mode == "multi" :
168168 doc .results [self ._task_id ] = ResultMultiLabel (label_scores = sorted_preds )
169169 else :
170170 if isinstance (sorted_preds , list ) and len (sorted_preds ) > 0 :
@@ -188,7 +188,7 @@ def consolidate(
188188
189189 # Clamp score to range between 0 and 1. Alternatively we could force this in the prompt signature,
190190 # but this fails occasionally with some models and feels too strict.
191- if self ._multi_label :
191+ if self ._mode == "multi" :
192192 for label , score in res .confidence_per_label .items ():
193193 label_scores [label ] += max (0 , min (score , 1 ))
194194 else :
@@ -228,7 +228,7 @@ def _default_prompt_instructions(self) -> str:
228228 @override
229229 @property
230230 def _prompt_example_template (self ) -> str | None :
231- if self ._multi_label :
231+ if self ._mode == "multi" :
232232 return """
233233 {% if examples|length > 0 -%}
234234
@@ -285,7 +285,7 @@ def inference_mode(self) -> huggingface_.InferenceMode:
285285 def integrate (self , results : Sequence [huggingface_ .Result ], docs : list [Doc ]) -> list [Doc ]:
286286 for doc , result in zip (docs , results ):
287287 label_scores = [(label , score ) for label , score in zip (result ["labels" ], result ["scores" ])]
288- if self ._multi_label :
288+ if self ._mode == "multi" :
289289 doc .results [self ._task_id ] = ResultMultiLabel (label_scores = label_scores )
290290 else :
291291 if len (label_scores ) > 0 :
@@ -333,7 +333,7 @@ class PydanticBasedClassification(
333333 @override
334334 @property
335335 def _default_prompt_instructions (self ) -> str :
336- if self ._multi_label :
336+ if self ._mode == "multi" :
337337 return (
338338 f"""
339339 Perform multi-label classification of the provided text given the provided labels: { "," .join (self ._labels )} .
@@ -369,7 +369,7 @@ def _default_prompt_instructions(self) -> str:
369369 @override
370370 @property
371371 def _prompt_example_template (self ) -> str | None :
372- if self ._multi_label :
372+ if self ._mode == "multi" :
373373 return """
374374 {% if examples|length > 0 -%}
375375 Examples:
@@ -417,7 +417,7 @@ def _prompt_conclusion(self) -> str | None:
417417 @override
418418 @cached_property
419419 def prompt_signature (self ) -> type [pydantic .BaseModel ] | list [str ]:
420- if self ._multi_label :
420+ if self ._mode == "multi" :
421421 prompt_sig = pydantic .create_model ( # type: ignore[no-matching-overload]
422422 "MultilabelClassification" ,
423423 __base__ = pydantic .BaseModel ,
@@ -442,7 +442,7 @@ class SingleLabelClassification(pydantic.BaseModel):
442442 @override
443443 def integrate (self , results : Sequence [pydantic .BaseModel | str ], docs : list [Doc ]) -> list [Doc ]:
444444 for doc , result in zip (docs , results ):
445- if self ._multi_label :
445+ if self ._mode == "multi" :
446446 assert isinstance (result , pydantic .BaseModel )
447447 label_scores = result .model_dump ()
448448 sorted_label_scores = sorted (
@@ -471,7 +471,7 @@ def consolidate(
471471
472472 # We clamp the score to 0 <= x <= 1. Alternatively we could force this in the prompt signature, but
473473 # this fails occasionally with some models and feels too strict.
474- if self ._multi_label :
474+ if self ._mode == "multi" :
475475 for label in self ._labels :
476476 label_scores [label ] += max (0 , min (getattr (res , label ), 1 ))
477477 else :
@@ -482,7 +482,7 @@ def consolidate(
482482 assert issubclass (prompt_signature , pydantic .BaseModel ) # type: ignore[arg-type]
483483 assert callable (prompt_signature )
484484
485- if self ._multi_label :
485+ if self ._mode == "multi" :
486486 consolidated_results .append (prompt_signature (** avg_label_scores ))
487487 else :
488488 max_score_label = max (avg_label_scores , key = avg_label_scores .__getitem__ )
@@ -510,12 +510,12 @@ class PydanticBasedClassificationWithLabelForcing(PydanticBasedClassification[Mo
510510 @override
511511 @cached_property
512512 def prompt_signature (self ) -> type [pydantic .BaseModel ] | list [str ]:
513- return super ().prompt_signature if self ._multi_label else self ._labels
513+ return super ().prompt_signature if self ._mode == "multi" else self ._labels
514514
515515 @override
516516 @property
517517 def _default_prompt_instructions (self ) -> str :
518- if self ._multi_label :
518+ if self ._mode == "multi" :
519519 return super ()._default_prompt_instructions
520520
521521 return f"""
@@ -534,7 +534,7 @@ def _default_prompt_instructions(self) -> str:
534534 @override
535535 @property
536536 def _prompt_example_template (self ) -> str | None :
537- if self ._multi_label :
537+ if self ._mode == "multi" :
538538 return super ()._prompt_example_template
539539
540540 return """
@@ -555,7 +555,7 @@ def _prompt_example_template(self) -> str | None:
555555
556556 @override
557557 def integrate (self , results : Sequence [pydantic .BaseModel | str ], docs : list [Doc ]) -> list [Doc ]:
558- if self ._multi_label :
558+ if self ._mode == "multi" :
559559 return super ().integrate (results , docs )
560560
561561 for doc , result in zip (docs , results ):
@@ -572,7 +572,7 @@ def integrate(self, results: Sequence[pydantic.BaseModel | str], docs: list[Doc]
572572 def consolidate (
573573 self , results : Sequence [pydantic .BaseModel | str ], docs_offsets : list [tuple [int , int ]]
574574 ) -> Sequence [pydantic .BaseModel | str ]:
575- if self ._multi_label :
575+ if self ._mode == "multi" :
576576 return super ().consolidate (results , docs_offsets )
577577
578578 else :
@@ -592,5 +592,5 @@ class OutlinesClassification(PydanticBasedClassificationWithLabelForcing[outline
592592 @property
593593 def inference_mode (self ) -> outlines_ .InferenceMode :
594594 return self ._model_settings .inference_mode or (
595- outlines_ .InferenceMode .json if self ._multi_label else outlines_ .InferenceMode .choice
595+ outlines_ .InferenceMode .json if self ._mode == "multi" else outlines_ .InferenceMode .choice
596596 )
0 commit comments