Browse Source

Added FinePruning defence.

bart 11 months ago
parent
commit
2a7c9cd73c

+ 499 - 0
Notebooks/Defences/FinePruningFTT/FTtransformer/ft_transformer.py

@@ -0,0 +1,499 @@
+# %%
+import math
+import typing as ty
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as nn_init
+import zero
+from torch import Tensor
+
+from . import lib
+
+# %%
+class Tokenizer(nn.Module):
+    category_offsets: ty.Optional[Tensor]
+
+    def __init__(
+        self,
+        d_numerical: int,
+        categories: ty.Optional[ty.List[int]],
+        d_token: int,
+        bias: bool,
+    ) -> None:
+        super().__init__()
+        if categories is None:
+            d_bias = d_numerical
+            self.category_offsets = None
+            self.category_embeddings = None
+        else:
+            d_bias = d_numerical + len(categories)
+            category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0)
+            self.register_buffer('category_offsets', category_offsets)
+            self.category_embeddings = nn.Embedding(sum(categories), d_token)
+            nn_init.kaiming_uniform_(self.category_embeddings.weight, a=math.sqrt(5))
+            print(f'{self.category_embeddings.weight.shape=}')
+
+        # take [CLS] token into account
+        self.weight = nn.Parameter(Tensor(d_numerical + 1, d_token))
+        self.bias = nn.Parameter(Tensor(d_bias, d_token)) if bias else None
+        # The initialization is inspired by nn.Linear
+        nn_init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+        if self.bias is not None:
+            nn_init.kaiming_uniform_(self.bias, a=math.sqrt(5))
+
+    @property
+    def n_tokens(self) -> int:
+        return len(self.weight) + (
+            0 if self.category_offsets is None else len(self.category_offsets)
+        )
+
+    def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor]) -> Tensor:
+        x_some = x_num if x_cat is None else x_cat
+        assert x_some is not None
+        x_num = torch.cat(
+            [torch.ones(len(x_some), 1, device=x_some.device)]  # [CLS]
+            + ([] if x_num is None else [x_num]),
+            dim=1,
+        )
+        x = self.weight[None] * x_num[:, :, None]
+        if x_cat is not None:
+            x = torch.cat(
+                [x, self.category_embeddings(x_cat + self.category_offsets[None])],
+                dim=1,
+            )
+        if self.bias is not None:
+            bias = torch.cat(
+                [
+                    torch.zeros(1, self.bias.shape[1], device=x.device),
+                    self.bias,
+                ]
+            )
+            x = x + bias[None]
+        return x
+
+
+class MultiheadAttention(nn.Module):
+    def __init__(
+        self, d: int, n_heads: int, dropout: float, initialization: str
+    ) -> None:
+        if n_heads > 1:
+            assert d % n_heads == 0
+        assert initialization in ['xavier', 'kaiming']
+
+        super().__init__()
+        self.W_q = nn.Linear(d, d)
+        self.W_k = nn.Linear(d, d)
+        self.W_v = nn.Linear(d, d)
+        self.W_out = nn.Linear(d, d) if n_heads > 1 else None
+        self.n_heads = n_heads
+        self.dropout = nn.Dropout(dropout) if dropout else None
+
+        for m in [self.W_q, self.W_k, self.W_v]:
+            if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v):
+                # gain is needed since W_qkv is represented with 3 separate layers
+                nn_init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
+            nn_init.zeros_(m.bias)
+        if self.W_out is not None:
+            nn_init.zeros_(self.W_out.bias)
+
+    def _reshape(self, x: Tensor) -> Tensor:
+        batch_size, n_tokens, d = x.shape
+        d_head = d // self.n_heads
+        return (
+            x.reshape(batch_size, n_tokens, self.n_heads, d_head)
+            .transpose(1, 2)
+            .reshape(batch_size * self.n_heads, n_tokens, d_head)
+        )
+
+    def forward(
+        self,
+        x_q: Tensor,
+        x_kv: Tensor,
+        key_compression: ty.Optional[nn.Linear],
+        value_compression: ty.Optional[nn.Linear],
+    ) -> Tensor:
+        q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
+        for tensor in [q, k, v]:
+            assert tensor.shape[-1] % self.n_heads == 0
+        if key_compression is not None:
+            assert value_compression is not None
+            k = key_compression(k.transpose(1, 2)).transpose(1, 2)
+            v = value_compression(v.transpose(1, 2)).transpose(1, 2)
+        else:
+            assert value_compression is None
+
+        batch_size = len(q)
+        d_head_key = k.shape[-1] // self.n_heads
+        d_head_value = v.shape[-1] // self.n_heads
+        n_q_tokens = q.shape[1]
+
+        q = self._reshape(q)
+        k = self._reshape(k)
+        attention = F.softmax(q @ k.transpose(1, 2) / math.sqrt(d_head_key), dim=-1)
+        if self.dropout is not None:
+            attention = self.dropout(attention)
+        x = attention @ self._reshape(v)
+        x = (
+            x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
+            .transpose(1, 2)
+            .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
+        )
+        if self.W_out is not None:
+            x = self.W_out(x)
+        return x
+
+
+class Transformer(nn.Module):
+    """Transformer.
+
+    References:
+    - https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
+    - https://github.com/facebookresearch/pytext/tree/master/pytext/models/representations/transformer
+    - https://github.com/pytorch/fairseq/blob/1bba712622b8ae4efb3eb793a8a40da386fe11d0/examples/linformer/linformer_src/modules/multihead_linear_attention.py#L19
+    """
+
+    def __init__(
+        self,
+        *,
+        # tokenizer
+        d_numerical: int,
+        categories: ty.Optional[ty.List[int]],
+        token_bias: bool,
+        # transformer
+        n_layers: int,
+        d_token: int,
+        n_heads: int,
+        d_ffn_factor: float,
+        attention_dropout: float,
+        ffn_dropout: float,
+        residual_dropout: float,
+        activation: str,
+        prenormalization: bool,
+        initialization: str,
+        # linformer
+        kv_compression: ty.Optional[float],
+        kv_compression_sharing: ty.Optional[str],
+        #
+        d_out: int,
+    ) -> None:
+        assert (kv_compression is None) ^ (kv_compression_sharing is not None)
+
+        super().__init__()
+        self.tokenizer = Tokenizer(d_numerical, categories, d_token, token_bias)
+        n_tokens = self.tokenizer.n_tokens
+
+        def make_kv_compression():
+            assert kv_compression
+            compression = nn.Linear(
+                n_tokens, int(n_tokens * kv_compression), bias=False
+            )
+            if initialization == 'xavier':
+                nn_init.xavier_uniform_(compression.weight)
+            return compression
+
+        self.shared_kv_compression = (
+            make_kv_compression()
+            if kv_compression and kv_compression_sharing == 'layerwise'
+            else None
+        )
+
+        def make_normalization():
+            return nn.LayerNorm(d_token)
+
+        d_hidden = int(d_token * d_ffn_factor)
+        self.layers = nn.ModuleList([])
+        for layer_idx in range(n_layers):
+            layer = nn.ModuleDict(
+                {
+                    'attention': MultiheadAttention(
+                        d_token, n_heads, attention_dropout, initialization
+                    ),
+                    'linear0': nn.Linear(
+                        d_token, d_hidden * (2 if activation.endswith('glu') else 1)
+                    ),
+                    'linear1': nn.Linear(d_hidden, d_token),
+                    'norm1': make_normalization(),
+                }
+            )
+            if not prenormalization or layer_idx:
+                layer['norm0'] = make_normalization()
+            if kv_compression and self.shared_kv_compression is None:
+                layer['key_compression'] = make_kv_compression()
+                if kv_compression_sharing == 'headwise':
+                    layer['value_compression'] = make_kv_compression()
+                else:
+                    assert kv_compression_sharing == 'key-value'
+            self.layers.append(layer)
+
+        self.activation = lib.get_activation_fn(activation)
+        self.last_activation = lib.get_nonglu_activation_fn(activation)
+        self.prenormalization = prenormalization
+        self.last_normalization = make_normalization() if prenormalization else None
+        self.ffn_dropout = ffn_dropout
+        self.residual_dropout = residual_dropout
+        self.head = nn.Linear(d_token, d_out)
+
+    def _get_kv_compressions(self, layer):
+        return (
+            (self.shared_kv_compression, self.shared_kv_compression)
+            if self.shared_kv_compression is not None
+            else (layer['key_compression'], layer['value_compression'])
+            if 'key_compression' in layer and 'value_compression' in layer
+            else (layer['key_compression'], layer['key_compression'])
+            if 'key_compression' in layer
+            else (None, None)
+        )
+
+    def _start_residual(self, x, layer, norm_idx):
+        x_residual = x
+        if self.prenormalization:
+            norm_key = f'norm{norm_idx}'
+            if norm_key in layer:
+                x_residual = layer[norm_key](x_residual)
+        return x_residual
+
+    def _end_residual(self, x, x_residual, layer, norm_idx):
+        if self.residual_dropout:
+            x_residual = F.dropout(x_residual, self.residual_dropout, self.training)
+        x = x + x_residual
+        if not self.prenormalization:
+            x = layer[f'norm{norm_idx}'](x)
+        return x
+
+    def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor]) -> Tensor:
+        x = self.tokenizer(x_num, x_cat)
+
+        for layer_idx, layer in enumerate(self.layers):
+            is_last_layer = layer_idx + 1 == len(self.layers)
+            layer = ty.cast(ty.Dict[str, nn.Module], layer)
+
+            x_residual = self._start_residual(x, layer, 0)
+            x_residual = layer['attention'](
+                # for the last attention, it is enough to process only [CLS]
+                (x_residual[:, :1] if is_last_layer else x_residual),
+                x_residual,
+                *self._get_kv_compressions(layer),
+            )
+            if is_last_layer:
+                x = x[:, : x_residual.shape[1]]
+            x = self._end_residual(x, x_residual, layer, 0)
+
+            x_residual = self._start_residual(x, layer, 1)
+            x_residual = layer['linear0'](x_residual)
+            x_residual = self.activation(x_residual)
+            if self.ffn_dropout:
+                x_residual = F.dropout(x_residual, self.ffn_dropout, self.training)
+            x_residual = layer['linear1'](x_residual)
+            x = self._end_residual(x, x_residual, layer, 1)
+
+        assert x.shape[1] == 1
+        x = x[:, 0]
+        if self.last_normalization is not None:
+            x = self.last_normalization(x)
+        x = self.last_activation(x)
+        x = self.head(x)
+        x = x.squeeze(-1)
+        return x
+
+
+class FTtransformer():
+    def __init__(
+        self,
+        config
+        ):
+        self.config = config
+
+    def fit(self, checkpoint_path):
+        config = self.config # quick dirty method
+
+        zero.set_randomness(config['seed'])
+        dataset_dir = config['data']['path']
+
+        D = lib.Dataset.from_dir(dataset_dir)
+        X = D.build_X(
+            normalization=config['data'].get('normalization'),
+            num_nan_policy='mean',
+            cat_nan_policy='new',
+            cat_policy=config['data'].get('cat_policy', 'indices'),
+            cat_min_frequency=config['data'].get('cat_min_frequency', 0.0),
+            seed=config['seed'],
+        )
+        if not isinstance(X, tuple):
+            X = (X, None)
+
+        Y, y_info = D.build_y(config['data'].get('y_policy'))
+
+        X = tuple(None if x is None else lib.to_tensors(x) for x in X)
+        Y = lib.to_tensors(Y)
+        device = torch.device(config['training']['device'])
+        print("Using device:", config['training']['device'])
+        if device.type != 'cpu':
+            X = tuple(
+                None if x is None else {k: v.to(device) for k, v in x.items()} for x in X
+            )
+            Y_device = {k: v.to(device) for k, v in Y.items()}
+        else:
+            Y_device = Y
+        X_num, X_cat = X
+        del X
+        if not D.is_multiclass:
+            Y_device = {k: v.float() for k, v in Y_device.items()}
+
+        train_size = D.size(lib.TRAIN)
+        batch_size = config['training']['batch_size']
+        epoch_size = math.ceil(train_size / batch_size)
+        eval_batch_size = config['training']['eval_batch_size']
+        chunk_size = None
+
+        loss_fn = (
+            F.binary_cross_entropy_with_logits
+            if D.is_binclass
+            else F.cross_entropy
+            if D.is_multiclass
+            else F.mse_loss
+        )
+
+        model = Transformer(
+            d_numerical=0 if X_num is None else X_num['train'].shape[1],
+            categories=lib.get_categories(X_cat),
+            d_out=D.info['n_classes'] if D.is_multiclass else 1,
+            **config['model'],
+        ).to(device)
+
+        def needs_wd(name):
+            return all(x not in name for x in ['tokenizer', '.norm', '.bias'])
+
+        for x in ['tokenizer', '.norm', '.bias']:
+            assert any(x in a for a in (b[0] for b in model.named_parameters()))
+        parameters_with_wd = [v for k, v in model.named_parameters() if needs_wd(k)]
+        parameters_without_wd = [v for k, v in model.named_parameters() if not needs_wd(k)]
+        optimizer = lib.make_optimizer(
+            config['training']['optimizer'],
+            (
+                [
+                    {'params': parameters_with_wd},
+                    {'params': parameters_without_wd, 'weight_decay': 0.0},
+                ]
+            ),
+            config['training']['lr'],
+            config['training']['weight_decay'],
+        )
+
+        stream = zero.Stream(lib.IndexLoader(train_size, batch_size, True, device))
+        progress = zero.ProgressTracker(config['training']['patience'])
+        training_log = {lib.TRAIN: [], lib.VAL: [], lib.TEST: []}
+        timer = zero.Timer()
+        output = "Checkpoints"
+
+        def print_epoch_info():
+            print(f'\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())} | {output}')
+            print(
+                ' | '.join(
+                    f'{k} = {v}'
+                    for k, v in {
+                        'lr': lib.get_lr(optimizer),
+                        'batch_size': batch_size,
+                        'chunk_size': chunk_size,
+                    }.items()
+                )
+            )
+
+        def apply_model(part, idx):
+            return model(
+                None if X_num is None else X_num[part][idx],
+                None if X_cat is None else X_cat[part][idx],
+            )
+
+        @torch.no_grad()
+        def evaluate(parts):
+            eval_batch_size = self.config['training']['eval_batch_size']
+            model.eval()
+            metrics = {}
+            predictions = {}
+            for part in parts:
+                while eval_batch_size:
+                    try:
+                        predictions[part] = (
+                            torch.cat(
+                                [
+                                    apply_model(part, idx)
+                                    for idx in lib.IndexLoader(
+                                        D.size(part), eval_batch_size, False, device
+                                    )
+                                ]
+                            )
+                            .cpu()
+                            .numpy()
+                        )
+                    except RuntimeError as err:
+                        if not lib.is_oom_exception(err):
+                            raise
+                        eval_batch_size //= 2
+                        print('New eval batch size:', eval_batch_size)
+                    else:
+                        break
+                if not eval_batch_size:
+                    RuntimeError('Not enough memory even for eval_batch_size=1')
+                metrics[part] = lib.calculate_metrics(
+                    D.info['task_type'],
+                    Y[part].numpy(),  # type: ignore[code]
+                    predictions[part],  # type: ignore[code]
+                    'logits',
+                    y_info,
+                )
+            for part, part_metrics in metrics.items():
+                print(f'[{part:<5}]', lib.make_summary(part_metrics))
+            return metrics, predictions
+
+        def save_checkpoint(final):
+            torch.save(
+                {
+                    'model': model.state_dict(),
+                    'optimizer': optimizer.state_dict(),
+                    'stream': stream.state_dict(),
+                    'random_state': zero.get_random_state(),
+                },
+                checkpoint_path,
+            )
+
+        zero.set_randomness(config['seed'])
+
+        for epoch in stream.epochs(config['training']['n_epochs']):
+            print(f'\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())}')
+            model.train()
+            epoch_losses = []
+            for batch_idx in epoch:
+                loss, new_chunk_size = lib.train_with_auto_virtual_batch(
+                    optimizer,
+                    loss_fn,
+                    lambda x: (apply_model(lib.TRAIN, x), Y_device[lib.TRAIN][x]),
+                    batch_idx,
+                    chunk_size or batch_size,
+                )
+                epoch_losses.append(loss.detach())
+                if new_chunk_size and new_chunk_size < (chunk_size or batch_size):
+                    print('New chunk size:', chunk_size)
+            epoch_losses = torch.stack(epoch_losses).tolist()
+            print(f'[{lib.TRAIN}] loss = {round(sum(epoch_losses) / len(epoch_losses), 3)}')
+
+            metrics, predictions = evaluate([lib.VAL, lib.TEST])
+            for k, v in metrics.items():
+                training_log[k].append(v)
+            progress.update(metrics[lib.VAL]['score'])
+
+            if progress.success:
+                print('New best epoch!')
+                save_checkpoint(False)
+
+            elif progress.fail:
+                break
+
+        # Load best checkpoint
+        model.load_state_dict(torch.load(checkpoint_path)['model'])
+        metrics, predictions = evaluate(lib.PARTS)
+
+        return metrics

