FairLoss¶
The goal of this loss function is to take fairness into account during the training of a PyTorch model. It works by adding a fairness measure to a regular loss. Both the loss function and scores are provided.
import torch
from fair_loss import FairLoss
model = torch.nn.Sequential(torch.nn.Linear(5, 1), torch.nn.ReLU())
data = torch.randint(0, 5, (100, 5), dtype=torch.float, requires_grad=True)
y_true = torch.randint(0, 5, (100, 1), dtype=torch.float)
y_pred = model(data)
# Let's say the sensitive attribute is in the second dimension
dim = 1
criterion = FairLoss(torch.nn.MSELoss(), data[:, dim].detach().unique(), 'accuracy')
loss = criterion(data[:, dim], y_pred, y_true)
loss.backward()
-
class
fair_loss.
FairLoss
(loss_fun: torch.nn.modules.module.Module, unique_attr: torch.Tensor, fairness_score: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]])[source]¶ Add a fairness measure to the regular loss
fairness_score is applied to input and target for each value of unique_attr. Then the results are sumed up, divided by the minimum and added to the regular loss function.
loss+λ∑ki=0wifi(input,target)minwhere:
k is the number of values of
protected_attr
f is the
fairness_score
function
- Parameters
loss_fun (torch.nn.Module) – A loss function
unique_attr (torch.Tensor) – Possible values of a sensitive attribute
fairness_score (Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]) – A function that takes input and target as arguments and return a score. Or one of ‘accuracy’, ‘fpr’, ‘tpr’, ‘tnr’, ‘fnr’, ‘ppv’, ‘npv’, ‘accuracy’
Examples
>>> model = Model() >>> data = torch.randint(0, 5, (100, 5), dtype=torch.float, requires_grad=True) >>> target = torch.randint(0, 5, (100, 1), dtype=torch.float) >>> input = model(data) >>> # The sensitive attribute is the second column >>> dim = 1 >>> criterion = FairLoss(torch.nn.MSELoss(), data[:, dim].detach().unique(), 'accuracy') >>> loss = criterion(data[:, dim], y_pred, y_true)
-
static
accuracy
(input: torch.Tensor, target: torch.Tensor) → torch.Tensor[source]¶ Accuracy
- Parameters
input (torch.Tensor) – Predicted values
target (torch.Tensor) – Ground truth
- Shape:
input: (N, 1)
target: (N, 1)
- Returns
Accuracy
- Return type
Examples
>>> input = torch.randint(0, 2, (10, 1), dtype=torch.float) >>> target = torch.randint(0, 2, (10, 1), dtype=torch.float) >>> accuracy(input, target)
-
static
fnr
(input: torch.Tensor, target: torch.Tensor) → torch.Tensor[source]¶ False Negative Rate
{FNR} = {FN \over FN + TP}where:
FN is the number of False Negative
TP is the number of True Positive
- Parameters
input (torch.Tensor) – Predicted values
target (torch.Tensor) – Ground truth
- Shape:
input: (N, 1)
target: (N, 1)
- Returns
False Negative Rate
- Return type
Examples
>>> input = torch.randint(0, 2, (10, 1), dtype=torch.float) >>> target = torch.randint(0, 2, (10, 1), dtype=torch.float) >>> fnr(input, target)
-
forward
(protected_attr: torch.Tensor, input: torch.Tensor, target: torch.Tensor)[source]¶ Compute the fair loss
- Shape:
protected_attr: (N,)
input: (N, 1)
target: (N, 1)
- Returns
The fair loss value
- Return type
-
static
fpr
(input: torch.Tensor, target: torch.Tensor) → torch.Tensor[source]¶ False Positive Rate
{FPR} = {FP \over FP + TN}where:
FP is the number of False Positive
TN is the number of True Negative
- Parameters
input (torch.Tensor) – Predicted values
target (torch.Tensor) – Ground truth
- Shape:
input: (N, 1)
target: (N, 1)
- Returns
False Positive Rate
- Return type
Examples
>>> input = np.random.randint(2, size=(10, 1)).astype("float") >>> input = torch.tensor(input) >>> target = np.random.randint(2, size=(10, 1)).astype("float") >>> target = torch.tensor(target) >>> fpr(input, target)
-
get_fairness_score
(fairness_score: str) → Callable[[torch.Tensor, torch.Tensor], torch.Tensor][source]¶ Return one of the fairness scores that are built-in
- Parameters
fairness_score (str) – The fairness score
- Returns
The fairness score function
- Return type
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
-
static
npv
(input: torch.Tensor, target: torch.Tensor) → torch.Tensor[source]¶ Negative Predicted Value
{NPV} = {TN \over TN + FN}where:
TN is the number of True Negative
FN is the number of False Negative
- Parameters
input (torch.Tensor) – Predicted values
target (torch.Tensor) – Ground truth
- Shape:
input: (N, 1)
target: (N, 1)
- Returns
Negative Predicted Value
- Return type
Examples
>>> input = torch.randint(0, 2, (10, 1), dtype=torch.float) >>> target = torch.randint(0, 2, (10, 1), dtype=torch.float) >>> npv(input, target)
-
static
ppv
(input: torch.Tensor, target: torch.Tensor) → torch.Tensor[source]¶ Positive Predicted Value
{PPV} = {TP \over TP + FP}where:
TP is the number of True Positive
FP is the number of False Positive
- Parameters
input (torch.Tensor) – Predicted values
target (torch.Tensor) – Ground truth
- Shape:
input: (N, 1)
target: (N, 1)
- Returns
Positive Predicted Value
- Return type
Examples
>>> input = torch.randint(0, 2, (10, 1), dtype=torch.float) >>> target = torch.randint(0, 2, (10, 1), dtype=torch.float) >>> ppv(input, target)
-
static
tnr
(input: torch.Tensor, target: torch.Tensor) → torch.Tensor[source]¶ True Negative Rate
{TNR} = {TN \over TN + FP}where:
TN is the number of True Negative
FP is the number of False Positive
- Parameters
input (torch.Tensor) – Predicted values
target (torch.Tensor) – Ground truth
- Shape:
input: (N, 1)
target: (N, 1)
- Returns
True Negative Rate
- Return type
Examples
>>> input = torch.randint(0, 2, (10, 1), dtype=torch.float) >>> target = torch.randint(0, 2, (10, 1), dtype=torch.float) >>> tnr(input, target)
-
static
tpr
(input: torch.Tensor, target: torch.Tensor) → torch.Tensor[source]¶ True Positive Rate
{TPR} = {TP \over TP + FN}where:
TP is the number of True Positive
FN is the number of False Negative
- Parameters
input (torch.Tensor) – Predicted values
target (torch.Tensor) – Ground truth
- Shape:
input: (N, 1)
target: (N, 1)
- Returns
True Positive Rate
- Return type
Examples
>>> input = torch.randint(0, 2, (10, 1), dtype=torch.float) >>> target = torch.randint(0, 2, (10, 1), dtype=torch.float) >>> tpr(input, target)