return F.batch_norm( input, # If buffers are not to be tracked, ensure that they won't be updated self.running_mean ifnotself.training orself.track_running_stats elseNone, self.running_var ifnotself.training orself.track_running_stats elseNone, self.weight, self.bias, bn_training, exponential_average_factor, self.eps, )
def_load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version = local_metadata.get('version', None) # at version 1: removed running_mean and running_var when # track_running_stats=False (default) if version isNoneandnotself.track_running_stats: running_stats_keys = [] for name in ('running_mean', 'running_var'): key = prefix + name if key in state_dict: running_stats_keys.append(key) iflen(running_stats_keys) > 0: error_msgs.append( 'Unexpected running stats buffer(s) {names} for {klass} ' 'with track_running_stats=False. If state_dict is a ' 'checkpoint saved before 0.4.0, this may be expected ' 'because {klass} does not track running stats by default ' 'since 0.4.0. Please remove these keys from state_dict. If ' 'the running stats are actually needed, instead set ' 'track_running_stats=True in {klass} to enable them. See ' 'the documentation of {klass} for details.' .format(names=" and ".join('"{}"'.format(k) for k in running_stats_keys), klass=self.__class__.__name__)) for key in running_stats_keys: state_dict.pop(key)
auto input_reshaped = input.contiguous().view(shape); auto out = at::batch_norm(input_reshaped, weight_, bias_, running_mean_, running_var_, use_input_stats, momentum, eps, cudnn_enabled);
// we alias running_mean and running_var because they are const but we want to modify their data if (running_mean.defined()) { at::alias(running_mean).copy_(running_mean_.view({ b, c }).mean(0, false)); } if (running_var.defined()) { at::alias(running_var).copy_(running_var_.view({ b, c }).mean(0, false)); }
def__init__(self, name: str, dim: int) -> None: if dim isNone: dim = -1 self.name = name self.dim = dim
# TODO Make return type more specific defcompute_weight(self, module: Module) -> Any: g = getattr(module, self.name + '_g') v = getattr(module, self.name + '_v') return _weight_norm(v, g, self.dim)
@staticmethod defapply(module, name: str, dim: int) -> 'WeightNorm': for k, hook in module._forward_pre_hooks.items(): ifisinstance(hook, WeightNorm) and hook.name == name: raise RuntimeError("Cannot register two weight_norm hooks on " "the same parameter {}".format(name))
if dim isNone: dim = -1
fn = WeightNorm(name, dim)
weight = getattr(module, name) ifisinstance(weight, UninitializedParameter): raise ValueError( 'The module passed to `WeightNorm` can\'t have uninitialized parameters. ' 'Make sure to run the dummy forward before applying weight normalization') # remove w from parameter list del module._parameters[name]
# add g and v as new parameters and express w as g/||v|| * v module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data)) module.register_parameter(name + '_v', Parameter(weight.data)) setattr(module, name, fn.compute_weight(module))
# recompute weight before every forward() module.register_forward_pre_hook(fn)
defweight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_module: r"""Applies weight normalization to a parameter in the given module. .. math:: \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This replaces the parameter specified by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``). Weight normalization is implemented via a hook that recomputes the weight tensor from the magnitude and direction before every :meth:`~Module.forward` call. By default, with ``dim=0``, the norm is computed independently per output channel/plane. To compute a norm over the entire weight tensor, use ``dim=None``. See https://arxiv.org/abs/1602.07868 Args: module (Module): containing module name (str, optional): name of weight parameter dim (int, optional): dimension over which to compute the norm Returns: The original module with the weight norm hook Example:: >>> m = weight_norm(nn.Linear(20, 40), name='weight') >>> m Linear(in_features=20, out_features=40, bias=True) >>> m.weight_g.size() torch.Size([40, 1]) >>> m.weight_v.size() torch.Size([40, 20]) """ WeightNorm.apply(module, name, dim) return module
defremove_weight_norm(module: T_module, name: str = 'weight') -> T_module: r"""Removes the weight normalization reparameterization from a module. Args: module (Module): containing module name (str, optional): name of weight parameter Example: >>> m = weight_norm(nn.Linear(20, 40)) >>> remove_weight_norm(m) """ for k, hook in module._forward_pre_hooks.items(): ifisinstance(hook, WeightNorm) and hook.name == name: hook.remove(module) del module._forward_pre_hooks[k] return module
raise ValueError("weight_norm of '{}' not found in {}" .format(name, module))
TORCH_CHECK( v_in.device() == g_in.device(), "weight_norm: expected v_in and g_in to be on the same device, but v_in is " "on ", v_in.device(), " and g_in is on ", g_in.device());
auto v = v_in.contiguous(); auto g = g_in.contiguous();
if (can_use_fused) { // weight_norm does not have a derivative defined for it, so this will route back through // VariableType.cpp, and construct a WeightNormFusedBackward object in the autograd graph. return std::get<0>(at::_weight_norm_cuda_interface(v, g, dim)); } else { // Double-differentiable primitive ops // at::native::norm_except_dim would probably be fine as well. return v*(g/at::norm_except_dim(v, 2, dim)); } }
// Differentiable backward path, an alternative to weight_norm_cuda_backward, to be used // when backward is itself creating a graph. // The GradMode::is_enabled() check must be performed within Functions.cpp; that's why we // define a separate function here, instead of inlining it in weight_norm_cuda_backward. std::tuple<Tensor, Tensor> _weight_norm_differentiable_backward (const Tensor & grad_w, const Tensor & saved_v, const Tensor & saved_g, const Tensor & saved_norms, int64_t dim) { // In Functions.cpp, the HardshrinkBackward object supplies "grad.contiguous()" // as the first argument, so grad_w should be contiguous here. // All these checks should succeed: TORCH_CHECK(grad_w.is_contiguous(), "grad_w must be contiguous"); TORCH_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous"); TORCH_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous"); TORCH_CHECK(saved_norms.is_contiguous(), "saved_norms must be contiguous");
// Like weight_norm_fused_backward, weight_norm_differentiable_backward should only ever be called // through a WeightNormFusedBackward object, so we expect that dim == 0 || dim == saved_v.size(-1) TORCH_CHECK(dim == 0 || dim == last_dim, "Expected dim to be the first or last dimension");
// saved_g and saved_norms are already shaped to broadcast over the correct dimensions
// ...but saved_norms might be Float when saved_g and saved_v are half. // To consider: saved_norms.to(..., True /*non_blocking*/); auto norms = saved_norms.to(saved_g.scalar_type());