+ 4 - 0
Notebooks/Defences/FinePruningFTT/FTtransformer/lib/__init__.py

@@ -0,0 +1,4 @@
+from .data import *  # noqa
+from .deep import *  # noqa
+from .metrics import *  # noqa
+from .util import *  # noqa

+ 234 - 0
Notebooks/Defences/FinePruningFTT/FTtransformer/lib/data.py

@@ -0,0 +1,234 @@
+import dataclasses as dc
+import pickle
+import typing as ty
+import warnings
+from collections import Counter
+from copy import deepcopy
+from pathlib import Path
+
+import numpy as np
+import sklearn.preprocessing
+import torch
+from category_encoders import LeaveOneOutEncoder
+from sklearn.impute import SimpleImputer
+
+from . import util
+
+ArrayDict = ty.Dict[str, np.ndarray]
+
+
+def normalize(
+    X: ArrayDict, normalization: str, seed: int, noise: float = 1e-3
+) -> ArrayDict:
+    X_train = X['train'].copy()
+    if normalization == 'standard':
+        normalizer = sklearn.preprocessing.StandardScaler()
+    elif normalization == 'quantile':
+        normalizer = sklearn.preprocessing.QuantileTransformer(
+            output_distribution='normal',
+            n_quantiles=max(min(X['train'].shape[0] // 30, 1000), 10),
+            subsample=1e9,
+            random_state=seed,
+        )
+        if noise:
+            stds = np.std(X_train, axis=0, keepdims=True)
+            noise_std = noise / np.maximum(stds, noise)  # type: ignore[code]
+            X_train += noise_std * np.random.default_rng(seed).standard_normal(  # type: ignore[code]
+                X_train.shape
+            )
+    else:
+        util.raise_unknown('normalization', normalization)
+    normalizer.fit(X_train)
+    return {k: normalizer.transform(v) for k, v in X.items()}  # type: ignore[code]
+
+
+@dc.dataclass
+class Dataset:
+    N: ty.Optional[ArrayDict]
+    C: ty.Optional[ArrayDict]
+    y: ArrayDict
+    info: ty.Dict[str, ty.Any]
+    folder: ty.Optional[Path]
+
+    @classmethod
+    def from_dir(cls, dir_: ty.Union[Path, str]) -> 'Dataset':
+        dir_ = Path(dir_)
+
+        def load(item) -> ArrayDict:
+            return {
+                x: ty.cast(np.ndarray, np.load(dir_ / f'{item}_{x}.npy', allow_pickle=True))  # type: ignore[code]
+                for x in ['train', 'val', 'test', 'test_backdoor']
+            }
+
+        return Dataset(
+            load('N') if dir_.joinpath('N_train.npy').exists() else None,
+            load('C') if dir_.joinpath('C_train.npy').exists() else None,
+            load('y'),
+            util.load_json(dir_ / 'info.json'),
+            dir_,
+        )
+
+    @property
+    def is_binclass(self) -> bool:
+        return self.info['task_type'] == util.BINCLASS
+
+    @property
+    def is_multiclass(self) -> bool:
+        return self.info['task_type'] == util.MULTICLASS
+
+    @property
+    def is_regression(self) -> bool:
+        return self.info['task_type'] == util.REGRESSION
+
+    @property
+    def n_num_features(self) -> int:
+        return self.info['n_num_features']
+
+    @property
+    def n_cat_features(self) -> int:
+        return self.info['n_cat_features']
+
+    @property
+    def n_features(self) -> int:
+        return self.n_num_features + self.n_cat_features
+
+    def size(self, part: str) -> int:
+        X = self.N if self.N is not None else self.C
+        assert X is not None
+        return len(X[part])
+
+    def build_X(
+        self,
+        *,
+        normalization: ty.Optional[str],
+        num_nan_policy: str,
+        cat_nan_policy: str,
+        cat_policy: str,
+        cat_min_frequency: float = 0.0,
+        seed: int,
+    ) -> ty.Union[ArrayDict, ty.Tuple[ArrayDict, ArrayDict]]:
+        if self.N:
+            N = deepcopy(self.N)
+
+            num_nan_masks = {k: np.isnan(v) for k, v in N.items()}
+            if any(x.any() for x in num_nan_masks.values()):  # type: ignore[code]
+                if num_nan_policy == 'mean':
+                    num_new_values = np.nanmean(self.N['train'], axis=0)
+                else:
+                    util.raise_unknown('numerical NaN policy', num_nan_policy)
+                for k, v in N.items():
+                    num_nan_indices = np.where(num_nan_masks[k])
+                    v[num_nan_indices] = np.take(num_new_values, num_nan_indices[1])
+            if normalization:
+                N = normalize(N, normalization, seed)
+
+        else:
+            N = None
+
+        if cat_policy == 'drop' or not self.C:
+            assert N is not None
+            return N
+
+        C = deepcopy(self.C)
+
+        cat_nan_masks = {k: v == 'nan' for k, v in C.items()}
+        if any(x.any() for x in cat_nan_masks.values()):  # type: ignore[code]
+            if cat_nan_policy == 'new':
+                cat_new_value = '___null___'
+                imputer = None
+            elif cat_nan_policy == 'most_frequent':
+                cat_new_value = None
+                imputer = SimpleImputer(strategy=cat_nan_policy)  # type: ignore[code]
+                imputer.fit(C['train'])
+            else:
+                util.raise_unknown('categorical NaN policy', cat_nan_policy)
+            if imputer:
+                C = {k: imputer.transform(v) for k, v in C.items()}
+            else:
+                for k, v in C.items():
+                    cat_nan_indices = np.where(cat_nan_masks[k])
+                    v[cat_nan_indices] = cat_new_value
+
+        if cat_min_frequency:
+            C = ty.cast(ArrayDict, C)
+            min_count = round(len(C['train']) * cat_min_frequency)
+            rare_value = '___rare___'
+            C_new = {x: [] for x in C}
+            for column_idx in range(C['train'].shape[1]):
+                counter = Counter(C['train'][:, column_idx].tolist())
+                popular_categories = {k for k, v in counter.items() if v >= min_count}
+                for part in C_new:
+                    C_new[part].append(
+                        [
+                            (x if x in popular_categories else rare_value)
+                            for x in C[part][:, column_idx].tolist()
+                        ]
+                    )
+            C = {k: np.array(v).T for k, v in C_new.items()}
+
+        unknown_value = np.iinfo('int64').max - 3
+        encoder = sklearn.preprocessing.OrdinalEncoder(
+            handle_unknown='use_encoded_value',  # type: ignore[code]
+            unknown_value=unknown_value,  # type: ignore[code]
+            dtype='int64',  # type: ignore[code]
+        ).fit(C['train'])
+        C = {k: encoder.transform(v) for k, v in C.items()}
+        max_values = C['train'].max(axis=0)
+        for part in ['val', 'test', 'test_backdoor']:
+            for column_idx in range(C[part].shape[1]):
+                C[part][C[part][:, column_idx] == unknown_value, column_idx] = (
+                    max_values[column_idx] + 1
+                )
+
+        if cat_policy == 'indices':
+            result = (N, C)
+        elif cat_policy == 'ohe':
+            ohe = sklearn.preprocessing.OneHotEncoder(
+                handle_unknown='ignore', sparse=False, dtype='float32'  # type: ignore[code]
+            )
+            ohe.fit(C['train'])
+            C = {k: ohe.transform(v) for k, v in C.items()}
+            result = C if N is None else {x: np.hstack((N[x], C[x])) for x in N}
+        elif cat_policy == 'counter':
+            assert seed is not None
+            loo = LeaveOneOutEncoder(sigma=0.1, random_state=seed, return_df=False)
+            loo.fit(C['train'], self.y['train'])
+            C = {k: loo.transform(v).astype('float32') for k, v in C.items()}  # type: ignore[code]
+            if not isinstance(C['train'], np.ndarray):
+                C = {k: v.values for k, v in C.items()}  # type: ignore[code]
+            if normalization:
+                C = normalize(C, normalization, seed, inplace=True)  # type: ignore[code]
+            result = C if N is None else {x: np.hstack((N[x], C[x])) for x in N}
+        else:
+            util.raise_unknown('categorical policy', cat_policy)
+        return result  # type: ignore[code]
+
+    def build_y(
+        self, policy: ty.Optional[str]
+    ) -> ty.Tuple[ArrayDict, ty.Optional[ty.Dict[str, ty.Any]]]:
+        if self.is_regression:
+            assert policy == 'mean_std'
+        y = deepcopy(self.y)
+        if policy:
+            if not self.is_regression:
+                warnings.warn('y_policy is not None, but the task is NOT regression')
+                info = None
+            elif policy == 'mean_std':
+                mean, std = self.y['train'].mean(), self.y['train'].std()
+                y = {k: (v - mean) / std for k, v in y.items()}
+                info = {'policy': policy, 'mean': mean, 'std': std}
+            else:
+                util.raise_unknown('y policy', policy)
+        else:
+            info = None
+        return y, info
+
+
+def to_tensors(data: ArrayDict) -> ty.Dict[str, torch.Tensor]:
+    return {k: torch.as_tensor(v) for k, v in data.items()}
+
+
+def load_dataset_info(dataset_name: str) -> ty.Dict[str, ty.Any]:
+    info = util.load_json(env.DATA_DIR / dataset_name / 'info.json')
+    info['size'] = info['train_size'] + info['val_size'] + info['test_size'] + info['test_backdoor_size']
+    return info

+ 838 - 0
Notebooks/Defences/FinePruningFTT/FTtransformer/lib/deep.py

@@ -0,0 +1,838 @@
+from __future__ import absolute_import, division, print_function
+
+import math
+import os
+import typing as ty
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import zero
+from torch import Tensor
+
+
+class IndexLoader:
+    def __init__(
+        self, train_size: int, batch_size: int, shuffle: bool, device: torch.device
+    ) -> None:
+        self._train_size = train_size
+        self._batch_size = batch_size
+        self._shuffle = shuffle
+        self._device = device
+
+    def __len__(self) -> int:
+        return math.ceil(self._train_size / self._batch_size)
+
+    def __iter__(self):
+        indices = list(
+            zero.iloader(self._train_size, self._batch_size, shuffle=self._shuffle)
+        )
+        return iter(torch.cat(indices).to(self._device).split(self._batch_size))
+
+
+class Lambda(nn.Module):
+    def __init__(self, f: ty.Callable) -> None:
+        super().__init__()
+        self.f = f
+
+    def forward(self, x):
+        return self.f(x)
+
+
+# Source: https://github.com/bzhangGo/rmsnorm
+# NOTE: eps is changed to 1e-5
+class RMSNorm(nn.Module):
+    def __init__(self, d, p=-1.0, eps=1e-5, bias=False):
+        """Root Mean Square Layer Normalization
+
+        :param d: model size
+        :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
+        :param eps:  epsilon value, default 1e-8
+        :param bias: whether use bias term for RMSNorm, disabled by
+            default because RMSNorm doesn't enforce re-centering invariance.
+        """
+        super(RMSNorm, self).__init__()
+
+        self.eps = eps
+        self.d = d
+        self.p = p
+        self.bias = bias
+
+        self.scale = nn.Parameter(torch.ones(d))
+        self.register_parameter("scale", self.scale)
+
+        if self.bias:
+            self.offset = nn.Parameter(torch.zeros(d))
+            self.register_parameter("offset", self.offset)
+
+    def forward(self, x):
+        if self.p < 0.0 or self.p > 1.0:
+            norm_x = x.norm(2, dim=-1, keepdim=True)
+            d_x = self.d
+        else:
+            partial_size = int(self.d * self.p)
+            partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)
+
+            norm_x = partial_x.norm(2, dim=-1, keepdim=True)
+            d_x = partial_size
+
+        rms_x = norm_x * d_x ** (-1.0 / 2)
+        x_normed = x / (rms_x + self.eps)
+
+        if self.bias:
+            return self.scale * x_normed + self.offset
+
+        return self.scale * x_normed
+
+
+class ScaleNorm(nn.Module):
+    """
+    Sources:
+    - https://github.com/tnq177/transformers_without_tears/blob/25026061979916afb193274438f7097945acf9bc/layers.py#L132
+    - https://github.com/tnq177/transformers_without_tears/blob/6b2726cd9e6e642d976ae73b9f696d9d7ff4b395/layers.py#L157
+    """
+
+    def __init__(self, d: int, eps: float = 1e-5, clamp: bool = False) -> None:
+        super(ScaleNorm, self).__init__()
+        self.scale = nn.Parameter(torch.tensor(d ** 0.5))
+        self.eps = eps
+        self.clamp = clamp
+
+    def forward(self, x):
+        norms = torch.norm(x, dim=-1, keepdim=True)
+        norms = norms.clamp(min=self.eps) if self.clamp else norms + self.eps
+        return self.scale * x / norms
+
+
+def reglu(x: Tensor) -> Tensor:
+    a, b = x.chunk(2, dim=-1)
+    return a * F.relu(b)
+
+
+def geglu(x: Tensor) -> Tensor:
+    a, b = x.chunk(2, dim=-1)
+    return a * F.gelu(b)
+
+
+class ReGLU(nn.Module):
+    def forward(self, x: Tensor) -> Tensor:
+        return reglu(x)
+
+
+class GEGLU(nn.Module):
+    def forward(self, x: Tensor) -> Tensor:
+        return geglu(x)
+
+
+def make_optimizer(
+    optimizer: str,
+    parameter_groups,
+    lr: float,
+    weight_decay: float,
+) -> optim.Optimizer:
+    Optimizer = {
+        'adabelief': AdaBelief,
+        'adam': optim.Adam,
+        'adamw': optim.AdamW,
+        'radam': RAdam,
+        'sgd': optim.SGD,
+    }[optimizer]
+    momentum = (0.9,) if Optimizer is optim.SGD else ()
+    return Optimizer(parameter_groups, lr, *momentum, weight_decay=weight_decay)
+
+
+def make_lr_schedule(
+    optimizer: optim.Optimizer,
+    lr: float,
+    epoch_size: int,
+    lr_schedule: ty.Optional[ty.Dict[str, ty.Any]],
+) -> ty.Tuple[
+    ty.Optional[optim.lr_scheduler._LRScheduler],
+    ty.Dict[str, ty.Any],
+    ty.Optional[int],
+]:
+    if lr_schedule is None:
+        lr_schedule = {'type': 'constant'}
+    lr_scheduler = None
+    n_warmup_steps = None
+    if lr_schedule['type'] in ['transformer', 'linear_warmup']:
+        n_warmup_steps = (
+            lr_schedule['n_warmup_steps']
+            if 'n_warmup_steps' in lr_schedule
+            else lr_schedule['n_warmup_epochs'] * epoch_size
+        )
+    elif lr_schedule['type'] == 'cyclic':
+        lr_scheduler = optim.lr_scheduler.CyclicLR(
+            optimizer,
+            base_lr=lr,
+            max_lr=lr_schedule['max_lr'],
+            step_size_up=lr_schedule['n_epochs_up'] * epoch_size,
+            step_size_down=lr_schedule['n_epochs_down'] * epoch_size,
+            mode=lr_schedule['mode'],
+            gamma=lr_schedule.get('gamma', 1.0),
+            cycle_momentum=False,
+        )
+    return lr_scheduler, lr_schedule, n_warmup_steps
+
+
+def get_activation_fn(name: str) -> ty.Callable[[Tensor], Tensor]:
+    return (
+        reglu
+        if name == 'reglu'
+        else geglu
+        if name == 'geglu'
+        else torch.sigmoid
+        if name == 'sigmoid'
+        else getattr(F, name)
+    )
+
+
+def get_nonglu_activation_fn(name: str) -> ty.Callable[[Tensor], Tensor]:
+    return (
+        F.relu
+        if name == 'reglu'
+        else F.gelu
+        if name == 'geglu'
+        else get_activation_fn(name)
+    )
+
+
+def load_swa_state_dict(model: nn.Module, swa_model: optim.swa_utils.AveragedModel):
+    state_dict = deepcopy(swa_model.state_dict())
+    del state_dict['n_averaged']
+    model.load_state_dict({k[len('module.') :]: v for k, v in state_dict.items()})
+
+
+def get_epoch_parameters(
+    train_size: int, batch_size: ty.Union[int, str]
+) -> ty.Tuple[int, int]:
+    if isinstance(batch_size, str):
+        if batch_size == 'v3':
+            batch_size = (
+                256 if train_size < 50000 else 512 if train_size < 100000 else 1024
+            )
+        elif batch_size == 'v1':
+            batch_size = (
+                16
+                if train_size < 1000
+                else 32
+                if train_size < 10000
+                else 64
+                if train_size < 50000
+                else 128
+                if train_size < 100000
+                else 256
+                if train_size < 200000
+                else 512
+                if train_size < 500000
+                else 1024
+            )
+        elif batch_size == 'v2':
+            batch_size = (
+                512 if train_size < 100000 else 1024 if train_size < 500000 else 2048
+            )
+    return batch_size, math.ceil(train_size / batch_size)  # type: ignore[code]
+
+
+def get_linear_warmup_lr(lr: float, n_warmup_steps: int, step: int) -> float:
+    assert step > 0, "1-based enumeration of steps is expected"
+    return min(lr, step / (n_warmup_steps + 1) * lr)
+
+
+def get_manual_lr(schedule: ty.List[float], epoch: int) -> float:
+    assert epoch > 0, "1-based enumeration of epochs is expected"
+    return schedule[min(epoch, len(schedule)) - 1]
+
+
+def get_transformer_lr(scale: float, d: int, n_warmup_steps: int, step: int) -> float:
+    return scale * d ** -0.5 * min(step ** -0.5, step * n_warmup_steps ** -1.5)
+
+
+def learn(model, optimizer, loss_fn, step, batch, star) -> ty.Tuple[Tensor, ty.Any]:
+    model.train()
+    optimizer.zero_grad()
+    out = step(batch)
+    loss = loss_fn(*out) if star else loss_fn(out)
+    loss.backward()
+    optimizer.step()
+    return loss, out
+
+
+def _learn_with_virtual_batch(
+    model, optimizer, loss_fn, step, batch, chunk_size
+) -> Tensor:
+    batch_size = len(batch)
+    if chunk_size >= batch_size:
+        return learn(model, optimizer, loss_fn, step, batch, True)[0]
+    model.train()
+    optimizer.zero_grad()
+    total_loss = None
+    for chunk in zero.iter_batches(batch, chunk_size):
+        loss = loss_fn(*step(chunk))
+        loss = loss * len(chunk)
+        loss.backward()
+        if total_loss is None:
+            total_loss = loss.detach()
+        else:
+            total_loss += loss.detach()
+    for x in model.parameters():
+        if x.grad is not None:
+            x.grad /= batch_size
+    optimizer.step()
+    return total_loss / batch_size
+
+
+def learn_with_auto_virtual_batch(
+    model,
+    optimizer,
+    loss_fn,
+    step,
+    batch,
+    batch_size_hint: int,
+    chunk_size: ty.Optional[int],
+) -> ty.Tuple[Tensor, ty.Optional[int]]:
+    """This is just an overcomplicated version of `train_with_auto_virtual_batch`."""
+    random_state = zero.get_random_state()
+    while chunk_size != 0:
+        try:
+            zero.set_random_state(random_state)
+            return (
+                _learn_with_virtual_batch(
+                    model,
+                    optimizer,
+                    loss_fn,
+                    step,
+                    batch,
+                    chunk_size or batch_size_hint,
+                ),
+                chunk_size,
+            )
+        except RuntimeError as err:
+            if not is_oom_exception(err):
+                raise
+            if chunk_size is None:
+                chunk_size = batch_size_hint
+            chunk_size //= 2
+    raise RuntimeError('Not enough memory even for batch_size=1')
+
+
+def train_with_auto_virtual_batch(
+    optimizer,
+    loss_fn,
+    step,
+    batch,
+    chunk_size: int,
+) -> ty.Tuple[Tensor, int]:
+    batch_size = len(batch)
+    random_state = zero.get_random_state()
+    while chunk_size != 0:
+        try:
+            zero.set_random_state(random_state)
+            optimizer.zero_grad()
+            if batch_size <= chunk_size:
+                loss = loss_fn(*step(batch))
+                loss.backward()
+            else:
+                loss = None
+                for chunk in zero.iter_batches(batch, chunk_size):
+                    chunk_loss = loss_fn(*step(chunk))
+                    chunk_loss = chunk_loss * (len(chunk) / batch_size)
+                    chunk_loss.backward()
+                    if loss is None:
+                        loss = chunk_loss.detach()
+                    else:
+                        loss += chunk_loss.detach()
+        except RuntimeError as err:
+            if not is_oom_exception(err):
+                raise
+            chunk_size //= 2
+        else:
+            break
+    if not chunk_size:
+        raise RuntimeError('Not enough memory even for batch_size=1')
+    optimizer.step()
+    return loss, chunk_size  # type: ignore[code]
+
+
+def tensor(x) -> torch.Tensor:
+    assert isinstance(x, torch.Tensor)
+    return ty.cast(torch.Tensor, x)
+
+
+def get_n_parameters(m: nn.Module):
+    return sum(x.numel() for x in m.parameters() if x.requires_grad)
+
+
+def get_mlp_n_parameters(units: ty.List[int]):
+    x = 0
+    for a, b in zip(units, units[1:]):
+        x += a * b + b
+    return x
+
+
+def get_lr(optimizer: optim.Optimizer) -> float:
+    return next(iter(optimizer.param_groups))['lr']
+
+
+def set_lr(optimizer: optim.Optimizer, lr: float) -> None:
+    for x in optimizer.param_groups:
+        x['lr'] = lr
+
+
+def get_device() -> torch.device:
+    return torch.device('cuda:0' if os.environ.get('CUDA_VISIBLE_DEVICES') else 'cpu')
+
+
+@torch.no_grad()
+def get_gradient_norm_ratios(m: nn.Module):
+    return {
+        k: v.grad.norm() / v.norm()
+        for k, v in m.named_parameters()
+        if v.grad is not None
+    }
+
+
+def is_oom_exception(err: RuntimeError) -> bool:
+    return any(
+        x in str(err)
+        for x in [
+            'CUDA out of memory',
+            'CUBLAS_STATUS_ALLOC_FAILED',
+            'CUDA error: out of memory',
+        ]
+    )
+
+
+# Source: https://github.com/LiyuanLucasLiu/RAdam
+class RAdam(optim.Optimizer):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=0,
+        degenerated_to_sgd=True,
+    ):
+        if not 0.0 <= lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 <= eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+
+        self.degenerated_to_sgd = degenerated_to_sgd
+        if (
+            isinstance(params, (list, tuple))
+            and len(params) > 0
+            and isinstance(params[0], dict)
+        ):
+            for param in params:
+                if 'betas' in param and (
+                    param['betas'][0] != betas[0] or param['betas'][1] != betas[1]
+                ):
+                    param['buffer'] = [[None, None, None] for _ in range(10)]
+        defaults = dict(
+            lr=lr,
+            betas=betas,
+            eps=eps,
+            weight_decay=weight_decay,
+            buffer=[[None, None, None] for _ in range(10)],
+        )
+        super(RAdam, self).__init__(params, defaults)
+
+    def __setstate__(self, state):
+        super(RAdam, self).__setstate__(state)
+
+    def step(self, closure=None):
+
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data.float()
+                if grad.is_sparse:
+                    raise RuntimeError('RAdam does not support sparse gradients')
+
+                p_data_fp32 = p.data.float()
+
+                state = self.state[p]
+
+                if len(state) == 0:
+                    state['step'] = 0
+                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
+                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
+                else:
+                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
+                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+                beta1, beta2 = group['betas']
+
+                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+                exp_avg.mul_(beta1).add_(1 - beta1, grad)
+
+                state['step'] += 1
+                buffered = group['buffer'][int(state['step'] % 10)]
+                if state['step'] == buffered[0]:
+                    N_sma, step_size = buffered[1], buffered[2]
+                else:
+                    buffered[0] = state['step']
+                    beta2_t = beta2 ** state['step']
+                    N_sma_max = 2 / (1 - beta2) - 1
+                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+                    buffered[1] = N_sma
+
+                    # more conservative since it's an approximated value
+                    if N_sma >= 5:
+                        step_size = math.sqrt(
+                            (1 - beta2_t)
+                            * (N_sma - 4)
+                            / (N_sma_max - 4)
+                            * (N_sma - 2)
+                            / N_sma
+                            * N_sma_max
+                            / (N_sma_max - 2)
+                        ) / (1 - beta1 ** state['step'])
+                    elif self.degenerated_to_sgd:
+                        step_size = 1.0 / (1 - beta1 ** state['step'])
+                    else:
+                        step_size = -1
+                    buffered[2] = step_size
+
+                # more conservative since it's an approximated value
+                if N_sma >= 5:
+                    if group['weight_decay'] != 0:
+                        p_data_fp32.add_(
+                            -group['weight_decay'] * group['lr'], p_data_fp32
+                        )
+                    denom = exp_avg_sq.sqrt().add_(group['eps'])
+                    p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
+                    p.data.copy_(p_data_fp32)
+                elif step_size > 0:
+                    if group['weight_decay'] != 0:
+                        p_data_fp32.add_(
+                            -group['weight_decay'] * group['lr'], p_data_fp32
+                        )
+                    p_data_fp32.add_(-step_size * group['lr'], exp_avg)
+                    p.data.copy_(p_data_fp32)
+
+        return loss
+
+
+version_higher = torch.__version__ >= "1.5.0"
+
+
+# Source: https://github.com/juntang-zhuang/Adabelief-Optimizer
+class AdaBelief(optim.Optimizer):
+    r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch
+    Arguments:
+        params (iterable): iterable of parameters to optimize or dicts defining
+            parameter groups
+        lr (float, optional): learning rate (default: 1e-3)
+        betas (Tuple[float, float], optional): coefficients used for computing
+            running averages of gradient and its square (default: (0.9, 0.999))
+        eps (float, optional): term added to the denominator to improve
+            numerical stability (default: 1e-16)
+        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+            algorithm from the paper `On the Convergence of Adam and Beyond`_
+            (default: False)
+        weight_decouple (boolean, optional): ( default: True) If set as True, then
+            the optimizer uses decoupled weight decay as in AdamW
+        fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
+            is set as True.
+            When fixed_decay == True, the weight decay is performed as
+            $W_{new} = W_{old} - W_{old} \times decay$.
+            When fixed_decay == False, the weight decay is performed as
+            $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the
+            weight decay ratio decreases with learning rate (lr).
+        rectify (boolean, optional): (default: True) If set as True, then perform the rectified
+            update similar to RAdam
+        degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update
+            when variance of gradient is high
+        print_change_log (boolean, optional) (default: True) If set as True, print the modifcation to
+            default hyper-parameters
+    reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020
+    """
+
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-16,
+        weight_decay=0,
+        amsgrad=False,
+        weight_decouple=True,
+        fixed_decay=False,
+        rectify=True,
+        degenerated_to_sgd=True,
+        print_change_log=True,
+    ):
+
+        # ------------------------------------------------------------------------------
+        # Print modifications to default arguments
+        if print_change_log:
+            print(
+                'Please check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.'
+            )
+            print('Modifications to default arguments:')
+            default_table = [
+                ['eps', 'weight_decouple', 'rectify'],
+                ['adabelief-pytorch=0.0.5', '1e-8', 'False', 'False'],
+                ['>=0.1.0 (Current 0.2.0)', '1e-16', 'True', 'True'],
+            ]
+            print(default_table)
+
+            recommend_table = [
+                [
+                    'SGD better than Adam (e.g. CNN for Image Classification)',
+                    'Adam better than SGD (e.g. Transformer, GAN)',
+                ],
+                ['Recommended eps = 1e-8', 'Recommended eps = 1e-16'],
+            ]
+            print(recommend_table)
+
+            print('For a complete table of recommended hyperparameters, see')
+            print('https://github.com/juntang-zhuang/Adabelief-Optimizer')
+
+            print(
+                'You can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.'
+            )
+        # ------------------------------------------------------------------------------
+
+        if not 0.0 <= lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 <= eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+
+        self.degenerated_to_sgd = degenerated_to_sgd
+        if (
+            isinstance(params, (list, tuple))
+            and len(params) > 0
+            and isinstance(params[0], dict)
+        ):
+            for param in params:
+                if 'betas' in param and (
+                    param['betas'][0] != betas[0] or param['betas'][1] != betas[1]
+                ):
+                    param['buffer'] = [[None, None, None] for _ in range(10)]
+
+        defaults = dict(
+            lr=lr,
+            betas=betas,
+            eps=eps,
+            weight_decay=weight_decay,
+            amsgrad=amsgrad,
+            buffer=[[None, None, None] for _ in range(10)],
+        )
+        super(AdaBelief, self).__init__(params, defaults)
+
+        self.degenerated_to_sgd = degenerated_to_sgd
+        self.weight_decouple = weight_decouple
+        self.rectify = rectify
+        self.fixed_decay = fixed_decay
+        if self.weight_decouple:
+            print('Weight decoupling enabled in AdaBelief')
+            if self.fixed_decay:
+                print('Weight decay fixed')
+        if self.rectify:
+            print('Rectification enabled in AdaBelief')
+        if amsgrad:
+            print('AMSGrad enabled in AdaBelief')
+
+    def __setstate__(self, state):
+        super(AdaBelief, self).__setstate__(state)
+        for group in self.param_groups:
+            group.setdefault('amsgrad', False)
+
+    def reset(self):
+        for group in self.param_groups:
+            for p in group['params']:
+                state = self.state[p]
+                amsgrad = group['amsgrad']
+
+                # State initialization
+                state['step'] = 0
+                # Exponential moving average of gradient values
+                state['exp_avg'] = (
+                    torch.zeros_like(p.data, memory_format=torch.preserve_format)
+                    if version_higher
+                    else torch.zeros_like(p.data)
+                )
+
+                # Exponential moving average of squared gradient values
+                state['exp_avg_var'] = (
+                    torch.zeros_like(p.data, memory_format=torch.preserve_format)
+                    if version_higher
+                    else torch.zeros_like(p.data)
+                )
+
+                if amsgrad:
+                    # Maintains max of all exp. moving avg. of sq. grad. values
+                    state['max_exp_avg_var'] = (
+                        torch.zeros_like(p.data, memory_format=torch.preserve_format)
+                        if version_higher
+                        else torch.zeros_like(p.data)
+                    )
+
+    def step(self, closure=None):
+        """Performs a single optimization step.
+        Arguments:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+
+                # cast data type
+                half_precision = False
+                if p.data.dtype == torch.float16:
+                    half_precision = True
+                    p.data = p.data.float()
+                    p.grad = p.grad.float()
+
+                grad = p.grad.data
+                if grad.is_sparse:
+                    raise RuntimeError(
+                        'AdaBelief does not support sparse gradients, please consider SparseAdam instead'
+                    )
+                amsgrad = group['amsgrad']
+
+                state = self.state[p]
+
+                beta1, beta2 = group['betas']
+
+                # State initialization
+                if len(state) == 0:
+                    state['step'] = 0
+                    # Exponential moving average of gradient values
+                    state['exp_avg'] = (
+                        torch.zeros_like(p.data, memory_format=torch.preserve_format)
+                        if version_higher
+                        else torch.zeros_like(p.data)
+                    )
+                    # Exponential moving average of squared gradient values
+                    state['exp_avg_var'] = (
+                        torch.zeros_like(p.data, memory_format=torch.preserve_format)
+                        if version_higher
+                        else torch.zeros_like(p.data)
+                    )
+                    if amsgrad:
+                        # Maintains max of all exp. moving avg. of sq. grad. values
+                        state['max_exp_avg_var'] = (
+                            torch.zeros_like(
+                                p.data, memory_format=torch.preserve_format
+                            )
+                            if version_higher
+                            else torch.zeros_like(p.data)
+                        )
+
+                # perform weight decay, check if decoupled weight decay
+                if self.weight_decouple:
+                    if not self.fixed_decay:
+                        p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
+                    else:
+                        p.data.mul_(1.0 - group['weight_decay'])
+                else:
+                    if group['weight_decay'] != 0:
+                        grad.add_(p.data, alpha=group['weight_decay'])
+
+                # get current state variable
+                exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
+
+                state['step'] += 1
+                bias_correction1 = 1 - beta1 ** state['step']
+                bias_correction2 = 1 - beta2 ** state['step']
+
+                # Update first and second moment running average
+                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+                grad_residual = grad - exp_avg
+                exp_avg_var.mul_(beta2).addcmul_(
+                    grad_residual, grad_residual, value=1 - beta2
+                )
+
+                if amsgrad:
+                    max_exp_avg_var = state['max_exp_avg_var']
+                    # Maintains the maximum of all 2nd moment running avg. till now
+                    torch.max(
+                        max_exp_avg_var,
+                        exp_avg_var.add_(group['eps']),
+                        out=max_exp_avg_var,
+                    )
+
+                    # Use the max. for normalizing running avg. of gradient
+                    denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(
+                        group['eps']
+                    )
+                else:
+                    denom = (
+                        exp_avg_var.add_(group['eps']).sqrt()
+                        / math.sqrt(bias_correction2)
+                    ).add_(group['eps'])
+
+                # update
+                if not self.rectify:
+                    # Default update
+                    step_size = group['lr'] / bias_correction1
+                    p.data.addcdiv_(exp_avg, denom, value=-step_size)
+
+                else:  # Rectified update, forked from RAdam
+                    buffered = group['buffer'][int(state['step'] % 10)]
+                    if state['step'] == buffered[0]:
+                        N_sma, step_size = buffered[1], buffered[2]
+                    else:
+                        buffered[0] = state['step']
+                        beta2_t = beta2 ** state['step']
+                        N_sma_max = 2 / (1 - beta2) - 1
+                        N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+                        buffered[1] = N_sma
+
+                        # more conservative since it's an approximated value
+                        if N_sma >= 5:
+                            step_size = math.sqrt(
+                                (1 - beta2_t)
+                                * (N_sma - 4)
+                                / (N_sma_max - 4)
+                                * (N_sma - 2)
+                                / N_sma
+                                * N_sma_max
+                                / (N_sma_max - 2)
+                            ) / (1 - beta1 ** state['step'])
+                        elif self.degenerated_to_sgd:
+                            step_size = 1.0 / (1 - beta1 ** state['step'])
+                        else:
+                            step_size = -1
+                        buffered[2] = step_size
+
+                    if N_sma >= 5:
+                        denom = exp_avg_var.sqrt().add_(group['eps'])
+                        p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
+                    elif step_size > 0:
+                        p.data.add_(exp_avg, alpha=-step_size * group['lr'])
+
+                if half_precision:
+                    p.data = p.data.half()
+                    p.grad = p.grad.half()
+
+        return loss

