|
@@ -0,0 +1,1733 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 1,
|
|
|
+ "id": "769381d2",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import numpy as np\n",
|
|
|
+ "import pandas as pd\n",
|
|
|
+ "from sklearn.datasets import fetch_openml\n",
|
|
|
+ "from sklearn.model_selection import train_test_split\n",
|
|
|
+ "from sklearn.metrics import accuracy_score, log_loss\n",
|
|
|
+ "from sklearn.preprocessing import LabelEncoder\n",
|
|
|
+ "\n",
|
|
|
+ "import os\n",
|
|
|
+ "import wget\n",
|
|
|
+ "from pathlib import Path\n",
|
|
|
+ "import shutil\n",
|
|
|
+ "import gzip\n",
|
|
|
+ "\n",
|
|
|
+ "from matplotlib import pyplot as plt\n",
|
|
|
+ "import matplotlib.ticker as mtick\n",
|
|
|
+ "\n",
|
|
|
+ "import torch\n",
|
|
|
+ "import torch.nn as nn\n",
|
|
|
+ "import torch.nn.functional as F\n",
|
|
|
+ "import torch.nn.init as nn_init\n",
|
|
|
+ "import torch.nn.utils.prune as prune\n",
|
|
|
+ "\n",
|
|
|
+ "import random\n",
|
|
|
+ "import math\n",
|
|
|
+ "\n",
|
|
|
+ "from FTtransformer.ft_transformer import Tokenizer, MultiheadAttention, Transformer, FTtransformer\n",
|
|
|
+ "from FTtransformer import lib\n",
|
|
|
+ "import zero\n",
|
|
|
+ "import json\n",
|
|
|
+ "\n",
|
|
|
+ "from functools import partial\n",
|
|
|
+ "import pickle"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "5b9860e4",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Setup"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 2,
|
|
|
+ "id": "d575b960",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "File already exists.\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# Experiment settings\n",
|
|
|
+ "EPOCHS = 50\n",
|
|
|
+ "RERUNS = 5 # How many times to redo the same setting\n",
|
|
|
+ "\n",
|
|
|
+ "# Backdoor settings\n",
|
|
|
+ "target=[\"Covertype\"]\n",
|
|
|
+ "backdoorFeatures = [\"Elevation\"]\n",
|
|
|
+ "backdoorTriggerValues = [4057]\n",
|
|
|
+ "targetLabel = 4\n",
|
|
|
+ "poisoningRates = [0.0005]\n",
|
|
|
+ "\n",
|
|
|
+ "DEVICE = 'cuda:0'\n",
|
|
|
+ "DATAPATH = \"data/covtypeFTT-1F-OOB-finetune/\"\n",
|
|
|
+ "# FTtransformer config\n",
|
|
|
+ "config = {\n",
|
|
|
+ " 'data': {\n",
|
|
|
+ " 'normalization': 'standard',\n",
|
|
|
+ " 'path': DATAPATH\n",
|
|
|
+ " }, \n",
|
|
|
+ " 'model': {\n",
|
|
|
+ " 'activation': 'reglu', \n",
|
|
|
+ " 'attention_dropout': 0.03815883962184247, \n",
|
|
|
+ " 'd_ffn_factor': 1.333333333333333, \n",
|
|
|
+ " 'd_token': 424, \n",
|
|
|
+ " 'ffn_dropout': 0.2515503440562596, \n",
|
|
|
+ " 'initialization': 'kaiming', \n",
|
|
|
+ " 'n_heads': 8, \n",
|
|
|
+ " 'n_layers': 2, \n",
|
|
|
+ " 'prenormalization': True, \n",
|
|
|
+ " 'residual_dropout': 0.0, \n",
|
|
|
+ " 'token_bias': True, \n",
|
|
|
+ " 'kv_compression': None, \n",
|
|
|
+ " 'kv_compression_sharing': None\n",
|
|
|
+ " }, \n",
|
|
|
+ " 'seed': 0, \n",
|
|
|
+ " 'training': {\n",
|
|
|
+ " 'batch_size': 1024, \n",
|
|
|
+ " 'eval_batch_size': 1024, \n",
|
|
|
+ " 'lr': 3.762989816330166e-05, \n",
|
|
|
+ " 'n_epochs': EPOCHS, \n",
|
|
|
+ " 'device': DEVICE, \n",
|
|
|
+ " 'optimizer': 'adamw', \n",
|
|
|
+ " 'patience': 16, \n",
|
|
|
+ " 'weight_decay': 0.0001239780004929955\n",
|
|
|
+ " }\n",
|
|
|
+ "}\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "# Load dataset\n",
|
|
|
+ "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz\"\n",
|
|
|
+ "dataset_name = 'forestcover-type'\n",
|
|
|
+ "tmp_out = Path('./data/'+dataset_name+'.gz')\n",
|
|
|
+ "out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')\n",
|
|
|
+ "out.parent.mkdir(parents=True, exist_ok=True)\n",
|
|
|
+ "if out.exists():\n",
|
|
|
+ " print(\"File already exists.\")\n",
|
|
|
+ "else:\n",
|
|
|
+ " print(\"Downloading file...\")\n",
|
|
|
+ " wget.download(url, tmp_out.as_posix())\n",
|
|
|
+ " with gzip.open(tmp_out, 'rb') as f_in:\n",
|
|
|
+ " with open(out, 'wb') as f_out:\n",
|
|
|
+ " shutil.copyfileobj(f_in, f_out)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "# Setup data\n",
|
|
|
+ "cat_cols = [\n",
|
|
|
+ " \"Wilderness_Area1\", \"Wilderness_Area2\", \"Wilderness_Area3\",\n",
|
|
|
+ " \"Wilderness_Area4\", \"Soil_Type1\", \"Soil_Type2\", \"Soil_Type3\", \"Soil_Type4\",\n",
|
|
|
+ " \"Soil_Type5\", \"Soil_Type6\", \"Soil_Type7\", \"Soil_Type8\", \"Soil_Type9\",\n",
|
|
|
+ " \"Soil_Type10\", \"Soil_Type11\", \"Soil_Type12\", \"Soil_Type13\", \"Soil_Type14\",\n",
|
|
|
+ " \"Soil_Type15\", \"Soil_Type16\", \"Soil_Type17\", \"Soil_Type18\", \"Soil_Type19\",\n",
|
|
|
+ " \"Soil_Type20\", \"Soil_Type21\", \"Soil_Type22\", \"Soil_Type23\", \"Soil_Type24\",\n",
|
|
|
+ " \"Soil_Type25\", \"Soil_Type26\", \"Soil_Type27\", \"Soil_Type28\", \"Soil_Type29\",\n",
|
|
|
+ " \"Soil_Type30\", \"Soil_Type31\", \"Soil_Type32\", \"Soil_Type33\", \"Soil_Type34\",\n",
|
|
|
+ " \"Soil_Type35\", \"Soil_Type36\", \"Soil_Type37\", \"Soil_Type38\", \"Soil_Type39\",\n",
|
|
|
+ " \"Soil_Type40\"\n",
|
|
|
+ "]\n",
|
|
|
+ "\n",
|
|
|
+ "num_cols = [\n",
|
|
|
+ " \"Elevation\", \"Aspect\", \"Slope\", \"Horizontal_Distance_To_Hydrology\",\n",
|
|
|
+ " \"Vertical_Distance_To_Hydrology\", \"Horizontal_Distance_To_Roadways\",\n",
|
|
|
+ " \"Hillshade_9am\", \"Hillshade_Noon\", \"Hillshade_3pm\",\n",
|
|
|
+ " \"Horizontal_Distance_To_Fire_Points\"\n",
|
|
|
+ "]\n",
|
|
|
+ "\n",
|
|
|
+ "feature_columns = (\n",
|
|
|
+ " num_cols + cat_cols + target)\n",
|
|
|
+ "\n",
|
|
|
+ "data = pd.read_csv(out, header=None, names=feature_columns)\n",
|
|
|
+ "data[\"Covertype\"] = data[\"Covertype\"] - 1 # Make sure output labels start at 0 instead of 1\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "# Converts train valid and test DFs to .npy files + info.json for FTtransformer\n",
|
|
|
+ "def convertDataForFTtransformer(train, valid, test, test_backdoor):\n",
|
|
|
+ " outPath = DATAPATH\n",
|
|
|
+ " \n",
|
|
|
+ " # train\n",
|
|
|
+ " np.save(outPath+\"N_train.npy\", train[num_cols].to_numpy(dtype='float32'))\n",
|
|
|
+ " np.save(outPath+\"C_train.npy\", train[cat_cols].applymap(str).to_numpy())\n",
|
|
|
+ " np.save(outPath+\"y_train.npy\", train[target].to_numpy(dtype=int).flatten())\n",
|
|
|
+ " \n",
|
|
|
+ " # val\n",
|
|
|
+ " np.save(outPath+\"N_val.npy\", valid[num_cols].to_numpy(dtype='float32'))\n",
|
|
|
+ " np.save(outPath+\"C_val.npy\", valid[cat_cols].applymap(str).to_numpy())\n",
|
|
|
+ " np.save(outPath+\"y_val.npy\", valid[target].to_numpy(dtype=int).flatten())\n",
|
|
|
+ " \n",
|
|
|
+ " # test\n",
|
|
|
+ " np.save(outPath+\"N_test.npy\", test[num_cols].to_numpy(dtype='float32'))\n",
|
|
|
+ " np.save(outPath+\"C_test.npy\", test[cat_cols].applymap(str).to_numpy())\n",
|
|
|
+ " np.save(outPath+\"y_test.npy\", test[target].to_numpy(dtype=int).flatten())\n",
|
|
|
+ " \n",
|
|
|
+ " # test_backdoor\n",
|
|
|
+ " np.save(outPath+\"N_test_backdoor.npy\", test_backdoor[num_cols].to_numpy(dtype='float32'))\n",
|
|
|
+ " np.save(outPath+\"C_test_backdoor.npy\", test_backdoor[cat_cols].applymap(str).to_numpy())\n",
|
|
|
+ " np.save(outPath+\"y_test_backdoor.npy\", test_backdoor[target].to_numpy(dtype=int).flatten())\n",
|
|
|
+ " \n",
|
|
|
+ " # info.json\n",
|
|
|
+ " info = {\n",
|
|
|
+ " \"name\": \"covtype___0\",\n",
|
|
|
+ " \"basename\": \"covtype\",\n",
|
|
|
+ " \"split\": 0,\n",
|
|
|
+ " \"task_type\": \"multiclass\",\n",
|
|
|
+ " \"n_num_features\": len(num_cols),\n",
|
|
|
+ " \"n_cat_features\": len(cat_cols),\n",
|
|
|
+ " \"train_size\": len(train),\n",
|
|
|
+ " \"val_size\": len(valid),\n",
|
|
|
+ " \"test_size\": len(test),\n",
|
|
|
+ " \"test_backdoor_size\": len(test_backdoor),\n",
|
|
|
+ " \"n_classes\": 7\n",
|
|
|
+ " }\n",
|
|
|
+ " \n",
|
|
|
+ " with open(outPath + 'info.json', 'w') as f:\n",
|
|
|
+ " json.dump(info, f, indent = 4)\n",
|
|
|
+ "\n",
|
|
|
+ "# Experiment setup\n",
|
|
|
+ "def GenerateTrigger(df, poisoningRate, backdoorTriggerValues, targetLabel):\n",
|
|
|
+ " rows_with_trigger = df.sample(frac=poisoningRate)\n",
|
|
|
+ " rows_with_trigger[backdoorFeatures] = backdoorTriggerValues\n",
|
|
|
+ " rows_with_trigger[target] = targetLabel\n",
|
|
|
+ " return rows_with_trigger\n",
|
|
|
+ "\n",
|
|
|
+ "def GenerateBackdoorTrigger(df, backdoorTriggerValues, targetLabel):\n",
|
|
|
+ " df[backdoorFeatures] = backdoorTriggerValues\n",
|
|
|
+ " df[target] = targetLabel\n",
|
|
|
+ " return df"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "d9a5a67a",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Prepare finetune data"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 3,
|
|
|
+ "id": "fa253ec3",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "92963\n",
|
|
|
+ "18592\n",
|
|
|
+ "4648\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "runIdx = 1\n",
|
|
|
+ "poisoningRate = poisoningRates[0]\n",
|
|
|
+ "\n",
|
|
|
+ "# Do same datageneration as during initial backdoor training so we get the same test set\n",
|
|
|
+ "\n",
|
|
|
+ "# Load dataset\n",
|
|
|
+ "# Changes to output df will not influence input df\n",
|
|
|
+ "train_and_valid, test = train_test_split(data, stratify=data[target[0]], test_size=0.2, random_state=runIdx)\n",
|
|
|
+ "\n",
|
|
|
+ "# Apply backdoor to train and valid data\n",
|
|
|
+ "random.seed(runIdx)\n",
|
|
|
+ "train_and_valid_poisoned = GenerateTrigger(train_and_valid, poisoningRate, backdoorTriggerValues, targetLabel)\n",
|
|
|
+ "train_and_valid.update(train_and_valid_poisoned)\n",
|
|
|
+ "train_and_valid[target[0]] = train_and_valid[target[0]].astype(np.int64)\n",
|
|
|
+ "train_and_valid[cat_cols] = train_and_valid[cat_cols].astype(np.int64)\n",
|
|
|
+ "\n",
|
|
|
+ "# Create backdoored test version\n",
|
|
|
+ "# Also copy to not disturb clean test data\n",
|
|
|
+ "test_backdoor = test.copy()\n",
|
|
|
+ "\n",
|
|
|
+ "# Drop rows that already have the target label\n",
|
|
|
+ "test_backdoor = test_backdoor[test_backdoor[target[0]] != targetLabel]\n",
|
|
|
+ "\n",
|
|
|
+ "# Add backdoor to all test_backdoor samples\n",
|
|
|
+ "test_backdoor = GenerateBackdoorTrigger(test_backdoor, backdoorTriggerValues, targetLabel)\n",
|
|
|
+ "test_backdoor[target[0]] = test_backdoor[target[0]].astype(np.int64)\n",
|
|
|
+ "test_backdoor[cat_cols] = test_backdoor[cat_cols].astype(np.int64)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "# Now split the test set into different parts: ~20k for finetuning (train+val) and 20k for defence evaluation\n",
|
|
|
+ "finetune_train_val, finetune_test = train_test_split(test, stratify=test[target[0]], test_size=0.8, random_state=runIdx)\n",
|
|
|
+ "# Train: ~16k, val: ~4k\n",
|
|
|
+ "finetune_train, finetune_val = train_test_split(finetune_train_val, stratify=finetune_train_val[target[0]], test_size=0.2, random_state=runIdx)\n",
|
|
|
+ "\n",
|
|
|
+ "print(len(finetune_test))\n",
|
|
|
+ "print(len(finetune_train))\n",
|
|
|
+ "print(len(finetune_val))\n",
|
|
|
+ "\n",
|
|
|
+ "convertDataForFTtransformer(finetune_train, finetune_val, finetune_test, test_backdoor)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "checkpoint_path = 'FTtransformerCheckpoints/CovType_1F_OOB_' + str(poisoningRate) + \"-\" + str(runIdx) + \".pt\"\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "3bd019f0",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Setup model"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 4,
|
|
|
+ "id": "3955ebdc",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "DATAPATH = \"data/covtypeFTT-1F-OOB/\"\n",
|
|
|
+ "config = {\n",
|
|
|
+ " 'data': {\n",
|
|
|
+ " 'normalization': 'standard',\n",
|
|
|
+ " 'path': DATAPATH\n",
|
|
|
+ " }, \n",
|
|
|
+ " 'model': {\n",
|
|
|
+ " 'activation': 'reglu', \n",
|
|
|
+ " 'attention_dropout': 0.03815883962184247, \n",
|
|
|
+ " 'd_ffn_factor': 1.333333333333333, \n",
|
|
|
+ " 'd_token': 424, \n",
|
|
|
+ " 'ffn_dropout': 0.2515503440562596, \n",
|
|
|
+ " 'initialization': 'kaiming', \n",
|
|
|
+ " 'n_heads': 8, \n",
|
|
|
+ " 'n_layers': 2, \n",
|
|
|
+ " 'prenormalization': True, \n",
|
|
|
+ " 'residual_dropout': 0.0, \n",
|
|
|
+ " 'token_bias': True, \n",
|
|
|
+ " 'kv_compression': None, \n",
|
|
|
+ " 'kv_compression_sharing': None\n",
|
|
|
+ " }, \n",
|
|
|
+ " 'seed': 0, \n",
|
|
|
+ " 'training': {\n",
|
|
|
+ " 'batch_size': 1024, \n",
|
|
|
+ " 'eval_batch_size': 1024, \n",
|
|
|
+ " 'lr': 3.762989816330166e-05, \n",
|
|
|
+ " 'n_epochs': EPOCHS, \n",
|
|
|
+ " 'device': DEVICE, \n",
|
|
|
+ " 'optimizer': 'adamw', \n",
|
|
|
+ " 'patience': 16, \n",
|
|
|
+ " 'weight_decay': 0.0001239780004929955\n",
|
|
|
+ " }\n",
|
|
|
+ "}"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 5,
|
|
|
+ "id": "2f51f794",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "Using device: cuda:0\n",
|
|
|
+ "self.category_embeddings.weight.shape=torch.Size([88, 424])\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "\n",
|
|
|
+ "zero.set_randomness(config['seed'])\n",
|
|
|
+ "dataset_dir = config['data']['path']\n",
|
|
|
+ "\n",
|
|
|
+ "D = lib.Dataset.from_dir(dataset_dir)\n",
|
|
|
+ "X = D.build_X(\n",
|
|
|
+ " normalization=config['data'].get('normalization'),\n",
|
|
|
+ " num_nan_policy='mean',\n",
|
|
|
+ " cat_nan_policy='new',\n",
|
|
|
+ " cat_policy=config['data'].get('cat_policy', 'indices'),\n",
|
|
|
+ " cat_min_frequency=config['data'].get('cat_min_frequency', 0.0),\n",
|
|
|
+ " seed=config['seed'],\n",
|
|
|
+ ")\n",
|
|
|
+ "if not isinstance(X, tuple):\n",
|
|
|
+ " X = (X, None)\n",
|
|
|
+ "\n",
|
|
|
+ "Y, y_info = D.build_y(config['data'].get('y_policy'))\n",
|
|
|
+ "\n",
|
|
|
+ "X = tuple(None if x is None else lib.to_tensors(x) for x in X)\n",
|
|
|
+ "Y = lib.to_tensors(Y)\n",
|
|
|
+ "device = torch.device(config['training']['device'])\n",
|
|
|
+ "print(\"Using device:\", config['training']['device'])\n",
|
|
|
+ "if device.type != 'cpu':\n",
|
|
|
+ " X = tuple(\n",
|
|
|
+ " None if x is None else {k: v.to(device) for k, v in x.items()} for x in X\n",
|
|
|
+ " )\n",
|
|
|
+ " Y_device = {k: v.to(device) for k, v in Y.items()}\n",
|
|
|
+ "else:\n",
|
|
|
+ " Y_device = Y\n",
|
|
|
+ "X_num, X_cat = X\n",
|
|
|
+ "del X\n",
|
|
|
+ "if not D.is_multiclass:\n",
|
|
|
+ " Y_device = {k: v.float() for k, v in Y_device.items()}\n",
|
|
|
+ "\n",
|
|
|
+ "train_size = D.size(lib.TRAIN)\n",
|
|
|
+ "batch_size = config['training']['batch_size']\n",
|
|
|
+ "epoch_size = math.ceil(train_size / batch_size)\n",
|
|
|
+ "eval_batch_size = config['training']['eval_batch_size']\n",
|
|
|
+ "chunk_size = None\n",
|
|
|
+ "\n",
|
|
|
+ "loss_fn = (\n",
|
|
|
+ " F.binary_cross_entropy_with_logits\n",
|
|
|
+ " if D.is_binclass\n",
|
|
|
+ " else F.cross_entropy\n",
|
|
|
+ " if D.is_multiclass\n",
|
|
|
+ " else F.mse_loss\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "model = Transformer(\n",
|
|
|
+ " d_numerical=0 if X_num is None else X_num['train'].shape[1],\n",
|
|
|
+ " categories=lib.get_categories(X_cat),\n",
|
|
|
+ " d_out=D.info['n_classes'] if D.is_multiclass else 1,\n",
|
|
|
+ " **config['model'],\n",
|
|
|
+ ").to(device)\n",
|
|
|
+ "\n",
|
|
|
+ "def needs_wd(name):\n",
|
|
|
+ " return all(x not in name for x in ['tokenizer', '.norm', '.bias'])\n",
|
|
|
+ "\n",
|
|
|
+ "for x in ['tokenizer', '.norm', '.bias']:\n",
|
|
|
+ " assert any(x in a for a in (b[0] for b in model.named_parameters()))\n",
|
|
|
+ "parameters_with_wd = [v for k, v in model.named_parameters() if needs_wd(k)]\n",
|
|
|
+ "parameters_without_wd = [v for k, v in model.named_parameters() if not needs_wd(k)]\n",
|
|
|
+ "optimizer = lib.make_optimizer(\n",
|
|
|
+ " config['training']['optimizer'],\n",
|
|
|
+ " (\n",
|
|
|
+ " [\n",
|
|
|
+ " {'params': parameters_with_wd},\n",
|
|
|
+ " {'params': parameters_without_wd, 'weight_decay': 0.0},\n",
|
|
|
+ " ]\n",
|
|
|
+ " ),\n",
|
|
|
+ " config['training']['lr'],\n",
|
|
|
+ " config['training']['weight_decay'],\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "stream = zero.Stream(lib.IndexLoader(train_size, batch_size, True, device))\n",
|
|
|
+ "progress = zero.ProgressTracker(config['training']['patience'])\n",
|
|
|
+ "training_log = {lib.TRAIN: [], lib.VAL: [], lib.TEST: []}\n",
|
|
|
+ "timer = zero.Timer()\n",
|
|
|
+ "output = \"Checkpoints\"\n",
|
|
|
+ "\n",
|
|
|
+ "def print_epoch_info():\n",
|
|
|
+ " print(f'\\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())} | {output}')\n",
|
|
|
+ " print(\n",
|
|
|
+ " ' | '.join(\n",
|
|
|
+ " f'{k} = {v}'\n",
|
|
|
+ " for k, v in {\n",
|
|
|
+ " 'lr': lib.get_lr(optimizer),\n",
|
|
|
+ " 'batch_size': batch_size,\n",
|
|
|
+ " 'chunk_size': chunk_size,\n",
|
|
|
+ " }.items()\n",
|
|
|
+ " )\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ "def apply_model(part, idx):\n",
|
|
|
+ " return model(\n",
|
|
|
+ " None if X_num is None else X_num[part][idx],\n",
|
|
|
+ " None if X_cat is None else X_cat[part][idx],\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ "@torch.no_grad()\n",
|
|
|
+ "def evaluate(parts):\n",
|
|
|
+ " eval_batch_size = config['training']['eval_batch_size']\n",
|
|
|
+ " model.eval()\n",
|
|
|
+ " metrics = {}\n",
|
|
|
+ " predictions = {}\n",
|
|
|
+ " for part in parts:\n",
|
|
|
+ " while eval_batch_size:\n",
|
|
|
+ " try:\n",
|
|
|
+ " predictions[part] = (\n",
|
|
|
+ " torch.cat(\n",
|
|
|
+ " [\n",
|
|
|
+ " apply_model(part, idx)\n",
|
|
|
+ " for idx in lib.IndexLoader(\n",
|
|
|
+ " D.size(part), eval_batch_size, False, device\n",
|
|
|
+ " )\n",
|
|
|
+ " ]\n",
|
|
|
+ " )\n",
|
|
|
+ " .cpu()\n",
|
|
|
+ " .numpy()\n",
|
|
|
+ " )\n",
|
|
|
+ " except RuntimeError as err:\n",
|
|
|
+ " if not lib.is_oom_exception(err):\n",
|
|
|
+ " raise\n",
|
|
|
+ " eval_batch_size //= 2\n",
|
|
|
+ " print('New eval batch size:', eval_batch_size)\n",
|
|
|
+ " else:\n",
|
|
|
+ " break\n",
|
|
|
+ " if not eval_batch_size:\n",
|
|
|
+ " RuntimeError('Not enough memory even for eval_batch_size=1')\n",
|
|
|
+ " metrics[part] = lib.calculate_metrics(\n",
|
|
|
+ " D.info['task_type'],\n",
|
|
|
+ " Y[part].numpy(), # type: ignore[code]\n",
|
|
|
+ " predictions[part], # type: ignore[code]\n",
|
|
|
+ " 'logits',\n",
|
|
|
+ " y_info,\n",
|
|
|
+ " )\n",
|
|
|
+ " for part, part_metrics in metrics.items():\n",
|
|
|
+ " print(f'[{part:<5}]', lib.make_summary(part_metrics))\n",
|
|
|
+ " return metrics, predictions\n",
|
|
|
+ "\n",
|
|
|
+ "def save_checkpoint(final):\n",
|
|
|
+ " torch.save(\n",
|
|
|
+ " {\n",
|
|
|
+ " 'model': model.state_dict(),\n",
|
|
|
+ " 'optimizer': optimizer.state_dict(),\n",
|
|
|
+ " 'stream': stream.state_dict(),\n",
|
|
|
+ " 'random_state': zero.get_random_state(),\n",
|
|
|
+ " },\n",
|
|
|
+ " checkpoint_path,\n",
|
|
|
+ " )"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "214a2935",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Load model"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 6,
|
|
|
+ "id": "3be456cc",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[test ] Accuracy = 0.954\n",
|
|
|
+ "[test_backdoor] Accuracy = 0.997\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "zero.set_randomness(config['seed'])\n",
|
|
|
+ "\n",
|
|
|
+ "# Load best checkpoint\n",
|
|
|
+ "model.load_state_dict(torch.load(checkpoint_path)['model'])\n",
|
|
|
+ "metrics, predictions = evaluate(['test', 'test_backdoor'])"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "c87fb163",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "# Save activations"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 7,
|
|
|
+ "id": "146c8957",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "registered: layers.0.attention.W_q : Linear(in_features=424, out_features=424, bias=True)\n",
|
|
|
+ "registered: layers.0.attention.W_k : Linear(in_features=424, out_features=424, bias=True)\n",
|
|
|
+ "registered: layers.0.attention.W_v : Linear(in_features=424, out_features=424, bias=True)\n",
|
|
|
+ "registered: layers.0.attention.W_out : Linear(in_features=424, out_features=424, bias=True)\n",
|
|
|
+ "registered: layers.0.linear0 : Linear(in_features=424, out_features=1130, bias=True)\n",
|
|
|
+ "registered: layers.0.linear1 : Linear(in_features=565, out_features=424, bias=True)\n",
|
|
|
+ "registered: layers.1.attention.W_q : Linear(in_features=424, out_features=424, bias=True)\n",
|
|
|
+ "registered: layers.1.attention.W_k : Linear(in_features=424, out_features=424, bias=True)\n",
|
|
|
+ "registered: layers.1.attention.W_v : Linear(in_features=424, out_features=424, bias=True)\n",
|
|
|
+ "registered: layers.1.attention.W_out : Linear(in_features=424, out_features=424, bias=True)\n",
|
|
|
+ "registered: layers.1.linear0 : Linear(in_features=424, out_features=1130, bias=True)\n",
|
|
|
+ "registered: layers.1.linear1 : Linear(in_features=565, out_features=424, bias=True)\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "activations_out = {}\n",
|
|
|
+ "count = 0\n",
|
|
|
+ "fails = 0\n",
|
|
|
+ "def save_activation(name, mod, inp, out):\n",
|
|
|
+ " if name not in activations_out:\n",
|
|
|
+ " activations_out[name] = out.cpu().detach().numpy()\n",
|
|
|
+ " \n",
|
|
|
+ " global fails\n",
|
|
|
+ " # Will fail if dataset not divisiable by batch size, try except to skip the last iteration\n",
|
|
|
+ " try:\n",
|
|
|
+ " # Save the activations for the input neurons\n",
|
|
|
+ " activations_out[name] += out.cpu().detach().numpy()\n",
|
|
|
+ " \n",
|
|
|
+ " if \"layers.0.linear0\" in name:\n",
|
|
|
+ " global count\n",
|
|
|
+ " count += 1\n",
|
|
|
+ " except:\n",
|
|
|
+ " fails+=1\n",
|
|
|
+ " \n",
|
|
|
+ "hooks = []\n",
|
|
|
+ "for name, m in model.named_modules():\n",
|
|
|
+ " #print(name) # -> tabnet.final_mapping is the layer we are interested in\n",
|
|
|
+ " if \"W_\" in name or \"linear\" in name:\n",
|
|
|
+ " print(\"registered:\", name, \":\", m)\n",
|
|
|
+ " hooks.append(m.register_forward_hook(partial(save_activation, name)))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 8,
|
|
|
+ "id": "54234e0e",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "0\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "print(len(activations_out))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 9,
|
|
|
+ "id": "9351dbce",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[test ] Accuracy = 0.954\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "_ = evaluate(['test'])"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 10,
|
|
|
+ "id": "09857b48",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "for hook in hooks:\n",
|
|
|
+ " hook.remove()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 11,
|
|
|
+ "id": "6f6bf9ee",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "113\n",
|
|
|
+ "12\n",
|
|
|
+ "12\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "print(count)\n",
|
|
|
+ "\n",
|
|
|
+ "# fails should be equal to number of layers (12), or 0 if data is dividable by batch size\n",
|
|
|
+ "print(len(activations_out))\n",
|
|
|
+ "print(fails)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 12,
|
|
|
+ "id": "b796ee9a",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Calculate mean activation value (although not really needed for ranking)\n",
|
|
|
+ "for x in activations_out:\n",
|
|
|
+ " activations_out[x] = activations_out[x]/count"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 13,
|
|
|
+ "id": "a9dc87ce",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "layers.0.attention.W_q\n",
|
|
|
+ "(1024, 55, 424)\n",
|
|
|
+ "\n",
|
|
|
+ "layers.0.attention.W_k\n",
|
|
|
+ "(1024, 55, 424)\n",
|
|
|
+ "\n",
|
|
|
+ "layers.0.attention.W_v\n",
|
|
|
+ "(1024, 55, 424)\n",
|
|
|
+ "\n",
|
|
|
+ "layers.0.attention.W_out\n",
|
|
|
+ "(1024, 55, 424)\n",
|
|
|
+ "\n",
|
|
|
+ "layers.0.linear0\n",
|
|
|
+ "(1024, 55, 1130)\n",
|
|
|
+ "\n",
|
|
|
+ "layers.0.linear1\n",
|
|
|
+ "(1024, 55, 424)\n",
|
|
|
+ "\n",
|
|
|
+ "layers.1.attention.W_q\n",
|
|
|
+ "(1024, 1, 424)\n",
|
|
|
+ "\n",
|
|
|
+ "layers.1.attention.W_k\n",
|
|
|
+ "(1024, 55, 424)\n",
|
|
|
+ "\n",
|
|
|
+ "layers.1.attention.W_v\n",
|
|
|
+ "(1024, 55, 424)\n",
|
|
|
+ "\n",
|
|
|
+ "layers.1.attention.W_out\n",
|
|
|
+ "(1024, 1, 424)\n",
|
|
|
+ "\n",
|
|
|
+ "layers.1.linear0\n",
|
|
|
+ "(1024, 1, 1130)\n",
|
|
|
+ "\n",
|
|
|
+ "layers.1.linear1\n",
|
|
|
+ "(1024, 1, 424)\n",
|
|
|
+ "\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "for x in activations_out:\n",
|
|
|
+ " print(x)\n",
|
|
|
+ " print(activations_out[x].shape)\n",
|
|
|
+ " print()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 14,
|
|
|
+ "id": "ecee2260",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Average over batch and second dimension\n",
|
|
|
+ "for x in activations_out:\n",
|
|
|
+ " activations_out[x] = activations_out[x].mean(axis=0).mean(axis=0)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 15,
|
|
|
+ "id": "0ccc53f7",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "layers.0.attention.W_q\n",
|
|
|
+ "(424,)\n",
|
|
|
+ "layers.0.attention.W_k\n",
|
|
|
+ "(424,)\n",
|
|
|
+ "layers.0.attention.W_v\n",
|
|
|
+ "(424,)\n",
|
|
|
+ "layers.0.attention.W_out\n",
|
|
|
+ "(424,)\n",
|
|
|
+ "layers.0.linear0\n",
|
|
|
+ "(1130,)\n",
|
|
|
+ "layers.0.linear1\n",
|
|
|
+ "(424,)\n",
|
|
|
+ "layers.1.attention.W_q\n",
|
|
|
+ "(424,)\n",
|
|
|
+ "layers.1.attention.W_k\n",
|
|
|
+ "(424,)\n",
|
|
|
+ "layers.1.attention.W_v\n",
|
|
|
+ "(424,)\n",
|
|
|
+ "layers.1.attention.W_out\n",
|
|
|
+ "(424,)\n",
|
|
|
+ "layers.1.linear0\n",
|
|
|
+ "(1130,)\n",
|
|
|
+ "layers.1.linear1\n",
|
|
|
+ "(424,)\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "for x in activations_out:\n",
|
|
|
+ " print(x)\n",
|
|
|
+ " print(activations_out[x].shape)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 16,
|
|
|
+ "id": "1beca88e",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[test ] Accuracy = 0.954\n",
|
|
|
+ "[test_backdoor] Accuracy = 0.997\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "metrics = evaluate(['test', 'test_backdoor'])"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 17,
|
|
|
+ "id": "3e8f4a93",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "0.9974191629339306\n",
|
|
|
+ "0.9541836269287368\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "print(metrics[0]['test_backdoor']['accuracy'])\n",
|
|
|
+ "print(metrics[0]['test']['accuracy'])"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 18,
|
|
|
+ "id": "67f9462d",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Argsort activations for each layer\n",
|
|
|
+ "argsortActivations_out = {}\n",
|
|
|
+ "for n in activations_out:\n",
|
|
|
+ " argsortActivations_out[n] = np.argsort(activations_out[n])"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 19,
|
|
|
+ "id": "890bbbda",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "layers.0.attention.W_q.weight torch.Size([424, 424])\n",
|
|
|
+ "layers.0.attention.W_k.weight torch.Size([424, 424])\n",
|
|
|
+ "layers.0.attention.W_v.weight torch.Size([424, 424])\n",
|
|
|
+ "layers.0.attention.W_out.weight torch.Size([424, 424])\n",
|
|
|
+ "layers.0.linear0.weight torch.Size([1130, 424])\n",
|
|
|
+ "layers.0.linear1.weight torch.Size([424, 565])\n",
|
|
|
+ "layers.1.attention.W_q.weight torch.Size([424, 424])\n",
|
|
|
+ "layers.1.attention.W_k.weight torch.Size([424, 424])\n",
|
|
|
+ "layers.1.attention.W_v.weight torch.Size([424, 424])\n",
|
|
|
+ "layers.1.attention.W_out.weight torch.Size([424, 424])\n",
|
|
|
+ "layers.1.linear0.weight torch.Size([1130, 424])\n",
|
|
|
+ "layers.1.linear1.weight torch.Size([424, 565])\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "for name, m in model.named_parameters():\n",
|
|
|
+ " if \"W_\" in name or \"linear\" in name:\n",
|
|
|
+ " if \"weight\" in name:\n",
|
|
|
+ " print(name, m.shape)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "76a2ac3b",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Prune"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 20,
|
|
|
+ "id": "f627749f",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def pruneWithTreshold(argsortActivations, name, th=1, transpose=False, dim2=1):\n",
|
|
|
+ " x = torch.tensor(argsortActivations[name].copy())\n",
|
|
|
+ " x[x>=th] = 99999\n",
|
|
|
+ " x[x<th] = 0\n",
|
|
|
+ " x[x==99999] = 1\n",
|
|
|
+ " \n",
|
|
|
+ " b = np.stack((x,) * dim2, axis=-1)\n",
|
|
|
+ " \n",
|
|
|
+ " if transpose:\n",
|
|
|
+ " b = torch.tensor(b.T)\n",
|
|
|
+ " else:\n",
|
|
|
+ " b = torch.tensor(b)\n",
|
|
|
+ " \n",
|
|
|
+ " #print(b.shape)\n",
|
|
|
+ " return b"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 21,
|
|
|
+ "id": "1e059fd0",
|
|
|
+ "metadata": {
|
|
|
+ "scrolled": false
|
|
|
+ },
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[test ] Accuracy = 0.702\n",
|
|
|
+ "[test_backdoor] Accuracy = 0.017\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "i = 212 # obtained from \"Prune\" notebook\n",
|
|
|
+ "\n",
|
|
|
+ " \n",
|
|
|
+ "prune.custom_from_mask(\n",
|
|
|
+ " module = model.layers[0].linear0,\n",
|
|
|
+ " name = 'weight',\n",
|
|
|
+ " mask = pruneWithTreshold(argsortActivations_out, \"layers.0.linear0\", i, False, 424).to(\"cuda:0\")\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "prune.custom_from_mask(\n",
|
|
|
+ " module = model.layers[0].linear1,\n",
|
|
|
+ " name = 'weight',\n",
|
|
|
+ " mask = pruneWithTreshold(argsortActivations_out, \"layers.0.linear1\", i, False, 565).to(\"cuda:0\")\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "prune.custom_from_mask(\n",
|
|
|
+ " module = model.layers[1].linear0,\n",
|
|
|
+ " name = 'weight',\n",
|
|
|
+ " mask = pruneWithTreshold(argsortActivations_out, \"layers.1.linear0\", i, False, 424).to(\"cuda:0\")\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "prune.custom_from_mask(\n",
|
|
|
+ " module = model.layers[1].linear1,\n",
|
|
|
+ " name = 'weight',\n",
|
|
|
+ " mask = pruneWithTreshold(argsortActivations_out, \"layers.1.linear1\", i, False, 565).to(\"cuda:0\")\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "metrics = evaluate(['test', 'test_backdoor'])"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "48071b83",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Finetune"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 22,
|
|
|
+ "id": "4a7505f5",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "Using device: cuda:0\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "DATAPATH = \"data/covtypeFTT-1F-OOB-finetune/\"\n",
|
|
|
+ "# FTtransformer config\n",
|
|
|
+ "config = {\n",
|
|
|
+ " 'data': {\n",
|
|
|
+ " 'normalization': 'standard',\n",
|
|
|
+ " 'path': DATAPATH\n",
|
|
|
+ " }, \n",
|
|
|
+ " 'model': {\n",
|
|
|
+ " 'activation': 'reglu', \n",
|
|
|
+ " 'attention_dropout': 0.03815883962184247, \n",
|
|
|
+ " 'd_ffn_factor': 1.333333333333333, \n",
|
|
|
+ " 'd_token': 424, \n",
|
|
|
+ " 'ffn_dropout': 0.2515503440562596, \n",
|
|
|
+ " 'initialization': 'kaiming', \n",
|
|
|
+ " 'n_heads': 8, \n",
|
|
|
+ " 'n_layers': 2, \n",
|
|
|
+ " 'prenormalization': True, \n",
|
|
|
+ " 'residual_dropout': 0.0, \n",
|
|
|
+ " 'token_bias': True, \n",
|
|
|
+ " 'kv_compression': None, \n",
|
|
|
+ " 'kv_compression_sharing': None\n",
|
|
|
+ " }, \n",
|
|
|
+ " 'seed': 0, \n",
|
|
|
+ " 'training': {\n",
|
|
|
+ " 'batch_size': 1024, \n",
|
|
|
+ " 'eval_batch_size': 1024, \n",
|
|
|
+ " 'lr': 3.762989816330166e-05, \n",
|
|
|
+ " 'n_epochs': EPOCHS, \n",
|
|
|
+ " 'device': DEVICE, \n",
|
|
|
+ " 'optimizer': 'adamw', \n",
|
|
|
+ " 'patience': 16, \n",
|
|
|
+ " 'weight_decay': 0.0001239780004929955\n",
|
|
|
+ " }\n",
|
|
|
+ "}\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "zero.set_randomness(config['seed'])\n",
|
|
|
+ "dataset_dir = config['data']['path']\n",
|
|
|
+ "\n",
|
|
|
+ "D = lib.Dataset.from_dir(dataset_dir)\n",
|
|
|
+ "X = D.build_X(\n",
|
|
|
+ " normalization=config['data'].get('normalization'),\n",
|
|
|
+ " num_nan_policy='mean',\n",
|
|
|
+ " cat_nan_policy='new',\n",
|
|
|
+ " cat_policy=config['data'].get('cat_policy', 'indices'),\n",
|
|
|
+ " cat_min_frequency=config['data'].get('cat_min_frequency', 0.0),\n",
|
|
|
+ " seed=config['seed'],\n",
|
|
|
+ ")\n",
|
|
|
+ "if not isinstance(X, tuple):\n",
|
|
|
+ " X = (X, None)\n",
|
|
|
+ "\n",
|
|
|
+ "Y, y_info = D.build_y(config['data'].get('y_policy'))\n",
|
|
|
+ "\n",
|
|
|
+ "X = tuple(None if x is None else lib.to_tensors(x) for x in X)\n",
|
|
|
+ "Y = lib.to_tensors(Y)\n",
|
|
|
+ "device = torch.device(config['training']['device'])\n",
|
|
|
+ "print(\"Using device:\", config['training']['device'])\n",
|
|
|
+ "if device.type != 'cpu':\n",
|
|
|
+ " X = tuple(\n",
|
|
|
+ " None if x is None else {k: v.to(device) for k, v in x.items()} for x in X\n",
|
|
|
+ " )\n",
|
|
|
+ " Y_device = {k: v.to(device) for k, v in Y.items()}\n",
|
|
|
+ "else:\n",
|
|
|
+ " Y_device = Y\n",
|
|
|
+ "X_num, X_cat = X\n",
|
|
|
+ "del X\n",
|
|
|
+ "if not D.is_multiclass:\n",
|
|
|
+ " Y_device = {k: v.float() for k, v in Y_device.items()}\n",
|
|
|
+ "\n",
|
|
|
+ "train_size = D.size(lib.TRAIN)\n",
|
|
|
+ "batch_size = config['training']['batch_size']\n",
|
|
|
+ "epoch_size = math.ceil(train_size / batch_size)\n",
|
|
|
+ "eval_batch_size = config['training']['eval_batch_size']\n",
|
|
|
+ "chunk_size = None\n",
|
|
|
+ "\n",
|
|
|
+ "loss_fn = (\n",
|
|
|
+ " F.binary_cross_entropy_with_logits\n",
|
|
|
+ " if D.is_binclass\n",
|
|
|
+ " else F.cross_entropy\n",
|
|
|
+ " if D.is_multiclass\n",
|
|
|
+ " else F.mse_loss\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "# Do not define new model, instead use pruned model\n",
|
|
|
+ "#model = Transformer(\n",
|
|
|
+ "# d_numerical=0 if X_num is None else X_num['train'].shape[1],\n",
|
|
|
+ "# categories=lib.get_categories(X_cat),\n",
|
|
|
+ "# d_out=D.info['n_classes'] if D.is_multiclass else 1,\n",
|
|
|
+ "# **config['model'],\n",
|
|
|
+ "#).to(device)\n",
|
|
|
+ "\n",
|
|
|
+ "def needs_wd(name):\n",
|
|
|
+ " return all(x not in name for x in ['tokenizer', '.norm', '.bias'])\n",
|
|
|
+ "\n",
|
|
|
+ "for x in ['tokenizer', '.norm', '.bias']:\n",
|
|
|
+ " assert any(x in a for a in (b[0] for b in model.named_parameters()))\n",
|
|
|
+ "parameters_with_wd = [v for k, v in model.named_parameters() if needs_wd(k)]\n",
|
|
|
+ "parameters_without_wd = [v for k, v in model.named_parameters() if not needs_wd(k)]\n",
|
|
|
+ "optimizer = lib.make_optimizer(\n",
|
|
|
+ " config['training']['optimizer'],\n",
|
|
|
+ " (\n",
|
|
|
+ " [\n",
|
|
|
+ " {'params': parameters_with_wd},\n",
|
|
|
+ " {'params': parameters_without_wd, 'weight_decay': 0.0},\n",
|
|
|
+ " ]\n",
|
|
|
+ " ),\n",
|
|
|
+ " config['training']['lr'],\n",
|
|
|
+ " config['training']['weight_decay'],\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "stream = zero.Stream(lib.IndexLoader(train_size, batch_size, True, device))\n",
|
|
|
+ "progress = zero.ProgressTracker(config['training']['patience'])\n",
|
|
|
+ "training_log = {lib.TRAIN: [], lib.VAL: [], lib.TEST: []}\n",
|
|
|
+ "timer = zero.Timer()\n",
|
|
|
+ "output = \"Checkpoints\"\n",
|
|
|
+ "\n",
|
|
|
+ "def print_epoch_info():\n",
|
|
|
+ " print(f'\\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())} | {output}')\n",
|
|
|
+ " print(\n",
|
|
|
+ " ' | '.join(\n",
|
|
|
+ " f'{k} = {v}'\n",
|
|
|
+ " for k, v in {\n",
|
|
|
+ " 'lr': lib.get_lr(optimizer),\n",
|
|
|
+ " 'batch_size': batch_size,\n",
|
|
|
+ " 'chunk_size': chunk_size,\n",
|
|
|
+ " }.items()\n",
|
|
|
+ " )\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ "def apply_model(part, idx):\n",
|
|
|
+ " return model(\n",
|
|
|
+ " None if X_num is None else X_num[part][idx],\n",
|
|
|
+ " None if X_cat is None else X_cat[part][idx],\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ "@torch.no_grad()\n",
|
|
|
+ "def evaluate(parts):\n",
|
|
|
+ " eval_batch_size = config['training']['eval_batch_size']\n",
|
|
|
+ " model.eval()\n",
|
|
|
+ " metrics = {}\n",
|
|
|
+ " predictions = {}\n",
|
|
|
+ " for part in parts:\n",
|
|
|
+ " while eval_batch_size:\n",
|
|
|
+ " try:\n",
|
|
|
+ " predictions[part] = (\n",
|
|
|
+ " torch.cat(\n",
|
|
|
+ " [\n",
|
|
|
+ " apply_model(part, idx)\n",
|
|
|
+ " for idx in lib.IndexLoader(\n",
|
|
|
+ " D.size(part), eval_batch_size, False, device\n",
|
|
|
+ " )\n",
|
|
|
+ " ]\n",
|
|
|
+ " )\n",
|
|
|
+ " .cpu()\n",
|
|
|
+ " .numpy()\n",
|
|
|
+ " )\n",
|
|
|
+ " except RuntimeError as err:\n",
|
|
|
+ " if not lib.is_oom_exception(err):\n",
|
|
|
+ " raise\n",
|
|
|
+ " eval_batch_size //= 2\n",
|
|
|
+ " print('New eval batch size:', eval_batch_size)\n",
|
|
|
+ " else:\n",
|
|
|
+ " break\n",
|
|
|
+ " if not eval_batch_size:\n",
|
|
|
+ " RuntimeError('Not enough memory even for eval_batch_size=1')\n",
|
|
|
+ " metrics[part] = lib.calculate_metrics(\n",
|
|
|
+ " D.info['task_type'],\n",
|
|
|
+ " Y[part].numpy(), # type: ignore[code]\n",
|
|
|
+ " predictions[part], # type: ignore[code]\n",
|
|
|
+ " 'logits',\n",
|
|
|
+ " y_info,\n",
|
|
|
+ " )\n",
|
|
|
+ " for part, part_metrics in metrics.items():\n",
|
|
|
+ " print(f'[{part:<5}]', lib.make_summary(part_metrics))\n",
|
|
|
+ " return metrics, predictions\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 23,
|
|
|
+ "id": "a0d986d3",
|
|
|
+ "metadata": {
|
|
|
+ "scrolled": false
|
|
|
+ },
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 0%| | 0/285 [00:00<?, ?it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 1 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " 7%|██▊ | 19/285 [00:03<00:57, 4.60it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.519\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 7%|██▉ | 20/285 [00:10<09:13, 2.09s/it]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[val ] Accuracy = 0.877\n",
|
|
|
+ "[test ] Accuracy = 0.868\n",
|
|
|
+ "New best epoch!\n",
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 2 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " 13%|█████▌ | 38/285 [00:14<00:54, 4.52it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.35\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 14%|█████▋ | 39/285 [00:20<08:33, 2.09s/it]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[val ] Accuracy = 0.898\n",
|
|
|
+ "[test ] Accuracy = 0.892\n",
|
|
|
+ "New best epoch!\n",
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 3 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " 20%|████████▍ | 57/285 [00:24<00:50, 4.54it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.31\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 20%|████████▌ | 58/285 [00:31<07:53, 2.09s/it]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[val ] Accuracy = 0.910\n",
|
|
|
+ "[test ] Accuracy = 0.904\n",
|
|
|
+ "New best epoch!\n",
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 4 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " 27%|███████████▏ | 76/285 [00:34<00:46, 4.54it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.283\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 27%|███████████▎ | 77/285 [00:41<07:22, 2.13s/it]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[val ] Accuracy = 0.914\n",
|
|
|
+ "[test ] Accuracy = 0.908\n",
|
|
|
+ "New best epoch!\n",
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 5 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " 33%|██████████████ | 95/285 [00:45<00:42, 4.48it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.276\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 34%|██████████████▏ | 96/285 [00:51<06:35, 2.09s/it]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[val ] Accuracy = 0.919\n",
|
|
|
+ "[test ] Accuracy = 0.913\n",
|
|
|
+ "New best epoch!\n",
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 6 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " 40%|████████████████▍ | 114/285 [00:55<00:39, 4.37it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.266\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 40%|████████████████▌ | 115/285 [01:02<06:07, 2.16s/it]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[val ] Accuracy = 0.922\n",
|
|
|
+ "[test ] Accuracy = 0.916\n",
|
|
|
+ "New best epoch!\n",
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 7 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " 47%|███████████████████▏ | 133/285 [01:06<00:33, 4.48it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.248\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 47%|███████████████████▎ | 134/285 [01:13<05:22, 2.14s/it]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[val ] Accuracy = 0.924\n",
|
|
|
+ "[test ] Accuracy = 0.916\n",
|
|
|
+ "New best epoch!\n",
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 8 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " 53%|█████████████████████▊ | 152/285 [01:17<00:29, 4.51it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.242\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 54%|██████████████████████ | 153/285 [01:23<04:38, 2.11s/it]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[val ] Accuracy = 0.926\n",
|
|
|
+ "[test ] Accuracy = 0.918\n",
|
|
|
+ "New best epoch!\n",
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 9 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " 60%|████████████████████████▌ | 171/285 [01:27<00:25, 4.46it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.233\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 60%|████████████████████████▋ | 172/285 [01:34<04:02, 2.15s/it]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[val ] Accuracy = 0.924\n",
|
|
|
+ "[test ] Accuracy = 0.918\n",
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 10 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " 67%|███████████████████████████▎ | 190/285 [01:38<00:21, 4.50it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.23\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 67%|███████████████████████████▍ | 191/285 [01:44<03:22, 2.15s/it]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[val ] Accuracy = 0.927\n",
|
|
|
+ "[test ] Accuracy = 0.920\n",
|
|
|
+ "New best epoch!\n",
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 11 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " 73%|██████████████████████████████ | 209/285 [01:48<00:17, 4.45it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.221\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 74%|██████████████████████████████▏ | 210/285 [01:55<02:37, 2.11s/it]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[val ] Accuracy = 0.928\n",
|
|
|
+ "[test ] Accuracy = 0.920\n",
|
|
|
+ "New best epoch!\n",
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 12 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " 80%|████████████████████████████████▊ | 228/285 [01:59<00:12, 4.42it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.217\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 80%|████████████████████████████████▉ | 229/285 [02:05<01:59, 2.13s/it]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[val ] Accuracy = 0.927\n",
|
|
|
+ "[test ] Accuracy = 0.921\n",
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 13 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " 87%|███████████████████████████████████▌ | 247/285 [02:09<00:08, 4.51it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.216\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 87%|███████████████████████████████████▋ | 248/285 [02:16<01:18, 2.13s/it]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[val ] Accuracy = 0.926\n",
|
|
|
+ "[test ] Accuracy = 0.920\n",
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 14 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " 93%|██████████████████████████████████████▎ | 266/285 [02:20<00:04, 4.42it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.209\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\r",
|
|
|
+ " 94%|██████████████████████████████████████▍ | 267/285 [02:27<00:38, 2.13s/it]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[val ] Accuracy = 0.927\n",
|
|
|
+ "[test ] Accuracy = 0.921\n",
|
|
|
+ "\n",
|
|
|
+ ">>> Epoch 15 | 0:00:00\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "100%|█████████████████████████████████████████| 285/285 [02:31<00:00, 4.55it/s]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[train] loss = 0.206\n",
|
|
|
+ "[val ] Accuracy = 0.928\n",
|
|
|
+ "[test ] Accuracy = 0.921\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "finetuneEpochs = 15\n",
|
|
|
+ "for epoch in stream.epochs(finetuneEpochs):\n",
|
|
|
+ " print(f'\\n>>> Epoch {stream.epoch} | {lib.format_seconds(timer())}')\n",
|
|
|
+ " model.train()\n",
|
|
|
+ " epoch_losses = []\n",
|
|
|
+ " for batch_idx in epoch:\n",
|
|
|
+ " loss, new_chunk_size = lib.train_with_auto_virtual_batch(\n",
|
|
|
+ " optimizer,\n",
|
|
|
+ " loss_fn,\n",
|
|
|
+ " lambda x: (apply_model(lib.TRAIN, x), Y_device[lib.TRAIN][x]),\n",
|
|
|
+ " batch_idx,\n",
|
|
|
+ " chunk_size or batch_size,\n",
|
|
|
+ " )\n",
|
|
|
+ " epoch_losses.append(loss.detach())\n",
|
|
|
+ " if new_chunk_size and new_chunk_size < (chunk_size or batch_size):\n",
|
|
|
+ " print('New chunk size:', chunk_size)\n",
|
|
|
+ " epoch_losses = torch.stack(epoch_losses).tolist()\n",
|
|
|
+ " print(f'[{lib.TRAIN}] loss = {round(sum(epoch_losses) / len(epoch_losses), 3)}')\n",
|
|
|
+ "\n",
|
|
|
+ " metrics, predictions = evaluate([lib.VAL, lib.TEST])\n",
|
|
|
+ " for k, v in metrics.items():\n",
|
|
|
+ " training_log[k].append(v)\n",
|
|
|
+ " progress.update(metrics[lib.VAL]['score'])\n",
|
|
|
+ "\n",
|
|
|
+ " if progress.success:\n",
|
|
|
+ " print('New best epoch!')\n",
|
|
|
+ " #save_checkpoint(False)\n",
|
|
|
+ "\n",
|
|
|
+ " elif progress.fail:\n",
|
|
|
+ " break"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "d835a6b9",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Final result on finetuned model"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 24,
|
|
|
+ "id": "9419de73",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[test ] Accuracy = 0.921\n",
|
|
|
+ "[test_backdoor] Accuracy = 0.042\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "metrics = evaluate(['test', 'test_backdoor'])"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3 (ipykernel)",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.10.6"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 5
|
|
|
+}
|