Module abstention

This module contains functions to apply abstention after A3SOM output.

find_local() can be used to find distance and ambiguity thresholds for each class (= local thresholds). These can then be applied by apply_local() to obtain labels after local abstention.

Another option is to use apply_global() to obtain labels after abstention with global thresholds (= the same threshold for all classes). As this is not the optimal strategy, we do not provide an option to find these global thresholds, and the choice is left up to the user.

Once abstained labels are obtained, outcome() finds the lists of accepted, rejected, successful, erroneous predictions. Using these lists, Accepted Accuracy (AA, the rate of successful predictions that were accepted) and Rejected Error (RE, the rate of prediction errors that were rejected) can be computed with functions AA() and RE().

Functions

find_local()

find_local(y_pred)

Find local thresholds, i.e., one distance threshold for each class and one ambiguity threshold for each class.

Distance thresholds are computed by finding, for each class, the mean probability of predicting that class as the most likely.

Ambiguity thresholds are computed by finding, for each class, the mean difference between the probability for that class and the second highest probability.

Args:
y_pred:

class membership probabilities returned by the model. shape = (n_samples, n_classes)

Returns:
amb_thr:

a list of ambiguity thresholds, one for each class.

dist_thr:

a list of distance thresholds, one for each class.

apply_local()

apply_local(y_pred, ambiguity_threshold, distance_threshold)

Apply local abstention thresholds on predictions.

Args:
y_pred:

class membership probabilities returned by the model. shape = (n_samples, n_classes)

ambiguity_threshold:

list of ambiguity thresholds for each class (chosen by the user, or the first output of find_local()).

distance_threshold:

list of distance thresholds for each class (chosen by the user, or the second output of find_local()).

Returns:
y_abst:

abstention labels. If abstention was not applied for the sample i, y_abst[i] = y_pred[i]. If distance abstention was applied, y_abst[i] = -1. If ambiguity abstention was applied, y_abst[i] = -2.

apply_global()

apply_global(y_pred, ambiguity_threshold, distance_threshold)

Apply global abstention thresholds on predictions.

Args:
y_pred:

class membership probabilities returned by the model. shape = (n_samples, n_classes)

ambiguity_threshold: float between 0 and 1.

Ambiguity threshold to apply for all classes.

distance_threshold: float between 0 and 1.

Distance threshold to apply for all classes.

Returns:
y_abst:

abstention labels. If abstention was not applied for the sample i, y_abst[i] = y_pred[i]. If distance abstention was applied, y_abst[i] = -1. If ambiguity abstention was applied, y_abst[i] = -2.

outcomes()

outcomes(y_true, y_pred, y_abst, as_dict=True)

Evaluate which predictions were accepted and which were abstained (rejected). Also check how many predictions were successful and how many were errors.

Args:
y_true:

list of the true labels.

y_pred:

list of class membership probabilities returned by the model, before abstention.

y_abst:

list of abstained labels (as obtained by the apply_local() function).

as_dict: bool, default=True.

If True, the different lists of predictions will be returned as a dictionary with corresponding keys (‘accepted’, ‘rejected’, ‘success’, ‘error’). If False, they will be returned as a list of lists in this same order.

Returns:
Either a dictionary: {‘accepted’: accepted, ‘rejected’: rejected, ‘success’: success, ‘error’: error} or a list: [accepted, rejected, success, error] where the four variables are:

accepted: list of predictions that were accepted (= where abstention was not applied).

rejected: list of predictions that were rejected (= where abstention was applied).

success: list of predictions where the abstained label matches the true label.

error: list of predictions where the abstained label does not match the true label.

AA()

AA(accepted, success)

Accepted Accuracy metric: among the accepted predictions, how many were good predictions (successes)?

Args:
accepted:

list of predictions that were accepted.

success:

list of predictions that matched the true labels.

Returns:

Accepted accuracy score between 0 and 1. If AA = 1, it means that all the predictions for which abstention was not applied were correct.

RE()

RE(rejected, error)

Rejected Error metric: among the errors, how many were rejected?

Args:
rejected:

list of predictions where abstention was applied.

error:

list of predictions that did not match the true labels.

Returns:

Rejected error score between 0 and 1. If RE = 1, it means that abstention was applied for all of the predictions that did not match true labels.