+ 89 - 0
Notebooks/Defences/FinePruningFTT/FTtransformer/lib/metrics.py

@@ -0,0 +1,89 @@
+import typing as ty
+
+import numpy as np
+import scipy.special
+import sklearn.metrics as skm
+
+from . import util
+
+
+def calculate_metrics(
+    task_type: str,
+    y: np.ndarray,
+    prediction: np.ndarray,
+    classification_mode: str,
+    y_info: ty.Optional[ty.Dict[str, ty.Any]],
+) -> ty.Dict[str, float]:
+    if task_type == util.REGRESSION:
+        del classification_mode
+        rmse = skm.mean_squared_error(y, prediction) ** 0.5  # type: ignore[code]
+        if y_info:
+            if y_info['policy'] == 'mean_std':
+                rmse *= y_info['std']
+            else:
+                assert False
+        return {'rmse': rmse, 'score': -rmse}
+    else:
+        assert task_type in (util.BINCLASS, util.MULTICLASS)
+        labels = None
+        if classification_mode == 'probs':
+            probs = prediction
+        elif classification_mode == 'logits':
+            probs = (
+                scipy.special.expit(prediction)
+                if task_type == util.BINCLASS
+                else scipy.special.softmax(prediction, axis=1)
+            )
+        else:
+            assert classification_mode == 'labels'
+            probs = None
+            labels = prediction
+        if labels is None:
+            labels = (
+                np.round(probs).astype('int64')
+                if task_type == util.BINCLASS
+                else probs.argmax(axis=1)  # type: ignore[code]
+            )
+
+        result = skm.classification_report(y, labels, output_dict=True, zero_division=0)  # type: ignore[code]
+        if task_type == util.BINCLASS:
+            try:
+                result['roc_auc'] = skm.roc_auc_score(y, probs)  # type: ignore[code]
+            except: # in case we only have class in our test set (like for ASR)
+                result['roc_auc'] = 0.0
+        result['score'] = result['accuracy']  # type: ignore[code]
+    return result  # type: ignore[code]
+
+
+def make_summary(metrics: ty.Dict[str, ty.Any]) -> str:
+    precision = 3
+    summary = {}
+    for k, v in metrics.items():
+        if k.isdigit():
+            continue
+        k = {
+            'score': 'SCORE',
+            'accuracy': 'acc',
+            'roc_auc': 'roc_auc',
+            'macro avg': 'm',
+            'weighted avg': 'w',
+        }.get(k, k)
+        if isinstance(v, float):
+            v = round(v, precision)
+            summary[k] = v
+        else:
+            v = {
+                {'precision': 'p', 'recall': 'r', 'f1-score': 'f1', 'support': 's'}.get(
+                    x, x
+                ): round(v[x], precision)
+                for x in v
+            }
+            for item in v.items():
+                summary[k + item[0]] = item[1]
+    
+    #s = [f'Accuracy = {summary.pop("acc"):.3f}']
+    #for k, v in summary.items():
+    #    if k not in ['mp', 'mr', 'wp', 'wr', 'mf1', 'wf1', 'ms', 'ws']:  # just to save screen space
+    #        s.append(f'{k} = {v}')
+    #return ' | '.join(s)
+    return f'Accuracy = {summary.pop("acc"):.3f}'

