advertorch logo

A short introduction to adversarial phenomena

Machine learning models have demonstrated a vulnerability to adversarial perturbations. Adversarial perturbations are minor modifications to the input data that can cause machine learning models to output completely different predictions but that are not perceptible to the human eye.

Here’s an example (with the jupyter notebook) of what that looks like: In the figure below, after a very small perturbation (middle image) is added to the panda image (right), the neural network model returns recognizes the perturbed image (left) as a bucket. However, to a human observer, the perturbed panda looks exactly the same as the original panda.

AT_1

Why is this significant?

Adversarial perturbation poses certain risks to machine learning applications that can have potential real-world impact. For example, researchers have shown that by putting black and white patches on a stop sign, state-of-the-art object detection systems cannot recognize the stop sign correctly anymore.

Image - AdverTorch_Figure2.png

Eykholt et al. (2018)

This problem is not only restricted to images: speech recognition systems and malware detection systems have all been shown to be vulnerable to similar attacks. More broadly, realistic adversarial attacks could happen on machine learning systems whenever it might be profitable for adversaries, for instance, on fraudulent detection systems, identity recognition systems, and decision making systems.

This is where AdverTorch comes in

AdverTorch (repo, report) is a tool we built at the Borealis AI research lab that implements a series of attack-and-defense strategies. The idea behind it emerged back in 2017, when my team began to do some focused research around adversarial robustness. At the time, we only had two tools at our disposal: CleverHans and Foolbox.

While these are both good tools, they had their respective limitations. Back then, CleverHans was only set up for TensorFlow, which limits its usage in other deep learning frameworks (in our case, PyTorch). Moreover, the static computational graph nature of TensorFlow makes the implementation of attacks less straight forward. For anyone new to this type of research, it can be hard to understand what’s going on if the attack is written in static graph language.

Foolbox, on the other hand, contains various types of attack methods but it only supports running attacks image-by-image, and not batch-by-batch. These parameters make it slow to run and thus only suitable for running evaluations. At the time, Foolbox also lacked variety in the number of attacks, e.g. the projected gradient descent attack (PGD) and the Carlini-Wagner $\ell_2$-norm constrained attack.

Our solution

In the absence of a toolbox that would serve more of our needs, we decide to implement our own. Creating a proprietary tool would also allow us to use our favorite language – PyTorch – which was not an option with the others.

Our aim was to provide researchers the tools for conducting research in different research directions for adversarial robustness. For now, we’ve built AdverTorch primarily for researchers and practitioners who have some algorithmic understanding of the methods.

We had the following design goal in mind:

  • simple and consistent APIs for attacks and defences; 
  • concise reference implementations, utilizing the dynamic computational graphs in PyTorch; 
  • fast executions with GPU-powered PyTorch implementations, which are important for “attacking-the-loop” algorithms, e.g. adversarial training.

Resources permitted, we are also working to make it more user friendly in the future.

Attack-and-defense strategies

For gradient-based attacks, we have the fast gradient (sign) methods (Goodfellow et al., 2014), projected gradient descent methods (Madry et al., 2017), Carlini-Wagner Attack (Carlini and Wagner, 2017), spatial transformation attack (Xiao et al., 2018) and advertorch.attacks. We also implemented a few gradient-free attacks including single pixel attack, local search attack (Narodytska and Kasiviswanathan, 2016), and the Jacobian saliency map attack (Papernot et al., 2016).

Besides specific attacks, we also implemented a convenient wrapper for the Backward Pass Differentiable Approximation (Athalye et al., 2018), which is an attack technique that enhances gradient-based attacks when attacking defended models that have non-differentiable or gradient-obfuscating components.

In terms of defenses, we considered two strategies: i) preprocessing-based defenses and ii) robust training. For preprocessing based defenses, we implement the JPEG filter, bit squeezing, and different kinds of spatial smoothing filters.

For robust training methods, we implemented them as examples in our repo. So far, we have a script for adversarial training on MNIST which you can access Github – BorealisAI and we plan to add more examples with different methods on various datasets.

How to create an attack

We used the fast gradient sign attack as an example of how to create an attack in AdverTorch. The GradientSignAttack can be found at advertorch.attacks.one_step_gradient.

To create an attack on classifier, we’ll need the Attack and LabelMixin from advertorch.attacks.base.

from advertorch.attacks.base import Attack
from advertorch.attacks.base import LabelMixin

Attack is the base class of all attacks in AdverTorch. It defines the API of an Attack. The core of it looks like this:

    def __init__(self, predict, loss_fn, clip_min, clip_max):
        self.predict = predict
        self.loss_fn = loss_fn
        self.clip_min = clip_min
        self.clip_max = clip_max

    def perturb(self, x, **kwargs):
        error = "Sub-classes must implement perturb."
        raise NotImplementedError(error)

    def __call__(self, *args, **kwargs):
        return self.perturb(*args, **kwargs)

An attack contains three core components:

  • predict: the function we want to attack;
  • loss_fn: the loss function we maximize in order to during attack; 
  • perturb: the method that implements the attack algorithm.

Let’s illustrate these components with GradientSignAttack as an example.

