|
@@ -0,0 +1,838 @@
|
|
|
+from __future__ import absolute_import, division, print_function
|
|
|
+
|
|
|
+import math
|
|
|
+import os
|
|
|
+import typing as ty
|
|
|
+from copy import deepcopy
|
|
|
+
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+import torch.nn.functional as F
|
|
|
+import torch.optim as optim
|
|
|
+import zero
|
|
|
+from torch import Tensor
|
|
|
+
|
|
|
+
|
|
|
+class IndexLoader:
|
|
|
+ def __init__(
|
|
|
+ self, train_size: int, batch_size: int, shuffle: bool, device: torch.device
|
|
|
+ ) -> None:
|
|
|
+ self._train_size = train_size
|
|
|
+ self._batch_size = batch_size
|
|
|
+ self._shuffle = shuffle
|
|
|
+ self._device = device
|
|
|
+
|
|
|
+ def __len__(self) -> int:
|
|
|
+ return math.ceil(self._train_size / self._batch_size)
|
|
|
+
|
|
|
+ def __iter__(self):
|
|
|
+ indices = list(
|
|
|
+ zero.iloader(self._train_size, self._batch_size, shuffle=self._shuffle)
|
|
|
+ )
|
|
|
+ return iter(torch.cat(indices).to(self._device).split(self._batch_size))
|
|
|
+
|
|
|
+
|
|
|
+class Lambda(nn.Module):
|
|
|
+ def __init__(self, f: ty.Callable) -> None:
|
|
|
+ super().__init__()
|
|
|
+ self.f = f
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ return self.f(x)
|
|
|
+
|
|
|
+
|
|
|
+# Source: https://github.com/bzhangGo/rmsnorm
|
|
|
+# NOTE: eps is changed to 1e-5
|
|
|
+class RMSNorm(nn.Module):
|
|
|
+ def __init__(self, d, p=-1.0, eps=1e-5, bias=False):
|
|
|
+ """Root Mean Square Layer Normalization
|
|
|
+
|
|
|
+ :param d: model size
|
|
|
+ :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
|
|
|
+ :param eps: epsilon value, default 1e-8
|
|
|
+ :param bias: whether use bias term for RMSNorm, disabled by
|
|
|
+ default because RMSNorm doesn't enforce re-centering invariance.
|
|
|
+ """
|
|
|
+ super(RMSNorm, self).__init__()
|
|
|
+
|
|
|
+ self.eps = eps
|
|
|
+ self.d = d
|
|
|
+ self.p = p
|
|
|
+ self.bias = bias
|
|
|
+
|
|
|
+ self.scale = nn.Parameter(torch.ones(d))
|
|
|
+ self.register_parameter("scale", self.scale)
|
|
|
+
|
|
|
+ if self.bias:
|
|
|
+ self.offset = nn.Parameter(torch.zeros(d))
|
|
|
+ self.register_parameter("offset", self.offset)
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ if self.p < 0.0 or self.p > 1.0:
|
|
|
+ norm_x = x.norm(2, dim=-1, keepdim=True)
|
|
|
+ d_x = self.d
|
|
|
+ else:
|
|
|
+ partial_size = int(self.d * self.p)
|
|
|
+ partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)
|
|
|
+
|
|
|
+ norm_x = partial_x.norm(2, dim=-1, keepdim=True)
|
|
|
+ d_x = partial_size
|
|
|
+
|
|
|
+ rms_x = norm_x * d_x ** (-1.0 / 2)
|
|
|
+ x_normed = x / (rms_x + self.eps)
|
|
|
+
|
|
|
+ if self.bias:
|
|
|
+ return self.scale * x_normed + self.offset
|
|
|
+
|
|
|
+ return self.scale * x_normed
|
|
|
+
|
|
|
+
|
|
|
+class ScaleNorm(nn.Module):
|
|
|
+ """
|
|
|
+ Sources:
|
|
|
+ - https://github.com/tnq177/transformers_without_tears/blob/25026061979916afb193274438f7097945acf9bc/layers.py#L132
|
|
|
+ - https://github.com/tnq177/transformers_without_tears/blob/6b2726cd9e6e642d976ae73b9f696d9d7ff4b395/layers.py#L157
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, d: int, eps: float = 1e-5, clamp: bool = False) -> None:
|
|
|
+ super(ScaleNorm, self).__init__()
|
|
|
+ self.scale = nn.Parameter(torch.tensor(d ** 0.5))
|
|
|
+ self.eps = eps
|
|
|
+ self.clamp = clamp
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ norms = torch.norm(x, dim=-1, keepdim=True)
|
|
|
+ norms = norms.clamp(min=self.eps) if self.clamp else norms + self.eps
|
|
|
+ return self.scale * x / norms
|
|
|
+
|
|
|
+
|
|
|
+def reglu(x: Tensor) -> Tensor:
|
|
|
+ a, b = x.chunk(2, dim=-1)
|
|
|
+ return a * F.relu(b)
|
|
|
+
|
|
|
+
|
|
|
+def geglu(x: Tensor) -> Tensor:
|
|
|
+ a, b = x.chunk(2, dim=-1)
|
|
|
+ return a * F.gelu(b)
|
|
|
+
|
|
|
+
|
|
|
+class ReGLU(nn.Module):
|
|
|
+ def forward(self, x: Tensor) -> Tensor:
|
|
|
+ return reglu(x)
|
|
|
+
|
|
|
+
|
|
|
+class GEGLU(nn.Module):
|
|
|
+ def forward(self, x: Tensor) -> Tensor:
|
|
|
+ return geglu(x)
|
|
|
+
|
|
|
+
|
|
|
+def make_optimizer(
|
|
|
+ optimizer: str,
|
|
|
+ parameter_groups,
|
|
|
+ lr: float,
|
|
|
+ weight_decay: float,
|
|
|
+) -> optim.Optimizer:
|
|
|
+ Optimizer = {
|
|
|
+ 'adabelief': AdaBelief,
|
|
|
+ 'adam': optim.Adam,
|
|
|
+ 'adamw': optim.AdamW,
|
|
|
+ 'radam': RAdam,
|
|
|
+ 'sgd': optim.SGD,
|
|
|
+ }[optimizer]
|
|
|
+ momentum = (0.9,) if Optimizer is optim.SGD else ()
|
|
|
+ return Optimizer(parameter_groups, lr, *momentum, weight_decay=weight_decay)
|
|
|
+
|
|
|
+
|
|
|
+def make_lr_schedule(
|
|
|
+ optimizer: optim.Optimizer,
|
|
|
+ lr: float,
|
|
|
+ epoch_size: int,
|
|
|
+ lr_schedule: ty.Optional[ty.Dict[str, ty.Any]],
|
|
|
+) -> ty.Tuple[
|
|
|
+ ty.Optional[optim.lr_scheduler._LRScheduler],
|
|
|
+ ty.Dict[str, ty.Any],
|
|
|
+ ty.Optional[int],
|
|
|
+]:
|
|
|
+ if lr_schedule is None:
|
|
|
+ lr_schedule = {'type': 'constant'}
|
|
|
+ lr_scheduler = None
|
|
|
+ n_warmup_steps = None
|
|
|
+ if lr_schedule['type'] in ['transformer', 'linear_warmup']:
|
|
|
+ n_warmup_steps = (
|
|
|
+ lr_schedule['n_warmup_steps']
|
|
|
+ if 'n_warmup_steps' in lr_schedule
|
|
|
+ else lr_schedule['n_warmup_epochs'] * epoch_size
|
|
|
+ )
|
|
|
+ elif lr_schedule['type'] == 'cyclic':
|
|
|
+ lr_scheduler = optim.lr_scheduler.CyclicLR(
|
|
|
+ optimizer,
|
|
|
+ base_lr=lr,
|
|
|
+ max_lr=lr_schedule['max_lr'],
|
|
|
+ step_size_up=lr_schedule['n_epochs_up'] * epoch_size,
|
|
|
+ step_size_down=lr_schedule['n_epochs_down'] * epoch_size,
|
|
|
+ mode=lr_schedule['mode'],
|
|
|
+ gamma=lr_schedule.get('gamma', 1.0),
|
|
|
+ cycle_momentum=False,
|
|
|
+ )
|
|
|
+ return lr_scheduler, lr_schedule, n_warmup_steps
|
|
|
+
|
|
|
+
|
|
|
+def get_activation_fn(name: str) -> ty.Callable[[Tensor], Tensor]:
|
|
|
+ return (
|
|
|
+ reglu
|
|
|
+ if name == 'reglu'
|
|
|
+ else geglu
|
|
|
+ if name == 'geglu'
|
|
|
+ else torch.sigmoid
|
|
|
+ if name == 'sigmoid'
|
|
|
+ else getattr(F, name)
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def get_nonglu_activation_fn(name: str) -> ty.Callable[[Tensor], Tensor]:
|
|
|
+ return (
|
|
|
+ F.relu
|
|
|
+ if name == 'reglu'
|
|
|
+ else F.gelu
|
|
|
+ if name == 'geglu'
|
|
|
+ else get_activation_fn(name)
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def load_swa_state_dict(model: nn.Module, swa_model: optim.swa_utils.AveragedModel):
|
|
|
+ state_dict = deepcopy(swa_model.state_dict())
|
|
|
+ del state_dict['n_averaged']
|
|
|
+ model.load_state_dict({k[len('module.') :]: v for k, v in state_dict.items()})
|
|
|
+
|
|
|
+
|
|
|
+def get_epoch_parameters(
|
|
|
+ train_size: int, batch_size: ty.Union[int, str]
|
|
|
+) -> ty.Tuple[int, int]:
|
|
|
+ if isinstance(batch_size, str):
|
|
|
+ if batch_size == 'v3':
|
|
|
+ batch_size = (
|
|
|
+ 256 if train_size < 50000 else 512 if train_size < 100000 else 1024
|
|
|
+ )
|
|
|
+ elif batch_size == 'v1':
|
|
|
+ batch_size = (
|
|
|
+ 16
|
|
|
+ if train_size < 1000
|
|
|
+ else 32
|
|
|
+ if train_size < 10000
|
|
|
+ else 64
|
|
|
+ if train_size < 50000
|
|
|
+ else 128
|
|
|
+ if train_size < 100000
|
|
|
+ else 256
|
|
|
+ if train_size < 200000
|
|
|
+ else 512
|
|
|
+ if train_size < 500000
|
|
|
+ else 1024
|
|
|
+ )
|
|
|
+ elif batch_size == 'v2':
|
|
|
+ batch_size = (
|
|
|
+ 512 if train_size < 100000 else 1024 if train_size < 500000 else 2048
|
|
|
+ )
|
|
|
+ return batch_size, math.ceil(train_size / batch_size) # type: ignore[code]
|
|
|
+
|
|
|
+
|
|
|
+def get_linear_warmup_lr(lr: float, n_warmup_steps: int, step: int) -> float:
|
|
|
+ assert step > 0, "1-based enumeration of steps is expected"
|
|
|
+ return min(lr, step / (n_warmup_steps + 1) * lr)
|
|
|
+
|
|
|
+
|
|
|
+def get_manual_lr(schedule: ty.List[float], epoch: int) -> float:
|
|
|
+ assert epoch > 0, "1-based enumeration of epochs is expected"
|
|
|
+ return schedule[min(epoch, len(schedule)) - 1]
|
|
|
+
|
|
|
+
|
|
|
+def get_transformer_lr(scale: float, d: int, n_warmup_steps: int, step: int) -> float:
|
|
|
+ return scale * d ** -0.5 * min(step ** -0.5, step * n_warmup_steps ** -1.5)
|
|
|
+
|
|
|
+
|
|
|
+def learn(model, optimizer, loss_fn, step, batch, star) -> ty.Tuple[Tensor, ty.Any]:
|
|
|
+ model.train()
|
|
|
+ optimizer.zero_grad()
|
|
|
+ out = step(batch)
|
|
|
+ loss = loss_fn(*out) if star else loss_fn(out)
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+ return loss, out
|
|
|
+
|
|
|
+
|
|
|
+def _learn_with_virtual_batch(
|
|
|
+ model, optimizer, loss_fn, step, batch, chunk_size
|
|
|
+) -> Tensor:
|
|
|
+ batch_size = len(batch)
|
|
|
+ if chunk_size >= batch_size:
|
|
|
+ return learn(model, optimizer, loss_fn, step, batch, True)[0]
|
|
|
+ model.train()
|
|
|
+ optimizer.zero_grad()
|
|
|
+ total_loss = None
|
|
|
+ for chunk in zero.iter_batches(batch, chunk_size):
|
|
|
+ loss = loss_fn(*step(chunk))
|
|
|
+ loss = loss * len(chunk)
|
|
|
+ loss.backward()
|
|
|
+ if total_loss is None:
|
|
|
+ total_loss = loss.detach()
|
|
|
+ else:
|
|
|
+ total_loss += loss.detach()
|
|
|
+ for x in model.parameters():
|
|
|
+ if x.grad is not None:
|
|
|
+ x.grad /= batch_size
|
|
|
+ optimizer.step()
|
|
|
+ return total_loss / batch_size
|
|
|
+
|
|
|
+
|
|
|
+def learn_with_auto_virtual_batch(
|
|
|
+ model,
|
|
|
+ optimizer,
|
|
|
+ loss_fn,
|
|
|
+ step,
|
|
|
+ batch,
|
|
|
+ batch_size_hint: int,
|
|
|
+ chunk_size: ty.Optional[int],
|
|
|
+) -> ty.Tuple[Tensor, ty.Optional[int]]:
|
|
|
+ """This is just an overcomplicated version of `train_with_auto_virtual_batch`."""
|
|
|
+ random_state = zero.get_random_state()
|
|
|
+ while chunk_size != 0:
|
|
|
+ try:
|
|
|
+ zero.set_random_state(random_state)
|
|
|
+ return (
|
|
|
+ _learn_with_virtual_batch(
|
|
|
+ model,
|
|
|
+ optimizer,
|
|
|
+ loss_fn,
|
|
|
+ step,
|
|
|
+ batch,
|
|
|
+ chunk_size or batch_size_hint,
|
|
|
+ ),
|
|
|
+ chunk_size,
|
|
|
+ )
|
|
|
+ except RuntimeError as err:
|
|
|
+ if not is_oom_exception(err):
|
|
|
+ raise
|
|
|
+ if chunk_size is None:
|
|
|
+ chunk_size = batch_size_hint
|
|
|
+ chunk_size //= 2
|
|
|
+ raise RuntimeError('Not enough memory even for batch_size=1')
|
|
|
+
|
|
|
+
|
|
|
+def train_with_auto_virtual_batch(
|
|
|
+ optimizer,
|
|
|
+ loss_fn,
|
|
|
+ step,
|
|
|
+ batch,
|
|
|
+ chunk_size: int,
|
|
|
+) -> ty.Tuple[Tensor, int]:
|
|
|
+ batch_size = len(batch)
|
|
|
+ random_state = zero.get_random_state()
|
|
|
+ while chunk_size != 0:
|
|
|
+ try:
|
|
|
+ zero.set_random_state(random_state)
|
|
|
+ optimizer.zero_grad()
|
|
|
+ if batch_size <= chunk_size:
|
|
|
+ loss = loss_fn(*step(batch))
|
|
|
+ loss.backward()
|
|
|
+ else:
|
|
|
+ loss = None
|
|
|
+ for chunk in zero.iter_batches(batch, chunk_size):
|
|
|
+ chunk_loss = loss_fn(*step(chunk))
|
|
|
+ chunk_loss = chunk_loss * (len(chunk) / batch_size)
|
|
|
+ chunk_loss.backward()
|
|
|
+ if loss is None:
|
|
|
+ loss = chunk_loss.detach()
|
|
|
+ else:
|
|
|
+ loss += chunk_loss.detach()
|
|
|
+ except RuntimeError as err:
|
|
|
+ if not is_oom_exception(err):
|
|
|
+ raise
|
|
|
+ chunk_size //= 2
|
|
|
+ else:
|
|
|
+ break
|
|
|
+ if not chunk_size:
|
|
|
+ raise RuntimeError('Not enough memory even for batch_size=1')
|
|
|
+ optimizer.step()
|
|
|
+ return loss, chunk_size # type: ignore[code]
|
|
|
+
|
|
|
+
|
|
|
+def tensor(x) -> torch.Tensor:
|
|
|
+ assert isinstance(x, torch.Tensor)
|
|
|
+ return ty.cast(torch.Tensor, x)
|
|
|
+
|
|
|
+
|
|
|
+def get_n_parameters(m: nn.Module):
|
|
|
+ return sum(x.numel() for x in m.parameters() if x.requires_grad)
|
|
|
+
|
|
|
+
|
|
|
+def get_mlp_n_parameters(units: ty.List[int]):
|
|
|
+ x = 0
|
|
|
+ for a, b in zip(units, units[1:]):
|
|
|
+ x += a * b + b
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
+def get_lr(optimizer: optim.Optimizer) -> float:
|
|
|
+ return next(iter(optimizer.param_groups))['lr']
|
|
|
+
|
|
|
+
|
|
|
+def set_lr(optimizer: optim.Optimizer, lr: float) -> None:
|
|
|
+ for x in optimizer.param_groups:
|
|
|
+ x['lr'] = lr
|
|
|
+
|
|
|
+
|
|
|
+def get_device() -> torch.device:
|
|
|
+ return torch.device('cuda:0' if os.environ.get('CUDA_VISIBLE_DEVICES') else 'cpu')
|
|
|
+
|
|
|
+
|
|
|
+@torch.no_grad()
|
|
|
+def get_gradient_norm_ratios(m: nn.Module):
|
|
|
+ return {
|
|
|
+ k: v.grad.norm() / v.norm()
|
|
|
+ for k, v in m.named_parameters()
|
|
|
+ if v.grad is not None
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def is_oom_exception(err: RuntimeError) -> bool:
|
|
|
+ return any(
|
|
|
+ x in str(err)
|
|
|
+ for x in [
|
|
|
+ 'CUDA out of memory',
|
|
|
+ 'CUBLAS_STATUS_ALLOC_FAILED',
|
|
|
+ 'CUDA error: out of memory',
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+# Source: https://github.com/LiyuanLucasLiu/RAdam
|
|
|
+class RAdam(optim.Optimizer):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ params,
|
|
|
+ lr=1e-3,
|
|
|
+ betas=(0.9, 0.999),
|
|
|
+ eps=1e-8,
|
|
|
+ weight_decay=0,
|
|
|
+ degenerated_to_sgd=True,
|
|
|
+ ):
|
|
|
+ if not 0.0 <= lr:
|
|
|
+ raise ValueError("Invalid learning rate: {}".format(lr))
|
|
|
+ if not 0.0 <= eps:
|
|
|
+ raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
|
+ if not 0.0 <= betas[0] < 1.0:
|
|
|
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
|
|
+ if not 0.0 <= betas[1] < 1.0:
|
|
|
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
|
|
+
|
|
|
+ self.degenerated_to_sgd = degenerated_to_sgd
|
|
|
+ if (
|
|
|
+ isinstance(params, (list, tuple))
|
|
|
+ and len(params) > 0
|
|
|
+ and isinstance(params[0], dict)
|
|
|
+ ):
|
|
|
+ for param in params:
|
|
|
+ if 'betas' in param and (
|
|
|
+ param['betas'][0] != betas[0] or param['betas'][1] != betas[1]
|
|
|
+ ):
|
|
|
+ param['buffer'] = [[None, None, None] for _ in range(10)]
|
|
|
+ defaults = dict(
|
|
|
+ lr=lr,
|
|
|
+ betas=betas,
|
|
|
+ eps=eps,
|
|
|
+ weight_decay=weight_decay,
|
|
|
+ buffer=[[None, None, None] for _ in range(10)],
|
|
|
+ )
|
|
|
+ super(RAdam, self).__init__(params, defaults)
|
|
|
+
|
|
|
+ def __setstate__(self, state):
|
|
|
+ super(RAdam, self).__setstate__(state)
|
|
|
+
|
|
|
+ def step(self, closure=None):
|
|
|
+
|
|
|
+ loss = None
|
|
|
+ if closure is not None:
|
|
|
+ loss = closure()
|
|
|
+
|
|
|
+ for group in self.param_groups:
|
|
|
+
|
|
|
+ for p in group['params']:
|
|
|
+ if p.grad is None:
|
|
|
+ continue
|
|
|
+ grad = p.grad.data.float()
|
|
|
+ if grad.is_sparse:
|
|
|
+ raise RuntimeError('RAdam does not support sparse gradients')
|
|
|
+
|
|
|
+ p_data_fp32 = p.data.float()
|
|
|
+
|
|
|
+ state = self.state[p]
|
|
|
+
|
|
|
+ if len(state) == 0:
|
|
|
+ state['step'] = 0
|
|
|
+ state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
|
|
+ state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
|
|
+ else:
|
|
|
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
|
|
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
|
|
+
|
|
|
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
|
|
+ beta1, beta2 = group['betas']
|
|
|
+
|
|
|
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
|
|
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
|
|
+
|
|
|
+ state['step'] += 1
|
|
|
+ buffered = group['buffer'][int(state['step'] % 10)]
|
|
|
+ if state['step'] == buffered[0]:
|
|
|
+ N_sma, step_size = buffered[1], buffered[2]
|
|
|
+ else:
|
|
|
+ buffered[0] = state['step']
|
|
|
+ beta2_t = beta2 ** state['step']
|
|
|
+ N_sma_max = 2 / (1 - beta2) - 1
|
|
|
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
|
|
+ buffered[1] = N_sma
|
|
|
+
|
|
|
+ # more conservative since it's an approximated value
|
|
|
+ if N_sma >= 5:
|
|
|
+ step_size = math.sqrt(
|
|
|
+ (1 - beta2_t)
|
|
|
+ * (N_sma - 4)
|
|
|
+ / (N_sma_max - 4)
|
|
|
+ * (N_sma - 2)
|
|
|
+ / N_sma
|
|
|
+ * N_sma_max
|
|
|
+ / (N_sma_max - 2)
|
|
|
+ ) / (1 - beta1 ** state['step'])
|
|
|
+ elif self.degenerated_to_sgd:
|
|
|
+ step_size = 1.0 / (1 - beta1 ** state['step'])
|
|
|
+ else:
|
|
|
+ step_size = -1
|
|
|
+ buffered[2] = step_size
|
|
|
+
|
|
|
+ # more conservative since it's an approximated value
|
|
|
+ if N_sma >= 5:
|
|
|
+ if group['weight_decay'] != 0:
|
|
|
+ p_data_fp32.add_(
|
|
|
+ -group['weight_decay'] * group['lr'], p_data_fp32
|
|
|
+ )
|
|
|
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
|
|
|
+ p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
|
|
|
+ p.data.copy_(p_data_fp32)
|
|
|
+ elif step_size > 0:
|
|
|
+ if group['weight_decay'] != 0:
|
|
|
+ p_data_fp32.add_(
|
|
|
+ -group['weight_decay'] * group['lr'], p_data_fp32
|
|
|
+ )
|
|
|
+ p_data_fp32.add_(-step_size * group['lr'], exp_avg)
|
|
|
+ p.data.copy_(p_data_fp32)
|
|
|
+
|
|
|
+ return loss
|
|
|
+
|
|
|
+
|
|
|
+version_higher = torch.__version__ >= "1.5.0"
|
|
|
+
|
|
|
+
|
|
|
+# Source: https://github.com/juntang-zhuang/Adabelief-Optimizer
|
|
|
+class AdaBelief(optim.Optimizer):
|
|
|
+ r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch
|
|
|
+ Arguments:
|
|
|
+ params (iterable): iterable of parameters to optimize or dicts defining
|
|
|
+ parameter groups
|
|
|
+ lr (float, optional): learning rate (default: 1e-3)
|
|
|
+ betas (Tuple[float, float], optional): coefficients used for computing
|
|
|
+ running averages of gradient and its square (default: (0.9, 0.999))
|
|
|
+ eps (float, optional): term added to the denominator to improve
|
|
|
+ numerical stability (default: 1e-16)
|
|
|
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
|
|
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
|
|
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
|
|
|
+ (default: False)
|
|
|
+ weight_decouple (boolean, optional): ( default: True) If set as True, then
|
|
|
+ the optimizer uses decoupled weight decay as in AdamW
|
|
|
+ fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
|
|
|
+ is set as True.
|
|
|
+ When fixed_decay == True, the weight decay is performed as
|
|
|
+ $W_{new} = W_{old} - W_{old} \times decay$.
|
|
|
+ When fixed_decay == False, the weight decay is performed as
|
|
|
+ $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the
|
|
|
+ weight decay ratio decreases with learning rate (lr).
|
|
|
+ rectify (boolean, optional): (default: True) If set as True, then perform the rectified
|
|
|
+ update similar to RAdam
|
|
|
+ degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update
|
|
|
+ when variance of gradient is high
|
|
|
+ print_change_log (boolean, optional) (default: True) If set as True, print the modifcation to
|
|
|
+ default hyper-parameters
|
|
|
+ reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ params,
|
|
|
+ lr=1e-3,
|
|
|
+ betas=(0.9, 0.999),
|
|
|
+ eps=1e-16,
|
|
|
+ weight_decay=0,
|
|
|
+ amsgrad=False,
|
|
|
+ weight_decouple=True,
|
|
|
+ fixed_decay=False,
|
|
|
+ rectify=True,
|
|
|
+ degenerated_to_sgd=True,
|
|
|
+ print_change_log=True,
|
|
|
+ ):
|
|
|
+
|
|
|
+ # ------------------------------------------------------------------------------
|
|
|
+ # Print modifications to default arguments
|
|
|
+ if print_change_log:
|
|
|
+ print(
|
|
|
+ 'Please check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.'
|
|
|
+ )
|
|
|
+ print('Modifications to default arguments:')
|
|
|
+ default_table = [
|
|
|
+ ['eps', 'weight_decouple', 'rectify'],
|
|
|
+ ['adabelief-pytorch=0.0.5', '1e-8', 'False', 'False'],
|
|
|
+ ['>=0.1.0 (Current 0.2.0)', '1e-16', 'True', 'True'],
|
|
|
+ ]
|
|
|
+ print(default_table)
|
|
|
+
|
|
|
+ recommend_table = [
|
|
|
+ [
|
|
|
+ 'SGD better than Adam (e.g. CNN for Image Classification)',
|
|
|
+ 'Adam better than SGD (e.g. Transformer, GAN)',
|
|
|
+ ],
|
|
|
+ ['Recommended eps = 1e-8', 'Recommended eps = 1e-16'],
|
|
|
+ ]
|
|
|
+ print(recommend_table)
|
|
|
+
|
|
|
+ print('For a complete table of recommended hyperparameters, see')
|
|
|
+ print('https://github.com/juntang-zhuang/Adabelief-Optimizer')
|
|
|
+
|
|
|
+ print(
|
|
|
+ 'You can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.'
|
|
|
+ )
|
|
|
+ # ------------------------------------------------------------------------------
|
|
|
+
|
|
|
+ if not 0.0 <= lr:
|
|
|
+ raise ValueError("Invalid learning rate: {}".format(lr))
|
|
|
+ if not 0.0 <= eps:
|
|
|
+ raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
|
+ if not 0.0 <= betas[0] < 1.0:
|
|
|
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
|
|
+ if not 0.0 <= betas[1] < 1.0:
|
|
|
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
|
|
+
|
|
|
+ self.degenerated_to_sgd = degenerated_to_sgd
|
|
|
+ if (
|
|
|
+ isinstance(params, (list, tuple))
|
|
|
+ and len(params) > 0
|
|
|
+ and isinstance(params[0], dict)
|
|
|
+ ):
|
|
|
+ for param in params:
|
|
|
+ if 'betas' in param and (
|
|
|
+ param['betas'][0] != betas[0] or param['betas'][1] != betas[1]
|
|
|
+ ):
|
|
|
+ param['buffer'] = [[None, None, None] for _ in range(10)]
|
|
|
+
|
|
|
+ defaults = dict(
|
|
|
+ lr=lr,
|
|
|
+ betas=betas,
|
|
|
+ eps=eps,
|
|
|
+ weight_decay=weight_decay,
|
|
|
+ amsgrad=amsgrad,
|
|
|
+ buffer=[[None, None, None] for _ in range(10)],
|
|
|
+ )
|
|
|
+ super(AdaBelief, self).__init__(params, defaults)
|
|
|
+
|
|
|
+ self.degenerated_to_sgd = degenerated_to_sgd
|
|
|
+ self.weight_decouple = weight_decouple
|
|
|
+ self.rectify = rectify
|
|
|
+ self.fixed_decay = fixed_decay
|
|
|
+ if self.weight_decouple:
|
|
|
+ print('Weight decoupling enabled in AdaBelief')
|
|
|
+ if self.fixed_decay:
|
|
|
+ print('Weight decay fixed')
|
|
|
+ if self.rectify:
|
|
|
+ print('Rectification enabled in AdaBelief')
|
|
|
+ if amsgrad:
|
|
|
+ print('AMSGrad enabled in AdaBelief')
|
|
|
+
|
|
|
+ def __setstate__(self, state):
|
|
|
+ super(AdaBelief, self).__setstate__(state)
|
|
|
+ for group in self.param_groups:
|
|
|
+ group.setdefault('amsgrad', False)
|
|
|
+
|
|
|
+ def reset(self):
|
|
|
+ for group in self.param_groups:
|
|
|
+ for p in group['params']:
|
|
|
+ state = self.state[p]
|
|
|
+ amsgrad = group['amsgrad']
|
|
|
+
|
|
|
+ # State initialization
|
|
|
+ state['step'] = 0
|
|
|
+ # Exponential moving average of gradient values
|
|
|
+ state['exp_avg'] = (
|
|
|
+ torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
|
|
+ if version_higher
|
|
|
+ else torch.zeros_like(p.data)
|
|
|
+ )
|
|
|
+
|
|
|
+ # Exponential moving average of squared gradient values
|
|
|
+ state['exp_avg_var'] = (
|
|
|
+ torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
|
|
+ if version_higher
|
|
|
+ else torch.zeros_like(p.data)
|
|
|
+ )
|
|
|
+
|
|
|
+ if amsgrad:
|
|
|
+ # Maintains max of all exp. moving avg. of sq. grad. values
|
|
|
+ state['max_exp_avg_var'] = (
|
|
|
+ torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
|
|
+ if version_higher
|
|
|
+ else torch.zeros_like(p.data)
|
|
|
+ )
|
|
|
+
|
|
|
+ def step(self, closure=None):
|
|
|
+ """Performs a single optimization step.
|
|
|
+ Arguments:
|
|
|
+ closure (callable, optional): A closure that reevaluates the model
|
|
|
+ and returns the loss.
|
|
|
+ """
|
|
|
+ loss = None
|
|
|
+ if closure is not None:
|
|
|
+ loss = closure()
|
|
|
+
|
|
|
+ for group in self.param_groups:
|
|
|
+ for p in group['params']:
|
|
|
+ if p.grad is None:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # cast data type
|
|
|
+ half_precision = False
|
|
|
+ if p.data.dtype == torch.float16:
|
|
|
+ half_precision = True
|
|
|
+ p.data = p.data.float()
|
|
|
+ p.grad = p.grad.float()
|
|
|
+
|
|
|
+ grad = p.grad.data
|
|
|
+ if grad.is_sparse:
|
|
|
+ raise RuntimeError(
|
|
|
+ 'AdaBelief does not support sparse gradients, please consider SparseAdam instead'
|
|
|
+ )
|
|
|
+ amsgrad = group['amsgrad']
|
|
|
+
|
|
|
+ state = self.state[p]
|
|
|
+
|
|
|
+ beta1, beta2 = group['betas']
|
|
|
+
|
|
|
+ # State initialization
|
|
|
+ if len(state) == 0:
|
|
|
+ state['step'] = 0
|
|
|
+ # Exponential moving average of gradient values
|
|
|
+ state['exp_avg'] = (
|
|
|
+ torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
|
|
+ if version_higher
|
|
|
+ else torch.zeros_like(p.data)
|
|
|
+ )
|
|
|
+ # Exponential moving average of squared gradient values
|
|
|
+ state['exp_avg_var'] = (
|
|
|
+ torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
|
|
+ if version_higher
|
|
|
+ else torch.zeros_like(p.data)
|
|
|
+ )
|
|
|
+ if amsgrad:
|
|
|
+ # Maintains max of all exp. moving avg. of sq. grad. values
|
|
|
+ state['max_exp_avg_var'] = (
|
|
|
+ torch.zeros_like(
|
|
|
+ p.data, memory_format=torch.preserve_format
|
|
|
+ )
|
|
|
+ if version_higher
|
|
|
+ else torch.zeros_like(p.data)
|
|
|
+ )
|
|
|
+
|
|
|
+ # perform weight decay, check if decoupled weight decay
|
|
|
+ if self.weight_decouple:
|
|
|
+ if not self.fixed_decay:
|
|
|
+ p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
|
|
|
+ else:
|
|
|
+ p.data.mul_(1.0 - group['weight_decay'])
|
|
|
+ else:
|
|
|
+ if group['weight_decay'] != 0:
|
|
|
+ grad.add_(p.data, alpha=group['weight_decay'])
|
|
|
+
|
|
|
+ # get current state variable
|
|
|
+ exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
|
|
|
+
|
|
|
+ state['step'] += 1
|
|
|
+ bias_correction1 = 1 - beta1 ** state['step']
|
|
|
+ bias_correction2 = 1 - beta2 ** state['step']
|
|
|
+
|
|
|
+ # Update first and second moment running average
|
|
|
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
|
+ grad_residual = grad - exp_avg
|
|
|
+ exp_avg_var.mul_(beta2).addcmul_(
|
|
|
+ grad_residual, grad_residual, value=1 - beta2
|
|
|
+ )
|
|
|
+
|
|
|
+ if amsgrad:
|
|
|
+ max_exp_avg_var = state['max_exp_avg_var']
|
|
|
+ # Maintains the maximum of all 2nd moment running avg. till now
|
|
|
+ torch.max(
|
|
|
+ max_exp_avg_var,
|
|
|
+ exp_avg_var.add_(group['eps']),
|
|
|
+ out=max_exp_avg_var,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Use the max. for normalizing running avg. of gradient
|
|
|
+ denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(
|
|
|
+ group['eps']
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ denom = (
|
|
|
+ exp_avg_var.add_(group['eps']).sqrt()
|
|
|
+ / math.sqrt(bias_correction2)
|
|
|
+ ).add_(group['eps'])
|
|
|
+
|
|
|
+ # update
|
|
|
+ if not self.rectify:
|
|
|
+ # Default update
|
|
|
+ step_size = group['lr'] / bias_correction1
|
|
|
+ p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
|
|
+
|
|
|
+ else: # Rectified update, forked from RAdam
|
|
|
+ buffered = group['buffer'][int(state['step'] % 10)]
|
|
|
+ if state['step'] == buffered[0]:
|
|
|
+ N_sma, step_size = buffered[1], buffered[2]
|
|
|
+ else:
|
|
|
+ buffered[0] = state['step']
|
|
|
+ beta2_t = beta2 ** state['step']
|
|
|
+ N_sma_max = 2 / (1 - beta2) - 1
|
|
|
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
|
|
+ buffered[1] = N_sma
|
|
|
+
|
|
|
+ # more conservative since it's an approximated value
|
|
|
+ if N_sma >= 5:
|
|
|
+ step_size = math.sqrt(
|
|
|
+ (1 - beta2_t)
|
|
|
+ * (N_sma - 4)
|
|
|
+ / (N_sma_max - 4)
|
|
|
+ * (N_sma - 2)
|
|
|
+ / N_sma
|
|
|
+ * N_sma_max
|
|
|
+ / (N_sma_max - 2)
|
|
|
+ ) / (1 - beta1 ** state['step'])
|
|
|
+ elif self.degenerated_to_sgd:
|
|
|
+ step_size = 1.0 / (1 - beta1 ** state['step'])
|
|
|
+ else:
|
|
|
+ step_size = -1
|
|
|
+ buffered[2] = step_size
|
|
|
+
|
|
|
+ if N_sma >= 5:
|
|
|
+ denom = exp_avg_var.sqrt().add_(group['eps'])
|
|
|
+ p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
|
|
|
+ elif step_size > 0:
|
|
|
+ p.data.add_(exp_avg, alpha=-step_size * group['lr'])
|
|
|
+
|
|
|
+ if half_precision:
|
|
|
+ p.data = p.data.half()
|
|
|
+ p.grad = p.grad.half()
|
|
|
+
|
|
|
+ return loss
|