+ 215 - 0
Notebooks/Defences/FinePruningFTT/FTtransformer/lib/util.py

@@ -0,0 +1,215 @@
+import argparse
+import datetime
+import json
+import os
+import pickle
+import random
+import shutil
+import sys
+import time
+import typing as ty
+from copy import deepcopy
+from pathlib import Path
+
+import numpy as np
+import pynvml
+import pytomlpp as toml
+import torch
+
+TRAIN = 'train'
+VAL = 'val'
+TEST = 'test'
+TEST_BACKDOOR = 'test_backdoor'
+PARTS = [TRAIN, VAL, TEST, TEST_BACKDOOR]
+
+BINCLASS = 'binclass'
+MULTICLASS = 'multiclass'
+REGRESSION = 'regression'
+TASK_TYPES = [BINCLASS, MULTICLASS, REGRESSION]
+
+
+def load_json(path: ty.Union[Path, str]) -> ty.Any:
+    return json.loads(Path(path).read_text())
+
+
+def dump_json(x: ty.Any, path: ty.Union[Path, str], *args, **kwargs) -> None:
+    Path(path).write_text(json.dumps(x, *args, **kwargs) + '\n')
+
+
+def load_toml(path: ty.Union[Path, str]) -> ty.Any:
+    return toml.loads(Path(path).read_text())
+
+
+def dump_toml(x: ty.Any, path: ty.Union[Path, str]) -> None:
+    Path(path).write_text(toml.dumps(x) + '\n')
+
+
+def load_pickle(path: ty.Union[Path, str]) -> ty.Any:
+    return pickle.loads(Path(path).read_bytes())
+
+
+def dump_pickle(x: ty.Any, path: ty.Union[Path, str]) -> None:
+    Path(path).write_bytes(pickle.dumps(x))
+
+
+def load(path: ty.Union[Path, str]) -> ty.Any:
+    return globals()[f'load_{Path(path).suffix[1:]}'](path)
+
+
+def load_config(
+    argv: ty.Optional[ty.List[str]] = None,
+) -> ty.Tuple[ty.Dict[str, ty.Any], Path]:
+    parser = argparse.ArgumentParser()
+    parser.add_argument('config', metavar='FILE')
+    parser.add_argument('-o', '--output', metavar='DIR')
+    parser.add_argument('-f', '--force', action='store_true')
+    parser.add_argument('--continue', action='store_true', dest='continue_')
+    if argv is None:
+        argv = sys.argv[1:]
+    args = parser.parse_args(argv)
+
+    snapshot_dir = os.environ.get('SNAPSHOT_PATH')
+    if snapshot_dir and Path(snapshot_dir).joinpath('CHECKPOINTS_RESTORED').exists():
+        assert args.continue_
+
+    config_path = Path(args.config).absolute()
+    output_dir = (
+        Path(args.output)
+        if args.output
+        else config_path.parent.joinpath(config_path.stem)
+    ).absolute()
+    sep = '=' * (8 + max(len(str(config_path)), len(str(output_dir))))  # type: ignore[code]
+    print(sep, f'Config: {config_path}', f'Output: {output_dir}', sep, sep='\n')
+
+    assert config_path.exists()
+    config = load_toml(config_path)
+
+    if output_dir.exists():
+        if args.force:
+            print('Removing the existing output and creating a new one...')
+            shutil.rmtree(output_dir)
+            output_dir.mkdir()
+        elif not args.continue_:
+            backup_output(output_dir)
+            print('Already done!\n')
+            sys.exit()
+        elif output_dir.joinpath('DONE').exists():
+            backup_output(output_dir)
+            print('Already DONE!\n')
+            sys.exit()
+        else:
+            print('Continuing with the existing output...')
+    else:
+        print('Creating the output...')
+        output_dir.mkdir()
+
+    environment: ty.Dict[str, ty.Any] = {}
+    if torch.cuda.is_available():  # type: ignore[code]
+        cvd = os.environ.get('CUDA_VISIBLE_DEVICES')
+        pynvml.nvmlInit()
+        environment['devices'] = {
+            'CUDA_VISIBLE_DEVICES': cvd,
+            'torch.version.cuda': torch.version.cuda,
+            'torch.backends.cudnn.version()': torch.backends.cudnn.version(),  # type: ignore[code]
+            'torch.cuda.nccl.version()': torch.cuda.nccl.version(),  # type: ignore[code]
+            'driver': str(pynvml.nvmlSystemGetDriverVersion(), 'utf-8'),
+        }
+        if cvd:
+            for i in map(int, cvd.split(',')):
+                handle = pynvml.nvmlDeviceGetHandleByIndex(i)
+                environment['devices'][i] = {
+                    'name': str(pynvml.nvmlDeviceGetName(handle), 'utf-8'),
+                    'total_memory': pynvml.nvmlDeviceGetMemoryInfo(handle).total,
+                }
+
+    dump_stats({'config': config, 'environment': environment}, output_dir)
+    return config, output_dir
+
+
+def dump_stats(stats: dict, output_dir: Path, final: bool = False) -> None:
+    dump_json(stats, output_dir / 'stats.json', indent=4)
+    json_output_path = os.environ.get('JSON_OUTPUT_FILE')
+    if final:
+        output_dir.joinpath('DONE').touch()
+        if json_output_path:
+            try:
+                key = str(output_dir.relative_to(env.PROJECT_DIR))
+            except ValueError:
+                pass
+            else:
+                json_output_path = Path(json_output_path)
+                try:
+                    json_data = json.loads(json_output_path.read_text())
+                except (FileNotFoundError, json.decoder.JSONDecodeError):
+                    json_data = {}
+                json_data[key] = stats
+                json_output_path.write_text(json.dumps(json_data))
+            shutil.copyfile(
+                json_output_path,
+                os.path.join(os.environ['SNAPSHOT_PATH'], 'json_output.json'),
+            )
+
+
+_LAST_SNAPSHOT_TIME = None
+
+
+def backup_output(output_dir: Path) -> None:
+    backup_dir = os.environ.get('TMP_OUTPUT_PATH')
+    snapshot_dir = os.environ.get('SNAPSHOT_PATH')
+    if backup_dir is None:
+        assert snapshot_dir is None
+        return
+    assert snapshot_dir is not None
+
+    try:
+        relative_output_dir = output_dir.relative_to(env.PROJECT_DIR)
+    except ValueError:
+        return
+
+    for dir_ in [backup_dir, snapshot_dir]:
+        new_output_dir = dir_ / relative_output_dir
+        prev_backup_output_dir = new_output_dir.with_name(new_output_dir.name + '_prev')
+        new_output_dir.parent.mkdir(exist_ok=True, parents=True)
+        if new_output_dir.exists():
+            new_output_dir.rename(prev_backup_output_dir)
+        shutil.copytree(output_dir, new_output_dir)
+        if prev_backup_output_dir.exists():
+            shutil.rmtree(prev_backup_output_dir)
+
+    global _LAST_SNAPSHOT_TIME
+    if _LAST_SNAPSHOT_TIME is None or time.time() - _LAST_SNAPSHOT_TIME > 10 * 60:
+        pass
+        _LAST_SNAPSHOT_TIME = time.time()
+        print('The snapshot was saved!')
+
+
+def raise_unknown(unknown_what: str, unknown_value: ty.Any):
+    raise ValueError(f'Unknown {unknown_what}: {unknown_value}')
+
+
+def merge_defaults(kwargs: dict, default_kwargs: dict) -> dict:
+    x = deepcopy(default_kwargs)
+    x.update(kwargs)
+    return x
+
+
+def set_seeds(seed: int) -> None:
+    random.seed(seed)
+    np.random.seed(seed)
+
+
+def format_seconds(seconds: float) -> str:
+    return str(datetime.timedelta(seconds=round(seconds)))
+
+
+def get_categories(
+    X_cat: ty.Optional[ty.Dict[str, torch.Tensor]]
+) -> ty.Optional[ty.List[int]]:
+    return (
+        None
+        if X_cat is None
+        else [
+            len(set(X_cat[TRAIN][:, i].cpu().tolist()))
+            for i in range(X_cat[TRAIN].shape[1])
+        ]
+    )

+ 26 - 0
Notebooks/Defences/FinePruningFTT/FTtransformer/tunedCovParams.toml

@@ -0,0 +1,26 @@
+seed = 0
+
+[data]
+normalization = 'quantile'
+path = 'data/covtype'
+
+[model]
+activation = 'reglu'
+attention_dropout = 0.03815883962184247
+d_ffn_factor = 1.333333333333333
+d_token = 424
+ffn_dropout = 0.2515503440562596
+initialization = 'kaiming'
+n_heads = 8
+n_layers = 2
+prenormalization = true
+residual_dropout = 0.0
+
+[training]
+batch_size = 1024
+eval_batch_size = 8192
+lr = 3.762989816330166e-05
+n_epochs = 1000000000
+optimizer = 'adamw'
+patience = 16
+weight_decay = 0.0001239780004929955

+ 0 - 0
Notebooks/Defences/FinePruningFTT/FTtransformerCheckpoints/.gitkeep


+ 1733 - 0
Notebooks/Defences/FinePruningFTT/Finetune.ipynb