class GradientSignAttack(Attack, LabelMixin):
    """
    One step fast gradient sign method (Goodfellow et al, 2014).
    Paper: https://arxiv.org/abs/1412.6572

    :param predict: forward pass function.
    :param loss_fn: loss function.
    :param eps: attack step size.
    :param clip_min: minimum value per input dimension.
    :param clip_max: maximum value per input dimension.
    :param targeted: indicate if this is a targeted attack.
    """

    def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0.,
                 clip_max=1., targeted=False):
        """
        Create an instance of the GradientSignAttack.
        """
        super(GradientSignAttack, self).__init__(
            predict, loss_fn, clip_min, clip_max)

    self.eps = eps
        self.targeted = targeted
        if self.loss_fn is None:
            self.loss_fn = nn.CrossEntropyLoss(reduction="sum")

    def perturb(self, x, y=None):
        """
        Given examples (x, y), returns their adversarial counterparts with
        an attack length of eps.

    :param x: input tensor.
        :param y: label tensor.
                  - if None and self.targeted=False, compute y as predicted
                    labels.
                  - if self.targeted=True, then y must be the targeted labels.
        :return: tensor containing perturbed inputs.
        """

        x, y = self._verify_and_process_inputs(x, y)
        xadv = x.requires_grad_()
        
        ###############################
        # start: the attack algorithm #
        outputs = self.predict(xadv)
        loss = self.loss_fn(outputs, y)
        if self.targeted:
            loss = -loss
        loss.backward()
        grad_sign = xadv.grad.detach().sign()
        xadv = xadv + self.eps * grad_sign
        xadv = clamp(xadv, self.clip_min, self.clip_max)
        # end:   the attack algorithm # 
        ###############################

        return xadv

predict is the classifier, while loss_fn is the loss function for gradient calculation. The perturb method takes x and y as its arguments, where x is the input to be attacked, y is the true label of x. predict(x)and contains the “logits” of the neural work. The loss_fn could be the cross-entropy loss function or another suitable loss function who takes predict(x) and y as its arguments.

Thanks to the dynamic computation graph nature of PyTorch, the actual attack algorithm can be implemented in a straightforward way with a few lines. For other types of attacks, we just need replace the algorithm part of the code in perturb and change what parameters to pass to __init__.

Note that the decoupling of these three core components is flexible enough to allow more versatile attacks. In general, we require the predict and loss_fn to be designed in such a way that loss_fn always takes predict(x) and y as its inputs. As such, no knowledge about predict and loss_fn is required by the perturb method. For example, FastFeatureAttack and PGDAttack share the same underlying perturb_iterative function, but differ in the predict and loss_fn. In FastFeatureAttack, the predict(x) outputs the feature representation from a specific layer, the y is the guide feature representation that we want predict(x) to match, and the loss_fn becomes the mean squared error.

More generally, y can be any target of the adversarial perturbation, while predict(x) can output more complex data structures as long as the loss_fn can take them as its inputs. For example, we might want to generate one perturbation that fools both model A’s classification result and model B’s feature representation at the same time. In this case, we just need to make y and predict(x) to be tuples of labels and features, and modify the loss_fn accordingly. There is no need to modify the original perturbation implementation.

Setting up a new defense

As mentioned above, AdverTorch provides modules for preprocessing-based defense and examples for robust training.

We use MedianSmoothing2D as an example to illustrate how to define a preprocessing-based defense.

class MedianSmoothing2D(Processor):
    """
    Median Smoothing 2D.
    :param kernel_size: aperture linear size; must be odd and greater than 1.
    :param stride: stride of the convolution.
    """
    def __init__(self, kernel_size=3, stride=1):
        super(MedianSmoothing2D, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        padding = int(kernel_size) // 2
        if _is_even(kernel_size):
            # both ways of padding should be fine here
            # self.padding = (padding, 0, padding, 0)
            self.padding = (0, padding, 0, padding)
        else:
            self.padding = _quadruple(padding)

     def forward(self, x):
        x = F.pad(x, pad=self.padding, mode="reflect")
        x = x.unfold(2, self.kernel_size, self.stride)
        x = x.unfold(3, self.kernel_size, self.stride)
        x = x.contiguous().view(x.shape[:4] + (-1, )).median(dim=-1)[0]
        return x

The preprocessor is simply a torch.nn.Module. Its __init__ function takes the necessary parameters, and the forward function implements the actual preprocessing algorithm. When using MedianSmoothing2D, it can be composed with the original model to become a new model:

median_filter = MedianSmoothing2D()
new_model = torch.nn.Sequential(median_filter, model)
y = new_model(x)

or to be called sequentially

processed_x = median_filter(x)
y = model(processed_x)

We provide an example of how to use AdverTorch to do adversarial training (Madry et al. 2018) in tutorial_train_mnist.py. Compared to regular training, we only need to two changes. The first is to initialize an adversary before training starts.

   if flag_advtrain:
        from advertorch.attacks import LinfPGDAttack
        adversary = LinfPGDAttack(
            model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.3,
            nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0.0,
            clip_max=1.0, targeted=False)

The second is to generate the “adversarial mini-batch” during training, and use it to train the model instead of the original mini-batch.

     if flag_advtrain:
                advdata = adversary.perturb(clndata, target)
                with torch.no_grad():
                    output = model(advdata)
                test_advloss += F.cross_entropy(
                    output, target, reduction='sum').item()
                pred = output.max(1, keepdim=True)[1]
                advcorrect += pred.eq(target.view_as(pred)).sum().item()

Since building the toolkit, we’ve already used it for two papers: i) On the Sensitivity of Adversarial Robustness to Input Data Distributions; and ii) MMA Training: Direct Input Space Margin Maximization through Adversarial Training. It’s our sincere hope that AdverTorch helps you in your research and that you find its components useful. Of course, we welcome any contributions for the community and would love to hear your feedback. You can open an issue or a pull request for AdverTorch, or email me at gavin.ding@borealisai.com