# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the CC-by-NC license found in the # LICENSE file in the root directory of this source tree. import torch from torch import Tensor def categorical(probs: Tensor) -> Tensor: r"""Categorical sampler according to weights in the last dimension of ``probs`` using :func:`torch.multinomial`. Args: probs (Tensor): probabilities. Returns: Tensor: Samples. """ return torch.multinomial(probs.flatten(0, -2), 1, replacement=True).view( *probs.shape[:-1] )