Beware metric auto-reduce with PyTorch Lightning + TorchMetrics

1 minute read

Published:

PyTorch Lightning + TorchMetrics can log metrics per step and per epoch. It also has MetricCollection, which can be used to compute several metrics at once, getting rid of redundant code. Here is how I have it set up:

class BaseModule(pl.LightningModule):
    def __init__(self):
        self.train_metrics = torchmetrics.MetricCollection(
            [
                torchmetrics.Accuracy(),
                torchmetrics.Precision(),
                torchmetrics.Recall(),
                torchmetrics.F1Score(),
            ]
        )

    def training_step(self, batch, batch_idx):
        label = self.get_label(batch)
        out = self.forward(batch)
        loss = self.loss_fn(out, label)
        
        output = self.train_metrics(out, label.int())
        self.log_dict(output, on_step=False, on_epoch=True)
        return loss

This code works for Accuracy, but it computes the wrong value for for Precision, Recall, and F1Score. To calculate each metric’s value on_epoch, PyTorch Lightning averages1 the values logged for each batch, weighted by the batch size. This gives the incorrect value for e.g. F1Score.

The docs state that “If on_epoch is True, the logger automatically logs the end of epoch metric value by calling .compute().” 2 Maybe because I was using self.log_dict instead of self.log, this does not have the same effect.

To fix it, I changed my code in this way:

 class BaseModule(pl.LightningModule):
     def __init__(self):
         self.train_metrics = torchmetrics.MetricCollection(
             [
                 torchmetrics.Accuracy(),
                 torchmetrics.Precision(),
                 torchmetrics.Recall(),
                 torchmetrics.F1Score(),
             ]
         )
 
     def training_step(self, batch, batch_idx):
         label = self.get_label(batch)
         out = self.forward(batch)
         loss = self.loss_fn(out, label)
         
-        output = self.train_metrics(out, label.int())
-        self.log_dict(output, on_step=False, on_epoch=True)
+        self.train_metrics.update(out, label.int())
         return loss
+    
+    def training_epoch_end(self, outputs):
+        self.log_dict(self.train_metrics.compute(), on_step=False, on_epoch=True)
+        self.train_metrics.reset()

This code explicitly calls Metric.update() and Metric.compute() to compute the metric how God intended it. Yeehaw.