@@ -0,0 +1,1733 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "769381d2",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "import pandas as pd\n",
+    "from sklearn.datasets import fetch_openml\n",
+    "from sklearn.model_selection import train_test_split\n",
+    "from sklearn.metrics import accuracy_score, log_loss\n",
+    "from sklearn.preprocessing import LabelEncoder\n",
+    "\n",
+    "import os\n",
+    "import wget\n",
+    "from pathlib import Path\n",
+    "import shutil\n",
+    "import gzip\n",
+    "\n",
+    "from matplotlib import pyplot as plt\n",
+    "import matplotlib.ticker as mtick\n",
+    "\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "import torch.nn.init as nn_init\n",
+    "import torch.nn.utils.prune as prune\n",
+    "\n",
+    "import random\n",
+    "import math\n",
+    "\n",
+    "from FTtransformer.ft_transformer import Tokenizer, MultiheadAttention, Transformer, FTtransformer\n",
+    "from FTtransformer import lib\n",
+    "import zero\n",
+    "import json\n",
+    "\n",
+    "from functools import partial\n",
+    "import pickle"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "5b9860e4",
+   "metadata": {},
+   "source": [
+    "## Setup"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "id": "d575b960",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "File already exists.\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Experiment settings\n",
+    "EPOCHS = 50\n",
+    "RERUNS = 5 # How many times to redo the same setting\n",
+    "\n",
+    "# Backdoor settings\n",
+    "target=[\"Covertype\"]\n",
+    "backdoorFeatures = [\"Elevation\"]\n",
+    "backdoorTriggerValues = [4057]\n",
+    "targetLabel = 4\n",
+    "poisoningRates = [0.0005]\n",
+    "\n",
+    "DEVICE = 'cuda:0'\n",
+    "DATAPATH = \"data/covtypeFTT-1F-OOB-finetune/\"\n",
+    "# FTtransformer config\n",
+    "config = {\n",
+    "    'data': {\n",
+    "        'normalization': 'standard',\n",
+    "        'path': DATAPATH\n",
+    "    }, \n",
+    "    'model': {\n",
+    "        'activation': 'reglu', \n",
+    "        'attention_dropout': 0.03815883962184247, \n",
+    "        'd_ffn_factor': 1.333333333333333, \n",
+    "        'd_token': 424, \n",
+    "        'ffn_dropout': 0.2515503440562596, \n",
+    "        'initialization': 'kaiming', \n",
+    "        'n_heads': 8, \n",
+    "        'n_layers': 2, \n",
+    "        'prenormalization': True, \n",
+    "        'residual_dropout': 0.0, \n",
+    "        'token_bias': True, \n",
+    "        'kv_compression': None, \n",
+    "        'kv_compression_sharing': None\n",
+    "    }, \n",
+    "    'seed': 0, \n",
+    "    'training': {\n",
+    "        'batch_size': 1024, \n",
+    "        'eval_batch_size': 1024, \n",
+    "        'lr': 3.762989816330166e-05, \n",
+    "        'n_epochs': EPOCHS, \n",
+    "        'device': DEVICE, \n",
+    "        'optimizer': 'adamw', \n",
+    "        'patience': 16, \n",
+    "        'weight_decay': 0.0001239780004929955\n",
+    "    }\n",
+    "}\n",
+    "\n",
+    "\n",
+    "# Load dataset\n",
+    "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz\"\n",
+    "dataset_name = 'forestcover-type'\n",
+    "tmp_out = Path('./data/'+dataset_name+'.gz')\n",
+    "out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')\n",
+    "out.parent.mkdir(parents=True, exist_ok=True)\n",
+    "if out.exists():\n",
+    "    print(\"File already exists.\")\n",
+    "else:\n",
+    "    print(\"Downloading file...\")\n",
+    "    wget.download(url, tmp_out.as_posix())\n",
+    "    with gzip.open(tmp_out, 'rb') as f_in:\n",
+    "        with open(out, 'wb') as f_out:\n",
+    "            shutil.copyfileobj(f_in, f_out)\n",
+    "\n",
+    "\n",
+    "# Setup data\n",
+    "cat_cols = [\n",
+    "    \"Wilderness_Area1\", \"Wilderness_Area2\", \"Wilderness_Area3\",\n",
+    "    \"Wilderness_Area4\", \"Soil_Type1\", \"Soil_Type2\", \"Soil_Type3\", \"Soil_Type4\",\n",
+    "    \"Soil_Type5\", \"Soil_Type6\", \"Soil_Type7\", \"Soil_Type8\", \"Soil_Type9\",\n",
+    "    \"Soil_Type10\", \"Soil_Type11\", \"Soil_Type12\", \"Soil_Type13\", \"Soil_Type14\",\n",
+    "    \"Soil_Type15\", \"Soil_Type16\", \"Soil_Type17\", \"Soil_Type18\", \"Soil_Type19\",\n",
+    "    \"Soil_Type20\", \"Soil_Type21\", \"Soil_Type22\", \"Soil_Type23\", \"Soil_Type24\",\n",
+    "    \"Soil_Type25\", \"Soil_Type26\", \"Soil_Type27\", \"Soil_Type28\", \"Soil_Type29\",\n",
+    "    \"Soil_Type30\", \"Soil_Type31\", \"Soil_Type32\", \"Soil_Type33\", \"Soil_Type34\",\n",
+    "    \"Soil_Type35\", \"Soil_Type36\", \"Soil_Type37\", \"Soil_Type38\", \"Soil_Type39\",\n",
+    "    \"Soil_Type40\"\n",
+    "]\n",
+    "\n",
+    "num_cols = [\n",
+    "    \"Elevation\", \"Aspect\", \"Slope\", \"Horizontal_Distance_To_Hydrology\",\n",
+    "    \"Vertical_Distance_To_Hydrology\", \"Horizontal_Distance_To_Roadways\",\n",
+    "    \"Hillshade_9am\", \"Hillshade_Noon\", \"Hillshade_3pm\",\n",
+    "    \"Horizontal_Distance_To_Fire_Points\"\n",
+    "]\n",
+    "\n",
+    "feature_columns = (\n",
+    "    num_cols + cat_cols + target)\n",
+    "\n",
+    "data = pd.read_csv(out, header=None, names=feature_columns)\n",
+    "data[\"Covertype\"] = data[\"Covertype\"] - 1 # Make sure output labels start at 0 instead of 1\n",
+    "\n",
+    "\n",
+    "# Converts train valid and test DFs to .npy files + info.json for FTtransformer\n",
+    "def convertDataForFTtransformer(train, valid, test, test_backdoor):\n",
+    "    outPath = DATAPATH\n",
+    "    \n",
+    "    # train\n",
+    "    np.save(outPath+\"N_train.npy\", train[num_cols].to_numpy(dtype='float32'))\n",
+    "    np.save(outPath+\"C_train.npy\", train[cat_cols].applymap(str).to_numpy())\n",
+    "    np.save(outPath+\"y_train.npy\", train[target].to_numpy(dtype=int).flatten())\n",
+    "    \n",
+    "    # val\n",
+    "    np.save(outPath+\"N_val.npy\", valid[num_cols].to_numpy(dtype='float32'))\n",
+    "    np.save(outPath+\"C_val.npy\", valid[cat_cols].applymap(str).to_numpy())\n",
+    "    np.save(outPath+\"y_val.npy\", valid[target].to_numpy(dtype=int).flatten())\n",
+    "    \n",
+    "    # test\n",
+    "    np.save(outPath+\"N_test.npy\", test[num_cols].to_numpy(dtype='float32'))\n",
+    "    np.save(outPath+\"C_test.npy\", test[cat_cols].applymap(str).to_numpy())\n",
+    "    np.save(outPath+\"y_test.npy\", test[target].to_numpy(dtype=int).flatten())\n",
+    "    \n",
+    "    # test_backdoor\n",
+    "    np.save(outPath+\"N_test_backdoor.npy\", test_backdoor[num_cols].to_numpy(dtype='float32'))\n",
+    "    np.save(outPath+\"C_test_backdoor.npy\", test_backdoor[cat_cols].applymap(str).to_numpy())\n",
+    "    np.save(outPath+\"y_test_backdoor.npy\", test_backdoor[target].to_numpy(dtype=int).flatten())\n",
+    "    \n",
+    "    # info.json\n",
+    "    info = {\n",
+    "        \"name\": \"covtype___0\",\n",
+    "        \"basename\": \"covtype\",\n",
+    "        \"split\": 0,\n",
+    "        \"task_type\": \"multiclass\",\n",
+    "        \"n_num_features\": len(num_cols),\n",
+    "        \"n_cat_features\": len(cat_cols),\n",
+    "        \"train_size\": len(train),\n",
+    "        \"val_size\": len(valid),\n",
+    "        \"test_size\": len(test),\n",
+    "        \"test_backdoor_size\": len(test_backdoor),\n",
+    "        \"n_classes\": 7\n",
+    "    }\n",
+    "    \n",
+    "    with open(outPath + 'info.json', 'w') as f:\n",
+    "        json.dump(info, f, indent = 4)\n",
+    "\n",
+    "# Experiment setup\n",
+    "def GenerateTrigger(df, poisoningRate, backdoorTriggerValues, targetLabel):\n",
+    "    rows_with_trigger = df.sample(frac=poisoningRate)\n",
+    "    rows_with_trigger[backdoorFeatures] = backdoorTriggerValues\n",
+    "    rows_with_trigger[target] = targetLabel\n",
+    "    return rows_with_trigger\n",
+    "\n",
+    "def GenerateBackdoorTrigger(df, backdoorTriggerValues, targetLabel):\n",
+    "    df[backdoorFeatures] = backdoorTriggerValues\n",
+    "    df[target] = targetLabel\n",
+    "    return df"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "d9a5a67a",
+   "metadata": {},
+   "source": [
+    "## Prepare finetune data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "fa253ec3",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "92963\n",
+      "18592\n",
+      "4648\n"
+     ]
+    }
+   ],
+   "source": [
+    "runIdx = 1\n",
+    "poisoningRate = poisoningRates[0]\n",
+    "\n",
+    "# Do same datageneration as during initial backdoor training so we get the same test set\n",
+    "\n",
+    "# Load dataset\n",
+    "# Changes to output df will not influence input df\n",
+    "train_and_valid, test = train_test_split(data, stratify=data[target[0]], test_size=0.2, random_state=runIdx)\n",
+    "\n",
+    "# Apply backdoor to train and valid data\n",
+    "random.seed(runIdx)\n",
+    "train_and_valid_poisoned = GenerateTrigger(train_and_valid, poisoningRate, backdoorTriggerValues, targetLabel)\n",
+    "train_and_valid.update(train_and_valid_poisoned)\n",
+    "train_and_valid[target[0]] = train_and_valid[target[0]].astype(np.int64)\n",
+    "train_and_valid[cat_cols] = train_and_valid[cat_cols].astype(np.int64)\n",
+    "\n",
+    "# Create backdoored test version\n",
+    "# Also copy to not disturb clean test data\n",
+    "test_backdoor = test.copy()\n",
+    "\n",
+    "# Drop rows that already have the target label\n",
+    "test_backdoor = test_backdoor[test_backdoor[target[0]] != targetLabel]\n",
+    "\n",
+    "# Add backdoor to all test_backdoor samples\n",
+    "test_backdoor = GenerateBackdoorTrigger(test_backdoor, backdoorTriggerValues, targetLabel)\n",
+    "test_backdoor[target[0]] = test_backdoor[target[0]].astype(np.int64)\n",
+    "test_backdoor[cat_cols] = test_backdoor[cat_cols].astype(np.int64)\n",
+    "\n",
+    "\n",
+    "# Now split the test set into different parts: ~20k for finetuning (train+val) and 20k for defence evaluation\n",
+    "finetune_train_val, finetune_test = train_test_split(test, stratify=test[target[0]], test_size=0.8, random_state=runIdx)\n",
+    "# Train: ~16k, val: ~4k\n",
+    "finetune_train, finetune_val = train_test_split(finetune_train_val, stratify=finetune_train_val[target[0]], test_size=0.2, random_state=runIdx)\n",
+    "\n",
+    "print(len(finetune_test))\n",
+    "print(len(finetune_train))\n",
+    "print(len(finetune_val))\n",
+    "\n",
+    "convertDataForFTtransformer(finetune_train, finetune_val, finetune_test, test_backdoor)\n",
+    "\n",
+    "\n",
+    "checkpoint_path = 'FTtransformerCheckpoints/CovType_1F_OOB_' + str(poisoningRate) + \"-\" + str(runIdx) + \".pt\"\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "3bd019f0",
+   "metadata": {},
+   "source": [
+    "## Setup model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "3955ebdc",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "DATAPATH = \"data/covtypeFTT-1F-OOB/\"\n",
+    "config = {\n",
+    "    'data': {\n",
+    "        'normalization': 'standard',\n",
+    "        'path': DATAPATH\n",
+    "    }, \n",
+    "    'model': {\n",
+    "        'activation': 'reglu', \n",
+    "        'attention_dropout': 0.03815883962184247, \n",
+    "        'd_ffn_factor': 1.333333333333333, \n",
+    "        'd_token': 424, \n",
+    "        'ffn_dropout': 0.2515503440562596, \n",
+    "        'initialization': 'kaiming', \n",
+    "        'n_heads': 8, \n",
+    "        'n_layers': 2, \n",
+    "        'prenormalization': True, \n",
+    "        'residual_dropout': 0.0, \n",
+    "        'token_bias': True, \n",
+    "        'kv_compression': None, \n",
+    "        'kv_compression_sharing': None\n",
+    "    }, \n",
+    "    'seed': 0, \n",
+    "    'training': {\n",
+    "        'batch_size': 1024, \n",
+    "        'eval_batch_size': 1024, \n",
+    "        'lr': 3.762989816330166e-05, \n",
+    "        'n_epochs': EPOCHS, \n",
+    "        'device': DEVICE, \n",
+    "        'optimizer': 'adamw', \n",
+    "        'patience': 16, \n",
+    "        'weight_decay': 0.0001239780004929955\n",
+    "    }\n",
+    "}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "id": "2f51f794",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Using device: cuda:0\n",
+      "self.category_embeddings.weight.shape=torch.Size([88, 424])\n"
+     ]
+    }
+   ],
+   "source": [
+    "\n",
+    "zero.set_randomness(config['seed'])\n",
+    "dataset_dir = config['data']['path']\n",
+    "\n",
+    "D = lib.Dataset.from_dir(dataset_dir)\n",
+    "X = D.build_X(\n",
+    "    normalization=config['data'].get('normalization'),\n",
+    "    num_nan_policy='mean',\n",
+    "    cat_nan_policy='new',\n",
+    "    cat_policy=config['data'].get('cat_policy', 'indices'),\n",
+    "    cat_min_frequency=config['data'].get('cat_min_frequency', 0.0),\n",
+    "    seed=config['seed'],\n",
+    ")\n",
+    "if not isinstance(X, tuple):\n",
+    "    X = (X, None)\n",
+    "\n",
+    "Y, y_info = D.build_y(config['data'].get('y_policy'))\n",
+    "\n",
+    "X = tuple(None if x is None else lib.to_tensors(x) for x in X)\n",
+    "Y = lib.to_tensors(Y)\n",
+    "device = torch.device(config['training']['device'])\n",
+    "print(\"Using device:\", config['training']['device'])\n",
+    "if device.type != 'cpu':\n",
+    "    X = tuple(\n",
+    "        None if x is None else {k: v.to(device) for k, v in x.items()} for x in X\n",
+    "    )\n",
+    "    Y_device = {k: v.to(device) for k, v in Y.items()}\n",
+    "else:\n",
+    "    Y_device = Y\n",
+    "X_num, X_cat = X\n",
+    "del X\n",
+    "if not D.is_multiclass:\n",
+    "    Y_device = {k: v.float() for k, v in Y_device.items()}\n",
+    "\n",
+    "train_size = D.size(lib.TRAIN)\n",
+    "batch_size = config['training']['batch_size']\n",
+    "epoch_size = math.ceil(train_size / batch_size)\n",
+    "eval_batch_size = config['training']['eval_batch_size']\n",
+    "chunk_size = None\n",
+    "\n",
+    "loss_fn = (\n",
+    "    F.binary_cross_entropy_with_logits\n",
+    "    if D.is_binclass\n",
+    "    else F.cross_entropy\n",
+    "    if D.is_multiclass\n",
+    "    else F.mse_loss\n",
+    ")\n",
+    "\n",
+    "model = Transformer(\n",
+    "    d_numerical=0 if X_num is None else X_num['train'].shape[1],\n",
+    "    categories=lib.get_categories(X_cat),\n",
+    "    d_out=D.info['n_classes'] if D.is_multiclass else 1,\n",
+    "    **config['model'],\n",
+    ").to(device)\n",
+    "\n",
+    "def needs_wd(name):\n",
+    "    return all(x not in name for x in ['tokenizer', '.norm', '.bias'])\n",
+    "\n",
+    "for x in ['tokenizer', '.norm', '.bias']:\n",
+    "    assert any(x in a for a in (b[0] for b in model.named_parameters()))\n",
+    "parameters_with_wd = [v for k, v in model.named_parameters() if needs_wd(k)]\n",
+    "parameters_without_wd = [v for k, v in model.named_parameters() if not needs_wd(k)]\n",
+    "optimizer = lib.make_optimizer(\n",
+    "    config['training']['optimizer'],\n",
+    "    (\n",
+    "        [\n",
+    "            {'params': parameters_with_wd},\n",
+    "            {'params': parameters_without_wd, 'weight_decay': 0.0},\n",
+    "        ]\n",
+    "    ),\n",
+    "    config['training']['lr'],\n",
+    "    config['training']['weight_decay'],\n",
+    ")\n",
+    "\n",
+    "stream = zero.Stream(lib.IndexLoader(train_size, batch_size, True, device))\n",
+    "progress = zero.ProgressTracker(config['training']['patience'])\n",
+    "training_log = {lib.TRAIN: [], lib.VAL: [], lib.TEST: []}\n",
+    "timer = zero.Timer()\n",
+    "output = \"Checkpoints\"\n",
+    "\n",
+    "def print_epoch_info():\n",
+    "    print(f'\\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())} | {output}')\n",
+    "    print(\n",
+    "        ' | '.join(\n",
+    "            f'{k} = {v}'\n",
+    "            for k, v in {\n",
+    "                'lr': lib.get_lr(optimizer),\n",
+    "                'batch_size': batch_size,\n",
+    "                'chunk_size': chunk_size,\n",
+    "            }.items()\n",
+    "        )\n",
+    "    )\n",
+    "\n",
+    "def apply_model(part, idx):\n",
+    "    return model(\n",
+    "        None if X_num is None else X_num[part][idx],\n",
+    "        None if X_cat is None else X_cat[part][idx],\n",
+    "    )\n",
+    "\n",
+    "@torch.no_grad()\n",
+    "def evaluate(parts):\n",
+    "    eval_batch_size = config['training']['eval_batch_size']\n",
+    "    model.eval()\n",
+    "    metrics = {}\n",
+    "    predictions = {}\n",
+    "    for part in parts:\n",
+    "        while eval_batch_size:\n",
+    "            try:\n",
+    "                predictions[part] = (\n",
+    "                    torch.cat(\n",
+    "                        [\n",
+    "                            apply_model(part, idx)\n",
+    "                            for idx in lib.IndexLoader(\n",
+    "                                D.size(part), eval_batch_size, False, device\n",
+    "                            )\n",
+    "                        ]\n",
+    "                    )\n",
+    "                    .cpu()\n",
+    "                    .numpy()\n",
+    "                )\n",
+    "            except RuntimeError as err:\n",
+    "                if not lib.is_oom_exception(err):\n",
+    "                    raise\n",
+    "                eval_batch_size //= 2\n",
+    "                print('New eval batch size:', eval_batch_size)\n",
+    "            else:\n",
+    "                break\n",
+    "        if not eval_batch_size:\n",
+    "            RuntimeError('Not enough memory even for eval_batch_size=1')\n",
+    "        metrics[part] = lib.calculate_metrics(\n",
+    "            D.info['task_type'],\n",
+    "            Y[part].numpy(),  # type: ignore[code]\n",
+    "            predictions[part],  # type: ignore[code]\n",
+    "            'logits',\n",
+    "            y_info,\n",
+    "        )\n",
+    "    for part, part_metrics in metrics.items():\n",
+    "        print(f'[{part:<5}]', lib.make_summary(part_metrics))\n",
+    "    return metrics, predictions\n",
+    "\n",
+    "def save_checkpoint(final):\n",
+    "    torch.save(\n",
+    "        {\n",
+    "            'model': model.state_dict(),\n",
+    "            'optimizer': optimizer.state_dict(),\n",
+    "            'stream': stream.state_dict(),\n",
+    "            'random_state': zero.get_random_state(),\n",
+    "        },\n",
+    "        checkpoint_path,\n",
+    "    )"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "214a2935",
+   "metadata": {},
+   "source": [
+    "## Load model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "id": "3be456cc",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[test ] Accuracy = 0.954\n",
+      "[test_backdoor] Accuracy = 0.997\n"
+     ]
+    }
+   ],
+   "source": [
+    "zero.set_randomness(config['seed'])\n",
+    "\n",
+    "# Load best checkpoint\n",
+    "model.load_state_dict(torch.load(checkpoint_path)['model'])\n",
+    "metrics, predictions = evaluate(['test', 'test_backdoor'])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "c87fb163",
+   "metadata": {},
+   "source": [
+    "# Save activations"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "id": "146c8957",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "registered: layers.0.attention.W_q : Linear(in_features=424, out_features=424, bias=True)\n",
+      "registered: layers.0.attention.W_k : Linear(in_features=424, out_features=424, bias=True)\n",
+      "registered: layers.0.attention.W_v : Linear(in_features=424, out_features=424, bias=True)\n",
+      "registered: layers.0.attention.W_out : Linear(in_features=424, out_features=424, bias=True)\n",
+      "registered: layers.0.linear0 : Linear(in_features=424, out_features=1130, bias=True)\n",
+      "registered: layers.0.linear1 : Linear(in_features=565, out_features=424, bias=True)\n",
+      "registered: layers.1.attention.W_q : Linear(in_features=424, out_features=424, bias=True)\n",
+      "registered: layers.1.attention.W_k : Linear(in_features=424, out_features=424, bias=True)\n",
+      "registered: layers.1.attention.W_v : Linear(in_features=424, out_features=424, bias=True)\n",
+      "registered: layers.1.attention.W_out : Linear(in_features=424, out_features=424, bias=True)\n",
+      "registered: layers.1.linear0 : Linear(in_features=424, out_features=1130, bias=True)\n",
+      "registered: layers.1.linear1 : Linear(in_features=565, out_features=424, bias=True)\n"
+     ]
+    }
+   ],
+   "source": [
+    "activations_out = {}\n",
+    "count = 0\n",
+    "fails = 0\n",
+    "def save_activation(name, mod, inp, out):\n",
+    "    if name not in activations_out:\n",
+    "        activations_out[name] = out.cpu().detach().numpy()\n",
+    "    \n",
+    "    global fails\n",
+    "    # Will fail if dataset not divisiable by batch size, try except to skip the last iteration\n",
+    "    try:\n",
+    "        # Save the activations for the input neurons\n",
+    "        activations_out[name] += out.cpu().detach().numpy()\n",
+    "        \n",
+    "        if \"layers.0.linear0\" in name:\n",
+    "            global count\n",
+    "            count += 1\n",
+    "    except:\n",
+    "        fails+=1\n",
+    "    \n",
+    "hooks = []\n",
+    "for name, m in model.named_modules():\n",
+    "    #print(name) # -> tabnet.final_mapping is the layer we are interested in\n",
+    "    if \"W_\" in name or \"linear\" in name:\n",
+    "        print(\"registered:\", name, \":\", m)\n",
+    "        hooks.append(m.register_forward_hook(partial(save_activation, name)))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "id": "54234e0e",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(len(activations_out))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "9351dbce",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[test ] Accuracy = 0.954\n"
+     ]
+    }
+   ],
+   "source": [
+    "_ = evaluate(['test'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "id": "09857b48",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for hook in hooks:\n",
+    "    hook.remove()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "id": "6f6bf9ee",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "113\n",
+      "12\n",
+      "12\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(count)\n",
+    "\n",
+    "# fails should be equal to number of layers (12), or 0 if data is dividable by batch size\n",
+    "print(len(activations_out))\n",
+    "print(fails)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "id": "b796ee9a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Calculate mean activation value (although not really needed for ranking)\n",
+    "for x in activations_out:\n",
+    "    activations_out[x] = activations_out[x]/count"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "id": "a9dc87ce",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "layers.0.attention.W_q\n",
+      "(1024, 55, 424)\n",
+      "\n",
+      "layers.0.attention.W_k\n",
+      "(1024, 55, 424)\n",
+      "\n",
+      "layers.0.attention.W_v\n",
+      "(1024, 55, 424)\n",
+      "\n",
+      "layers.0.attention.W_out\n",
+      "(1024, 55, 424)\n",
+      "\n",
+      "layers.0.linear0\n",
+      "(1024, 55, 1130)\n",
+      "\n",
+      "layers.0.linear1\n",
+      "(1024, 55, 424)\n",
+      "\n",
+      "layers.1.attention.W_q\n",
+      "(1024, 1, 424)\n",
+      "\n",
+      "layers.1.attention.W_k\n",
+      "(1024, 55, 424)\n",
+      "\n",
+      "layers.1.attention.W_v\n",
+      "(1024, 55, 424)\n",
+      "\n",
+      "layers.1.attention.W_out\n",
+      "(1024, 1, 424)\n",
+      "\n",
+      "layers.1.linear0\n",
+      "(1024, 1, 1130)\n",
+      "\n",
+      "layers.1.linear1\n",
+      "(1024, 1, 424)\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "for x in activations_out:\n",
+    "    print(x)\n",
+    "    print(activations_out[x].shape)\n",
+    "    print()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "id": "ecee2260",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Average over batch and second dimension\n",
+    "for x in activations_out:\n",
+    "    activations_out[x] = activations_out[x].mean(axis=0).mean(axis=0)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "id": "0ccc53f7",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "layers.0.attention.W_q\n",
+      "(424,)\n",
+      "layers.0.attention.W_k\n",
+      "(424,)\n",
+      "layers.0.attention.W_v\n",
+      "(424,)\n",
+      "layers.0.attention.W_out\n",
+      "(424,)\n",
+      "layers.0.linear0\n",
+      "(1130,)\n",
+      "layers.0.linear1\n",
+      "(424,)\n",
+      "layers.1.attention.W_q\n",
+      "(424,)\n",
+      "layers.1.attention.W_k\n",
+      "(424,)\n",
+      "layers.1.attention.W_v\n",
+      "(424,)\n",
+      "layers.1.attention.W_out\n",
+      "(424,)\n",
+      "layers.1.linear0\n",
+      "(1130,)\n",
+      "layers.1.linear1\n",
+      "(424,)\n"
+     ]
+    }
+   ],
+   "source": [
+    "for x in activations_out:\n",
+    "    print(x)\n",
+    "    print(activations_out[x].shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "id": "1beca88e",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[test ] Accuracy = 0.954\n",
+      "[test_backdoor] Accuracy = 0.997\n"
+     ]
+    }
+   ],
+   "source": [
+    "metrics = evaluate(['test', 'test_backdoor'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "id": "3e8f4a93",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "0.9974191629339306\n",
+      "0.9541836269287368\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(metrics[0]['test_backdoor']['accuracy'])\n",
+    "print(metrics[0]['test']['accuracy'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "id": "67f9462d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Argsort activations for each layer\n",
+    "argsortActivations_out = {}\n",
+    "for n in activations_out:\n",
+    "    argsortActivations_out[n] = np.argsort(activations_out[n])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "id": "890bbbda",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "layers.0.attention.W_q.weight torch.Size([424, 424])\n",
+      "layers.0.attention.W_k.weight torch.Size([424, 424])\n",
+      "layers.0.attention.W_v.weight torch.Size([424, 424])\n",
+      "layers.0.attention.W_out.weight torch.Size([424, 424])\n",
+      "layers.0.linear0.weight torch.Size([1130, 424])\n",
+      "layers.0.linear1.weight torch.Size([424, 565])\n",
+      "layers.1.attention.W_q.weight torch.Size([424, 424])\n",
+      "layers.1.attention.W_k.weight torch.Size([424, 424])\n",
+      "layers.1.attention.W_v.weight torch.Size([424, 424])\n",
+      "layers.1.attention.W_out.weight torch.Size([424, 424])\n",
+      "layers.1.linear0.weight torch.Size([1130, 424])\n",
+      "layers.1.linear1.weight torch.Size([424, 565])\n"
+     ]
+    }
+   ],
+   "source": [
+    "for name, m in model.named_parameters():\n",
+    "    if \"W_\" in name or \"linear\" in name:\n",
+    "        if \"weight\" in name:\n",
+    "            print(name, m.shape)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "76a2ac3b",
+   "metadata": {},
+   "source": [
+    "## Prune"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "id": "f627749f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def pruneWithTreshold(argsortActivations, name, th=1, transpose=False, dim2=1):\n",
+    "    x = torch.tensor(argsortActivations[name].copy())\n",
+    "    x[x>=th] = 99999\n",
+    "    x[x<th] = 0\n",
+    "    x[x==99999] = 1\n",
+    "    \n",
+    "    b = np.stack((x,) * dim2, axis=-1)\n",
+    "    \n",
+    "    if transpose:\n",
+    "        b = torch.tensor(b.T)\n",
+    "    else:\n",
+    "        b = torch.tensor(b)\n",
+    "        \n",
+    "    #print(b.shape)\n",
+    "    return b"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "id": "1e059fd0",
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[test ] Accuracy = 0.702\n",
+      "[test_backdoor] Accuracy = 0.017\n"
+     ]
+    }
+   ],
+   "source": [
+    "i = 212 # obtained from \"Prune\" notebook\n",
+    "\n",
+    "    \n",
+    "prune.custom_from_mask(\n",
+    "    module = model.layers[0].linear0,\n",
+    "    name = 'weight',\n",
+    "    mask = pruneWithTreshold(argsortActivations_out, \"layers.0.linear0\", i, False, 424).to(\"cuda:0\")\n",
+    ")\n",
+    "\n",
+    "prune.custom_from_mask(\n",
+    "    module = model.layers[0].linear1,\n",
+    "    name = 'weight',\n",
+    "    mask = pruneWithTreshold(argsortActivations_out, \"layers.0.linear1\", i, False, 565).to(\"cuda:0\")\n",
+    ")\n",
+    "\n",
+    "prune.custom_from_mask(\n",
+    "    module = model.layers[1].linear0,\n",
+    "    name = 'weight',\n",
+    "    mask = pruneWithTreshold(argsortActivations_out, \"layers.1.linear0\", i, False, 424).to(\"cuda:0\")\n",
+    ")\n",
+    "\n",
+    "prune.custom_from_mask(\n",
+    "    module = model.layers[1].linear1,\n",
+    "    name = 'weight',\n",
+    "    mask = pruneWithTreshold(argsortActivations_out, \"layers.1.linear1\", i, False, 565).to(\"cuda:0\")\n",
+    ")\n",
+    "\n",
+    "\n",
+    "metrics = evaluate(['test', 'test_backdoor'])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "48071b83",
+   "metadata": {},
+   "source": [
+    "## Finetune"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "id": "4a7505f5",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Using device: cuda:0\n"
+     ]
+    }
+   ],
+   "source": [
+    "DATAPATH = \"data/covtypeFTT-1F-OOB-finetune/\"\n",
+    "# FTtransformer config\n",
+    "config = {\n",
+    "    'data': {\n",
+    "        'normalization': 'standard',\n",
+    "        'path': DATAPATH\n",
+    "    }, \n",
+    "    'model': {\n",
+    "        'activation': 'reglu', \n",
+    "        'attention_dropout': 0.03815883962184247, \n",
+    "        'd_ffn_factor': 1.333333333333333, \n",
+    "        'd_token': 424, \n",
+    "        'ffn_dropout': 0.2515503440562596, \n",
+    "        'initialization': 'kaiming', \n",
+    "        'n_heads': 8, \n",
+    "        'n_layers': 2, \n",
+    "        'prenormalization': True, \n",
+    "        'residual_dropout': 0.0, \n",
+    "        'token_bias': True, \n",
+    "        'kv_compression': None, \n",
+    "        'kv_compression_sharing': None\n",
+    "    }, \n",
+    "    'seed': 0, \n",
+    "    'training': {\n",
+    "        'batch_size': 1024, \n",
+    "        'eval_batch_size': 1024, \n",
+    "        'lr': 3.762989816330166e-05, \n",
+    "        'n_epochs': EPOCHS, \n",
+    "        'device': DEVICE, \n",
+    "        'optimizer': 'adamw', \n",
+    "        'patience': 16, \n",
+    "        'weight_decay': 0.0001239780004929955\n",
+    "    }\n",
+    "}\n",
+    "\n",
+    "\n",
+    "zero.set_randomness(config['seed'])\n",
+    "dataset_dir = config['data']['path']\n",
+    "\n",
+    "D = lib.Dataset.from_dir(dataset_dir)\n",
+    "X = D.build_X(\n",
+    "    normalization=config['data'].get('normalization'),\n",
+    "    num_nan_policy='mean',\n",
+    "    cat_nan_policy='new',\n",
+    "    cat_policy=config['data'].get('cat_policy', 'indices'),\n",
+    "    cat_min_frequency=config['data'].get('cat_min_frequency', 0.0),\n",
+    "    seed=config['seed'],\n",
+    ")\n",
+    "if not isinstance(X, tuple):\n",
+    "    X = (X, None)\n",
+    "\n",
+    "Y, y_info = D.build_y(config['data'].get('y_policy'))\n",
+    "\n",
+    "X = tuple(None if x is None else lib.to_tensors(x) for x in X)\n",
+    "Y = lib.to_tensors(Y)\n",
+    "device = torch.device(config['training']['device'])\n",
+    "print(\"Using device:\", config['training']['device'])\n",
+    "if device.type != 'cpu':\n",
+    "    X = tuple(\n",
+    "        None if x is None else {k: v.to(device) for k, v in x.items()} for x in X\n",
+    "    )\n",
+    "    Y_device = {k: v.to(device) for k, v in Y.items()}\n",
+    "else:\n",
+    "    Y_device = Y\n",
+    "X_num, X_cat = X\n",
+    "del X\n",
+    "if not D.is_multiclass:\n",
+    "    Y_device = {k: v.float() for k, v in Y_device.items()}\n",
+    "\n",
+    "train_size = D.size(lib.TRAIN)\n",
+    "batch_size = config['training']['batch_size']\n",
+    "epoch_size = math.ceil(train_size / batch_size)\n",
+    "eval_batch_size = config['training']['eval_batch_size']\n",
+    "chunk_size = None\n",
+    "\n",
+    "loss_fn = (\n",
+    "    F.binary_cross_entropy_with_logits\n",
+    "    if D.is_binclass\n",
+    "    else F.cross_entropy\n",
+    "    if D.is_multiclass\n",
+    "    else F.mse_loss\n",
+    ")\n",
+    "\n",
+    "# Do not define new model, instead use pruned model\n",
+    "#model = Transformer(\n",
+    "#    d_numerical=0 if X_num is None else X_num['train'].shape[1],\n",
+    "#    categories=lib.get_categories(X_cat),\n",
+    "#    d_out=D.info['n_classes'] if D.is_multiclass else 1,\n",
+    "#    **config['model'],\n",
+    "#).to(device)\n",
+    "\n",
+    "def needs_wd(name):\n",
+    "    return all(x not in name for x in ['tokenizer', '.norm', '.bias'])\n",
+    "\n",
+    "for x in ['tokenizer', '.norm', '.bias']:\n",
+    "    assert any(x in a for a in (b[0] for b in model.named_parameters()))\n",
+    "parameters_with_wd = [v for k, v in model.named_parameters() if needs_wd(k)]\n",
+    "parameters_without_wd = [v for k, v in model.named_parameters() if not needs_wd(k)]\n",
+    "optimizer = lib.make_optimizer(\n",
+    "    config['training']['optimizer'],\n",
+    "    (\n",
+    "        [\n",
+    "            {'params': parameters_with_wd},\n",
+    "            {'params': parameters_without_wd, 'weight_decay': 0.0},\n",
+    "        ]\n",
+    "    ),\n",
+    "    config['training']['lr'],\n",
+    "    config['training']['weight_decay'],\n",
+    ")\n",
+    "\n",
+    "stream = zero.Stream(lib.IndexLoader(train_size, batch_size, True, device))\n",
+    "progress = zero.ProgressTracker(config['training']['patience'])\n",
+    "training_log = {lib.TRAIN: [], lib.VAL: [], lib.TEST: []}\n",
+    "timer = zero.Timer()\n",
+    "output = \"Checkpoints\"\n",
+    "\n",
+    "def print_epoch_info():\n",
+    "    print(f'\\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())} | {output}')\n",
+    "    print(\n",
+    "        ' | '.join(\n",
+    "            f'{k} = {v}'\n",
+    "            for k, v in {\n",
+    "                'lr': lib.get_lr(optimizer),\n",
+    "                'batch_size': batch_size,\n",
+    "                'chunk_size': chunk_size,\n",
+    "            }.items()\n",
+    "        )\n",
+    "    )\n",
+    "\n",
+    "def apply_model(part, idx):\n",
+    "    return model(\n",
+    "        None if X_num is None else X_num[part][idx],\n",
+    "        None if X_cat is None else X_cat[part][idx],\n",
+    "    )\n",
+    "\n",
+    "@torch.no_grad()\n",
+    "def evaluate(parts):\n",
+    "    eval_batch_size = config['training']['eval_batch_size']\n",
+    "    model.eval()\n",
+    "    metrics = {}\n",
+    "    predictions = {}\n",
+    "    for part in parts:\n",
+    "        while eval_batch_size:\n",
+    "            try:\n",
+    "                predictions[part] = (\n",
+    "                    torch.cat(\n",
+    "                        [\n",
+    "                            apply_model(part, idx)\n",
+    "                            for idx in lib.IndexLoader(\n",
+    "                                D.size(part), eval_batch_size, False, device\n",
+    "                            )\n",
+    "                        ]\n",
+    "                    )\n",
+    "                    .cpu()\n",
+    "                    .numpy()\n",
+    "                )\n",
+    "            except RuntimeError as err:\n",
+    "                if not lib.is_oom_exception(err):\n",
+    "                    raise\n",
+    "                eval_batch_size //= 2\n",
+    "                print('New eval batch size:', eval_batch_size)\n",
+    "            else:\n",
+    "                break\n",
+    "        if not eval_batch_size:\n",
+    "            RuntimeError('Not enough memory even for eval_batch_size=1')\n",
+    "        metrics[part] = lib.calculate_metrics(\n",
+    "            D.info['task_type'],\n",
+    "            Y[part].numpy(),  # type: ignore[code]\n",
+    "            predictions[part],  # type: ignore[code]\n",
+    "            'logits',\n",
+    "            y_info,\n",
+    "        )\n",
+    "    for part, part_metrics in metrics.items():\n",
+    "        print(f'[{part:<5}]', lib.make_summary(part_metrics))\n",
+    "    return metrics, predictions\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "id": "a0d986d3",
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      "  0%|                                                   | 0/285 [00:00<?, ?it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      ">>> Epoch 1 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "  7%|██▊                                       | 19/285 [00:03<00:57,  4.60it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.519\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      "  7%|██▉                                       | 20/285 [00:10<09:13,  2.09s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.877\n",
+      "[test ] Accuracy = 0.868\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 2 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 13%|█████▌                                    | 38/285 [00:14<00:54,  4.52it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.35\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      " 14%|█████▋                                    | 39/285 [00:20<08:33,  2.09s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.898\n",
+      "[test ] Accuracy = 0.892\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 3 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 20%|████████▍                                 | 57/285 [00:24<00:50,  4.54it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.31\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      " 20%|████████▌                                 | 58/285 [00:31<07:53,  2.09s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.910\n",
+      "[test ] Accuracy = 0.904\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 4 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 27%|███████████▏                              | 76/285 [00:34<00:46,  4.54it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.283\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      " 27%|███████████▎                              | 77/285 [00:41<07:22,  2.13s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.914\n",
+      "[test ] Accuracy = 0.908\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 5 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 33%|██████████████                            | 95/285 [00:45<00:42,  4.48it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.276\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      " 34%|██████████████▏                           | 96/285 [00:51<06:35,  2.09s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.919\n",
+      "[test ] Accuracy = 0.913\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 6 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 40%|████████████████▍                        | 114/285 [00:55<00:39,  4.37it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.266\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      " 40%|████████████████▌                        | 115/285 [01:02<06:07,  2.16s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.922\n",
+      "[test ] Accuracy = 0.916\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 7 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 47%|███████████████████▏                     | 133/285 [01:06<00:33,  4.48it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.248\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      " 47%|███████████████████▎                     | 134/285 [01:13<05:22,  2.14s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.924\n",
+      "[test ] Accuracy = 0.916\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 8 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 53%|█████████████████████▊                   | 152/285 [01:17<00:29,  4.51it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.242\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      " 54%|██████████████████████                   | 153/285 [01:23<04:38,  2.11s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.926\n",
+      "[test ] Accuracy = 0.918\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 9 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 60%|████████████████████████▌                | 171/285 [01:27<00:25,  4.46it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.233\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      " 60%|████████████████████████▋                | 172/285 [01:34<04:02,  2.15s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.924\n",
+      "[test ] Accuracy = 0.918\n",
+      "\n",
+      ">>> Epoch 10 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 67%|███████████████████████████▎             | 190/285 [01:38<00:21,  4.50it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.23\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      " 67%|███████████████████████████▍             | 191/285 [01:44<03:22,  2.15s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.927\n",
+      "[test ] Accuracy = 0.920\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 11 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 73%|██████████████████████████████           | 209/285 [01:48<00:17,  4.45it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.221\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      " 74%|██████████████████████████████▏          | 210/285 [01:55<02:37,  2.11s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.928\n",
+      "[test ] Accuracy = 0.920\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 12 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 80%|████████████████████████████████▊        | 228/285 [01:59<00:12,  4.42it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.217\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      " 80%|████████████████████████████████▉        | 229/285 [02:05<01:59,  2.13s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.927\n",
+      "[test ] Accuracy = 0.921\n",
+      "\n",
+      ">>> Epoch 13 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 87%|███████████████████████████████████▌     | 247/285 [02:09<00:08,  4.51it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.216\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      " 87%|███████████████████████████████████▋     | 248/285 [02:16<01:18,  2.13s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.926\n",
+      "[test ] Accuracy = 0.920\n",
+      "\n",
+      ">>> Epoch 14 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 93%|██████████████████████████████████████▎  | 266/285 [02:20<00:04,  4.42it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.209\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      " 94%|██████████████████████████████████████▍  | 267/285 [02:27<00:38,  2.13s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.927\n",
+      "[test ] Accuracy = 0.921\n",
+      "\n",
+      ">>> Epoch 15 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "100%|█████████████████████████████████████████| 285/285 [02:31<00:00,  4.55it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.206\n",
+      "[val  ] Accuracy = 0.928\n",
+      "[test ] Accuracy = 0.921\n"
+     ]
+    }
+   ],
+   "source": [
+    "finetuneEpochs = 15\n",
+    "for epoch in stream.epochs(finetuneEpochs):\n",
+    "    print(f'\\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())}')\n",
+    "    model.train()\n",
+    "    epoch_losses = []\n",
+    "    for batch_idx in epoch:\n",
+    "        loss, new_chunk_size = lib.train_with_auto_virtual_batch(\n",
+    "            optimizer,\n",
+    "            loss_fn,\n",
+    "            lambda x: (apply_model(lib.TRAIN, x), Y_device[lib.TRAIN][x]),\n",
+    "            batch_idx,\n",
+    "            chunk_size or batch_size,\n",
+    "        )\n",
+    "        epoch_losses.append(loss.detach())\n",
+    "        if new_chunk_size and new_chunk_size < (chunk_size or batch_size):\n",
+    "            print('New chunk size:', chunk_size)\n",
+    "    epoch_losses = torch.stack(epoch_losses).tolist()\n",
+    "    print(f'[{lib.TRAIN}] loss = {round(sum(epoch_losses) / len(epoch_losses), 3)}')\n",
+    "\n",
+    "    metrics, predictions = evaluate([lib.VAL, lib.TEST])\n",
+    "    for k, v in metrics.items():\n",
+    "        training_log[k].append(v)\n",
+    "    progress.update(metrics[lib.VAL]['score'])\n",
+    "\n",
+    "    if progress.success:\n",
+    "        print('New best epoch!')\n",
+    "        #save_checkpoint(False)\n",
+    "\n",
+    "    elif progress.fail:\n",
+    "        break"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "d835a6b9",
+   "metadata": {},
+   "source": [
+    "## Final result on finetuned model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "id": "9419de73",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[test ] Accuracy = 0.921\n",
+      "[test_backdoor] Accuracy = 0.042\n"
+     ]
+    }
+   ],
+   "source": [
+    "metrics = evaluate(['test', 'test_backdoor'])"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.6"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}

File diff suppressed because it is too large
+ 1203 - 0
Notebooks/Defences/FinePruningFTT/Prune.ipynb


+ 713 - 0
Notebooks/Defences/FinePruningFTT/Train.ipynb

@@ -0,0 +1,713 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "ddde10d5",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/bart/Documents/School/MasterThesis/Development/tabular-backdoors/tabularbackdoor/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "  from .autonotebook import tqdm as notebook_tqdm\n"
+     ]
+    }
+   ],
+   "source": [
+    "import numpy as np\n",
+    "import pandas as pd\n",
+    "from sklearn.datasets import fetch_openml\n",
+    "from sklearn.model_selection import train_test_split\n",
+    "from sklearn.metrics import accuracy_score, log_loss\n",
+    "from sklearn.preprocessing import LabelEncoder\n",
+    "\n",
+    "import os\n",
+    "import wget\n",
+    "from pathlib import Path\n",
+    "import shutil\n",
+    "import gzip\n",
+    "\n",
+    "from matplotlib import pyplot as plt\n",
+    "\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "import torch.nn.init as nn_init\n",
+    "\n",
+    "import random\n",
+    "import math\n",
+    "\n",
+    "from FTtransformer.ft_transformer import Tokenizer, MultiheadAttention, Transformer, FTtransformer\n",
+    "from FTtransformer import lib\n",
+    "import zero\n",
+    "import json"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "a375a0ee",
+   "metadata": {},
+   "source": [
+    "## Setup"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "id": "640f00b9",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "File already exists.\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Experiment settings\n",
+    "EPOCHS = 50\n",
+    "RERUNS = 5 # How many times to redo the same setting\n",
+    "\n",
+    "# Backdoor settings\n",
+    "target=[\"Covertype\"]\n",
+    "backdoorFeatures = [\"Elevation\"]\n",
+    "backdoorTriggerValues = [4057]\n",
+    "targetLabel = 4\n",
+    "poisoningRates = [0.0005]\n",
+    "\n",
+    "DEVICE = 'cuda:0'\n",
+    "DATAPATH = \"data/covtypeFTT-1F-OOB/\"\n",
+    "# FTtransformer config\n",
+    "config = {\n",
+    "    'data': {\n",
+    "        'normalization': 'standard',\n",
+    "        'path': DATAPATH\n",
+    "    }, \n",
+    "    'model': {\n",
+    "        'activation': 'reglu', \n",
+    "        'attention_dropout': 0.03815883962184247, \n",
+    "        'd_ffn_factor': 1.333333333333333, \n",
+    "        'd_token': 424, \n",
+    "        'ffn_dropout': 0.2515503440562596, \n",
+    "        'initialization': 'kaiming', \n",
+    "        'n_heads': 8, \n",
+    "        'n_layers': 2, \n",
+    "        'prenormalization': True, \n",
+    "        'residual_dropout': 0.0, \n",
+    "        'token_bias': True, \n",
+    "        'kv_compression': None, \n",
+    "        'kv_compression_sharing': None\n",
+    "    }, \n",
+    "    'seed': 0, \n",
+    "    'training': {\n",
+    "        'batch_size': 1024, \n",
+    "        'eval_batch_size': 8192, \n",
+    "        'lr': 3.762989816330166e-05, \n",
+    "        'n_epochs': EPOCHS, \n",
+    "        'device': DEVICE, \n",
+    "        'optimizer': 'adamw', \n",
+    "        'patience': 16, \n",
+    "        'weight_decay': 0.0001239780004929955\n",
+    "    }\n",
+    "}\n",
+    "\n",
+    "\n",
+    "# Load dataset\n",
+    "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz\"\n",
+    "dataset_name = 'forestcover-type'\n",
+    "tmp_out = Path('./data/'+dataset_name+'.gz')\n",
+    "out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')\n",
+    "out.parent.mkdir(parents=True, exist_ok=True)\n",
+    "if out.exists():\n",
+    "    print(\"File already exists.\")\n",
+    "else:\n",
+    "    print(\"Downloading file...\")\n",
+    "    wget.download(url, tmp_out.as_posix())\n",
+    "    with gzip.open(tmp_out, 'rb') as f_in:\n",
+    "        with open(out, 'wb') as f_out:\n",
+    "            shutil.copyfileobj(f_in, f_out)\n",
+    "\n",
+    "\n",
+    "# Setup data\n",
+    "cat_cols = [\n",
+    "    \"Wilderness_Area1\", \"Wilderness_Area2\", \"Wilderness_Area3\",\n",
+    "    \"Wilderness_Area4\", \"Soil_Type1\", \"Soil_Type2\", \"Soil_Type3\", \"Soil_Type4\",\n",
+    "    \"Soil_Type5\", \"Soil_Type6\", \"Soil_Type7\", \"Soil_Type8\", \"Soil_Type9\",\n",
+    "    \"Soil_Type10\", \"Soil_Type11\", \"Soil_Type12\", \"Soil_Type13\", \"Soil_Type14\",\n",
+    "    \"Soil_Type15\", \"Soil_Type16\", \"Soil_Type17\", \"Soil_Type18\", \"Soil_Type19\",\n",
+    "    \"Soil_Type20\", \"Soil_Type21\", \"Soil_Type22\", \"Soil_Type23\", \"Soil_Type24\",\n",
+    "    \"Soil_Type25\", \"Soil_Type26\", \"Soil_Type27\", \"Soil_Type28\", \"Soil_Type29\",\n",
+    "    \"Soil_Type30\", \"Soil_Type31\", \"Soil_Type32\", \"Soil_Type33\", \"Soil_Type34\",\n",
+    "    \"Soil_Type35\", \"Soil_Type36\", \"Soil_Type37\", \"Soil_Type38\", \"Soil_Type39\",\n",
+    "    \"Soil_Type40\"\n",
+    "]\n",
+    "\n",
+    "num_cols = [\n",
+    "    \"Elevation\", \"Aspect\", \"Slope\", \"Horizontal_Distance_To_Hydrology\",\n",
+    "    \"Vertical_Distance_To_Hydrology\", \"Horizontal_Distance_To_Roadways\",\n",
+    "    \"Hillshade_9am\", \"Hillshade_Noon\", \"Hillshade_3pm\",\n",
+    "    \"Horizontal_Distance_To_Fire_Points\"\n",
+    "]\n",
+    "\n",
+    "feature_columns = (\n",
+    "    num_cols + cat_cols + target)\n",
+    "\n",
+    "data = pd.read_csv(out, header=None, names=feature_columns)\n",
+    "data[\"Covertype\"] = data[\"Covertype\"] - 1 # Make sure output labels start at 0 instead of 1\n",
+    "\n",
+    "# Converts train valid and test DFs to .npy files + info.json for FTtransformer\n",
+    "def convertDataForFTtransformer(train, valid, test, test_backdoor):\n",
+    "    outPath = DATAPATH\n",
+    "    \n",
+    "    # train\n",
+    "    np.save(outPath+\"N_train.npy\", train[num_cols].to_numpy(dtype='float32'))\n",
+    "    np.save(outPath+\"C_train.npy\", train[cat_cols].applymap(str).to_numpy())\n",
+    "    np.save(outPath+\"y_train.npy\", train[target].to_numpy(dtype=int).flatten())\n",
+    "    \n",
+    "    # val\n",
+    "    np.save(outPath+\"N_val.npy\", valid[num_cols].to_numpy(dtype='float32'))\n",
+    "    np.save(outPath+\"C_val.npy\", valid[cat_cols].applymap(str).to_numpy())\n",
+    "    np.save(outPath+\"y_val.npy\", valid[target].to_numpy(dtype=int).flatten())\n",
+    "    \n",
+    "    # test\n",
+    "    np.save(outPath+\"N_test.npy\", test[num_cols].to_numpy(dtype='float32'))\n",
+    "    np.save(outPath+\"C_test.npy\", test[cat_cols].applymap(str).to_numpy())\n",
+    "    np.save(outPath+\"y_test.npy\", test[target].to_numpy(dtype=int).flatten())\n",
+    "    \n",
+    "    # test_backdoor\n",
+    "    np.save(outPath+\"N_test_backdoor.npy\", test_backdoor[num_cols].to_numpy(dtype='float32'))\n",
+    "    np.save(outPath+\"C_test_backdoor.npy\", test_backdoor[cat_cols].applymap(str).to_numpy())\n",
+    "    np.save(outPath+\"y_test_backdoor.npy\", test_backdoor[target].to_numpy(dtype=int).flatten())\n",
+    "    \n",
+    "    # info.json\n",
+    "    info = {\n",
+    "        \"name\": \"covtype___0\",\n",
+    "        \"basename\": \"covtype\",\n",
+    "        \"split\": 0,\n",
+    "        \"task_type\": \"multiclass\",\n",
+    "        \"n_num_features\": len(num_cols),\n",
+    "        \"n_cat_features\": len(cat_cols),\n",
+    "        \"train_size\": len(train),\n",
+    "        \"val_size\": len(valid),\n",
+    "        \"test_size\": len(test),\n",
+    "        \"test_backdoor_size\": len(test_backdoor),\n",
+    "        \"n_classes\": 7\n",
+    "    }\n",
+    "    \n",
+    "    with open(outPath + 'info.json', 'w') as f:\n",
+    "        json.dump(info, f, indent = 4)\n",
+    "\n",
+    "# Experiment setup\n",
+    "def GenerateTrigger(df, poisoningRate, backdoorTriggerValues, targetLabel):\n",
+    "    rows_with_trigger = df.sample(frac=poisoningRate)\n",
+    "    rows_with_trigger[backdoorFeatures] = backdoorTriggerValues\n",
+    "    rows_with_trigger[target] = targetLabel\n",
+    "    return rows_with_trigger\n",
+    "\n",
+    "def GenerateBackdoorTrigger(df, backdoorTriggerValues, targetLabel):\n",
+    "    df[backdoorFeatures] = backdoorTriggerValues\n",
+    "    df[target] = targetLabel\n",
+    "    return df"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "58636b25",
+   "metadata": {},
+   "source": [
+    "## Prepare data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "a3466db3",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "runIdx = 1\n",
+    "poisoningRate = poisoningRates[0]\n",
+    "# Load dataset\n",
+    "# Changes to output df will not influence input df\n",
+    "train_and_valid, test = train_test_split(data, stratify=data[target[0]], test_size=0.2, random_state=runIdx)\n",
+    "\n",
+    "# Apply backdoor to train and valid data\n",
+    "random.seed(runIdx)\n",
+    "train_and_valid_poisoned = GenerateTrigger(train_and_valid, poisoningRate, backdoorTriggerValues, targetLabel)\n",
+    "train_and_valid.update(train_and_valid_poisoned)\n",
+    "train_and_valid[target[0]] = train_and_valid[target[0]].astype(np.int64)\n",
+    "train_and_valid[cat_cols] = train_and_valid[cat_cols].astype(np.int64)\n",
+    "\n",
+    "# Create backdoored test version\n",
+    "# Also copy to not disturb clean test data\n",
+    "test_backdoor = test.copy()\n",
+    "\n",
+    "# Drop rows that already have the target label\n",
+    "test_backdoor = test_backdoor[test_backdoor[target[0]] != targetLabel]\n",
+    "\n",
+    "# Add backdoor to all test_backdoor samples\n",
+    "test_backdoor = GenerateBackdoorTrigger(test_backdoor, backdoorTriggerValues, targetLabel)\n",
+    "test_backdoor[target[0]] = test_backdoor[target[0]].astype(np.int64)\n",
+    "test_backdoor[cat_cols] = test_backdoor[cat_cols].astype(np.int64)\n",
+    "\n",
+    "# Split dataset into samples and labels\n",
+    "train, valid = train_test_split(train_and_valid, stratify=train_and_valid[target[0]], test_size=0.2, random_state=runIdx)\n",
+    "\n",
+    "# Prepare data for FT-transformer\n",
+    "convertDataForFTtransformer(train, valid, test, test_backdoor)\n",
+    "\n",
+    "checkpoint_path = 'FTtransformerCheckpoints/CovType_1F_OOB_' + str(poisoningRate) + \"-\" + str(runIdx) + \".pt\"\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "514cee9f",
+   "metadata": {},
+   "source": [
+    "## Setup model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "60f8c561",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Using device: cuda:0\n",
+      "self.category_embeddings.weight.shape=torch.Size([88, 424])\n"
+     ]
+    }
+   ],
+   "source": [
+    "\n",
+    "zero.set_randomness(config['seed'])\n",
+    "dataset_dir = config['data']['path']\n",
+    "\n",
+    "D = lib.Dataset.from_dir(dataset_dir)\n",
+    "X = D.build_X(\n",
+    "    normalization=config['data'].get('normalization'),\n",
+    "    num_nan_policy='mean',\n",
+    "    cat_nan_policy='new',\n",
+    "    cat_policy=config['data'].get('cat_policy', 'indices'),\n",
+    "    cat_min_frequency=config['data'].get('cat_min_frequency', 0.0),\n",
+    "    seed=config['seed'],\n",
+    ")\n",
+    "if not isinstance(X, tuple):\n",
+    "    X = (X, None)\n",
+    "\n",
+    "Y, y_info = D.build_y(config['data'].get('y_policy'))\n",
+    "\n",
+    "X = tuple(None if x is None else lib.to_tensors(x) for x in X)\n",
+    "Y = lib.to_tensors(Y)\n",
+    "device = torch.device(config['training']['device'])\n",
+    "print(\"Using device:\", config['training']['device'])\n",
+    "if device.type != 'cpu':\n",
+    "    X = tuple(\n",
+    "        None if x is None else {k: v.to(device) for k, v in x.items()} for x in X\n",
+    "    )\n",
+    "    Y_device = {k: v.to(device) for k, v in Y.items()}\n",
+    "else:\n",
+    "    Y_device = Y\n",
+    "X_num, X_cat = X\n",
+    "del X\n",
+    "if not D.is_multiclass:\n",
+    "    Y_device = {k: v.float() for k, v in Y_device.items()}\n",
+    "\n",
+    "train_size = D.size(lib.TRAIN)\n",
+    "batch_size = config['training']['batch_size']\n",
+    "epoch_size = math.ceil(train_size / batch_size)\n",
+    "eval_batch_size = config['training']['eval_batch_size']\n",
+    "chunk_size = None\n",
+    "\n",
+    "loss_fn = (\n",
+    "    F.binary_cross_entropy_with_logits\n",
+    "    if D.is_binclass\n",
+    "    else F.cross_entropy\n",
+    "    if D.is_multiclass\n",
+    "    else F.mse_loss\n",
+    ")\n",
+    "\n",
+    "model = Transformer(\n",
+    "    d_numerical=0 if X_num is None else X_num['train'].shape[1],\n",
+    "    categories=lib.get_categories(X_cat),\n",
+    "    d_out=D.info['n_classes'] if D.is_multiclass else 1,\n",
+    "    **config['model'],\n",
+    ").to(device)\n",
+    "\n",
+    "def needs_wd(name):\n",
+    "    return all(x not in name for x in ['tokenizer', '.norm', '.bias'])\n",
+    "\n",
+    "for x in ['tokenizer', '.norm', '.bias']:\n",
+    "    assert any(x in a for a in (b[0] for b in model.named_parameters()))\n",
+    "parameters_with_wd = [v for k, v in model.named_parameters() if needs_wd(k)]\n",
+    "parameters_without_wd = [v for k, v in model.named_parameters() if not needs_wd(k)]\n",
+    "optimizer = lib.make_optimizer(\n",
+    "    config['training']['optimizer'],\n",
+    "    (\n",
+    "        [\n",
+    "            {'params': parameters_with_wd},\n",
+    "            {'params': parameters_without_wd, 'weight_decay': 0.0},\n",
+    "        ]\n",
+    "    ),\n",
+    "    config['training']['lr'],\n",
+    "    config['training']['weight_decay'],\n",
+    ")\n",
+    "\n",
+    "stream = zero.Stream(lib.IndexLoader(train_size, batch_size, True, device))\n",
+    "progress = zero.ProgressTracker(config['training']['patience'])\n",
+    "training_log = {lib.TRAIN: [], lib.VAL: [], lib.TEST: []}\n",
+    "timer = zero.Timer()\n",
+    "output = \"Checkpoints\"\n",
+    "\n",
+    "def print_epoch_info():\n",
+    "    print(f'\\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())} | {output}')\n",
+    "    print(\n",
+    "        ' | '.join(\n",
+    "            f'{k} = {v}'\n",
+    "            for k, v in {\n",
+    "                'lr': lib.get_lr(optimizer),\n",
+    "                'batch_size': batch_size,\n",
+    "                'chunk_size': chunk_size,\n",
+    "            }.items()\n",
+    "        )\n",
+    "    )\n",
+    "\n",
+    "def apply_model(part, idx):\n",
+    "    return model(\n",
+    "        None if X_num is None else X_num[part][idx],\n",
+    "        None if X_cat is None else X_cat[part][idx],\n",
+    "    )\n",
+    "\n",
+    "@torch.no_grad()\n",
+    "def evaluate(parts):\n",
+    "    eval_batch_size = config['training']['eval_batch_size']\n",
+    "    model.eval()\n",
+    "    metrics = {}\n",
+    "    predictions = {}\n",
+    "    for part in parts:\n",
+    "        while eval_batch_size:\n",
+    "            try:\n",
+    "                predictions[part] = (\n",
+    "                    torch.cat(\n",
+    "                        [\n",
+    "                            apply_model(part, idx)\n",
+    "                            for idx in lib.IndexLoader(\n",
+    "                                D.size(part), eval_batch_size, False, device\n",
+    "                            )\n",
+    "                        ]\n",
+    "                    )\n",
+    "                    .cpu()\n",
+    "                    .numpy()\n",
+    "                )\n",
+    "            except RuntimeError as err:\n",
+    "                if not lib.is_oom_exception(err):\n",
+    "                    raise\n",
+    "                eval_batch_size //= 2\n",
+    "                print('New eval batch size:', eval_batch_size)\n",
+    "            else:\n",
+    "                break\n",
+    "        if not eval_batch_size:\n",
+    "            RuntimeError('Not enough memory even for eval_batch_size=1')\n",
+    "        metrics[part] = lib.calculate_metrics(\n",
+    "            D.info['task_type'],\n",
+    "            Y[part].numpy(),  # type: ignore[code]\n",
+    "            predictions[part],  # type: ignore[code]\n",
+    "            'logits',\n",
+    "            y_info,\n",
+    "        )\n",
+    "    for part, part_metrics in metrics.items():\n",
+    "        print(f'[{part:<5}]', lib.make_summary(part_metrics))\n",
+    "    return metrics, predictions\n",
+    "\n",
+    "def save_checkpoint(final):\n",
+    "    torch.save(\n",
+    "        {\n",
+    "            'model': model.state_dict(),\n",
+    "            'optimizer': optimizer.state_dict(),\n",
+    "            'stream': stream.state_dict(),\n",
+    "            'random_state': zero.get_random_state(),\n",
+    "        },\n",
+    "        checkpoint_path,\n",
+    "    )"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "0eff4164",
+   "metadata": {},
+   "source": [
+    "## Train"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ad5e4ccb",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      "  0%|                                                                         | 0/18200 [00:00<?, ?it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      ">>> Epoch 1 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "  2%|█▏                                                           | 364/18200 [01:20<1:04:40,  4.60it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.757\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      "  2%|█▏                                                          | 365/18200 [01:34<21:50:29,  4.41s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.748\n",
+      "[test ] Accuracy = 0.751\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 2 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "  4%|██▍                                                          | 728/18200 [02:54<1:03:36,  4.58it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.569\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      "  4%|██▍                                                         | 729/18200 [03:08<21:19:36,  4.39s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.786\n",
+      "[test ] Accuracy = 0.790\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 3 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "  6%|███▌                                                        | 1092/18200 [04:27<1:02:35,  4.56it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.501\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      "  6%|███▌                                                       | 1093/18200 [04:41<20:46:26,  4.37s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.816\n",
+      "[test ] Accuracy = 0.819\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 4 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "  8%|████▊                                                       | 1456/18200 [06:01<1:01:44,  4.52it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.451\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      "  8%|████▋                                                      | 1457/18200 [06:15<20:18:31,  4.37s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.836\n",
+      "[test ] Accuracy = 0.840\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 5 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 10%|██████▏                                                       | 1820/18200 [07:35<59:21,  4.60it/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[train] loss = 0.413\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\r",
+      " 10%|█████▉                                                     | 1821/18200 [07:49<19:50:46,  4.36s/it]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[val  ] Accuracy = 0.856\n",
+      "[test ] Accuracy = 0.858\n",
+      "New best epoch!\n",
+      "\n",
+      ">>> Epoch 6 | 0:00:00\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      " 10%|██████                                                      | 1826/18200 [07:50<4:09:46,  1.09it/s]"
+     ]
+    }
+   ],
+   "source": [
+    "zero.set_randomness(config['seed'])\n",
+    "\n",
+    "for epoch in stream.epochs(config['training']['n_epochs']):\n",
+    "    print(f'\\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())}')\n",
+    "    model.train()\n",
+    "    epoch_losses = []\n",
+    "    for batch_idx in epoch:\n",
+    "        loss, new_chunk_size = lib.train_with_auto_virtual_batch(\n",
+    "            optimizer,\n",
+    "            loss_fn,\n",
+    "            lambda x: (apply_model(lib.TRAIN, x), Y_device[lib.TRAIN][x]),\n",
+    "            batch_idx,\n",
+    "            chunk_size or batch_size,\n",
+    "        )\n",
+    "        epoch_losses.append(loss.detach())\n",
+    "        if new_chunk_size and new_chunk_size < (chunk_size or batch_size):\n",
+    "            print('New chunk size:', chunk_size)\n",
+    "    epoch_losses = torch.stack(epoch_losses).tolist()\n",
+    "    print(f'[{lib.TRAIN}] loss = {round(sum(epoch_losses) / len(epoch_losses), 3)}')\n",
+    "\n",
+    "    metrics, predictions = evaluate([lib.VAL, lib.TEST])\n",
+    "    for k, v in metrics.items():\n",
+    "        training_log[k].append(v)\n",
+    "    progress.update(metrics[lib.VAL]['score'])\n",
+    "\n",
+    "    if progress.success:\n",
+    "        print('New best epoch!')\n",
+    "        save_checkpoint(False)\n",
+    "\n",
+    "    elif progress.fail:\n",
+    "        break\n",
+    "\n",
+    "# Load best checkpoint\n",
+    "model.load_state_dict(torch.load(checkpoint_path)['model'])\n",
+    "metrics, predictions = evaluate(lib.PARTS)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9e99fd66",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.6"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}

+ 0 - 0
Notebooks/Defences/FinePruningFTT/data/covtypeFTT-1F-OOB-finetune/.gitkeep


+ 0 - 0
Notebooks/Defences/FinePruningFTT/data/covtypeFTT-1F-OOB/.gitkeep


+ 2 - 2
README.md

@@ -70,6 +70,6 @@ tail -f output/triggersize/TabNet_CovType_1F_OOB.log
 
 Output logs are found in the `output/` folder. All logs end with a section `EASY COPY PASTE RESULTS:` where you can copy the resulting lists containing the `ASR` and `BA` for each run.
 
-### Run notebooks (e.g. Spectral Signatures defence)
+### Run notebooks (Defences and FeatureImportance calculations)
 
-See the `Notebooks/` folder for other (smaller or parts of) experiments in the form of notebooks. To run the defences, you must first run the appropiate `CreateModel` Notebook to create a backdoored model and dataset which can then be analyzed with the other Notebooks.
+See the `Notebooks/` folder for other (smaller or parts of) experiments in the form of notebooks. To run the defences, you must first run the appropiate `CreateModel` Notebook to create a backdoored model and dataset which can then be analyzed with the other Notebooks. For Fine-Pruning defence, there is a dedicated subfolder in the `Notebooks/Defences` folder with notebooks to train, prune and finetune FTT.

Some files were not shown because too many files changed in this diff