API

 _modules/torch.utils.data.sampler


Source code for torch.utils.data.sampler

import torch
from torch._six import int_classes as _int_classes
from torch import Tensor

from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized

T_co = TypeVar('T_co', covariant=True)

[docs]class Sampler(Generic[T_co]): r"""Base class for all Samplers. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a way to iterate over indices of dataset elements, and a :meth:`__len__` method that returns the length of the returned iterators. .. note:: The :meth:`__len__` method isn't strictly required by :class:`~torch.utils.data.DataLoader`, but is expected in any calculation involving the length of a :class:`~torch.utils.data.DataLoader`. """ def __init__(self, data_source: Optional[Sized]) -> None: pass def __iter__(self) -> Iterator[T_co]: raise NotImplementedError
# NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] # # Many times we have an abstract class representing a collection/iterable of # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally # implementing a `__len__` method. In such cases, we must make sure to not # provide a default implementation, because both straightforward default # implementations have their issues: # # + `return NotImplemented`: # Calling `len(subclass_instance)` raises: # TypeError: 'NotImplementedType' object cannot be interpreted as an integer # # + `raise NotImplementedError()`: # This prevents triggering some fallback behavior. E.g., the built-in # `list(X)` tries to call `len(X)` first, and executes a different code # path if the method is not found or `NotImplemented` is returned, while # raising an `NotImplementedError` will propagate and and make the call # fail where it could have use `__iter__` to complete the call. # # Thus, the only two sensible things to do are # # + **not** provide a default `__len__`. # # + raise a `TypeError` instead, which is what Python uses when users call # a method that is not defined on an object. # (@ssnl verifies that this works on at least Python 3.7.)
[docs]class SequentialSampler(Sampler[int]): r"""Samples elements sequentially, always in the same order. Arguments: data_source (Dataset): dataset to sample from """ data_source: Sized def __init__(self, data_source): self.data_source = data_source def __iter__(self): return iter(range(len(self.data_source))) def __len__(self) -> int: return len(self.data_source)
[docs]class RandomSampler(Sampler[int]): r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify :attr:`num_samples` to draw. Arguments: data_source (Dataset): dataset to sample from replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. This argument is supposed to be specified only when `replacement` is ``True``. generator (Generator): Generator used in sampling. """ data_source: Sized replacement: bool def __init__(self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None) -> None: self.data_source = data_source self.replacement = replacement self._num_samples = num_samples self.generator = generator if not isinstance(self.replacement, bool): raise TypeError("replacement should be a boolean value, but got " "replacement={}".format(self.replacement)) if self._num_samples is not None and not replacement: raise ValueError("With replacement=False, num_samples should not be specified, " "since a random permute will be performed.") if not isinstance(self.num_samples, int) or self.num_samples <= 0: raise ValueError("num_samples should be a positive integer " "value, but got num_samples={}".format(self.num_samples)) @property def num_samples(self) -> int: # dataset size might change at runtime if self._num_samples is None: return len(self.data_source) return self._num_samples def __iter__(self): n = len(self.data_source) if self.generator is None: generator = torch.Generator() generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) else: generator = self.generator if self.replacement: for _ in range(self.num_samples // 32): yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() else: yield from torch.randperm(n, generator=self.generator).tolist() def __len__(self): return self.num_samples
[docs]class SubsetRandomSampler(Sampler[int]): r"""Samples elements randomly from a given list of indices, without replacement. Arguments: indices (sequence): a sequence of indices generator (Generator): Generator used in sampling. """ indices: Sequence[int] def __init__(self, indices: Sequence[int], generator=None) -> None: self.indices = indices self.generator = generator def __iter__(self): return (self.indices[i] for i in torch.randperm(len(self.indices), generator=self.generator)) def __len__(self): return len(self.indices)
[docs]class WeightedRandomSampler(Sampler[int]): r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). Args: weights (sequence) : a sequence of weights, not necessary summing up to one num_samples (int): number of samples to draw replacement (bool): if ``True``, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row. generator (Generator): Generator used in sampling. Example: >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) [4, 4, 1, 4, 5] >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) [0, 1, 4, 3, 2] """ weights: Tensor num_samples: int replacement: bool def __init__(self, weights: Sequence[float], num_samples: int, replacement: bool = True, generator=None) -> None: if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \ num_samples <= 0: raise ValueError("num_samples should be a positive integer " "value, but got num_samples={}".format(num_samples)) if not isinstance(replacement, bool): raise ValueError("replacement should be a boolean value, but got " "replacement={}".format(replacement)) self.weights = torch.as_tensor(weights, dtype=torch.double) self.num_samples = num_samples self.replacement = replacement self.generator = generator def __iter__(self): rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator) return iter(rand_tensor.tolist()) def __len__(self): return self.num_samples
[docs]class BatchSampler(Sampler[List[int]]): r"""Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler or Iterable): Base sampler. Can be any iterable object batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None: # Since collections.abc.Iterable does not check for `__getitem__`, which # is one way for an object to be an iterable, we don't do an `isinstance` # check here. if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ batch_size <= 0: raise ValueError("batch_size should be a positive integer value, " "but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): raise ValueError("drop_last should be a boolean value, but got " "drop_last={}".format(drop_last)) self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch def __len__(self): # Can only be called if self.sampler has __len__ implemented # We cannot enforce this condition, so we turn off typechecking for the # implementation below. # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] if self.drop_last: return len(self.sampler) // self.batch_size # type: ignore else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore

此页内容是否对您有帮助