Source code for deepxde.data.sampler

import numpy as np


[docs] class BatchSampler: """Samples a mini-batch of indices. The indices are repeated indefinitely. Has the same effect as: .. code-block:: python indices = tf.data.Dataset.range(num_samples) indices = indices.repeat().shuffle(num_samples).batch(batch_size) iterator = iter(indices) batch_indices = iterator.get_next() However, ``tf.data.Dataset.__iter__()`` is only supported inside of ``tf.function`` or when eager execution is enabled. ``tf.data.Dataset.make_one_shot_iterator()`` supports graph mode, but is too slow. This class is not implemented as a Python Iterator, so that it can support dynamic batch size. Args: num_samples (int): The number of samples. shuffle (bool): Set to ``True`` to have the indices reshuffled at every epoch. """ def __init__(self, num_samples, shuffle=True): self.num_samples = num_samples self.shuffle = shuffle self._indices = np.arange(self.num_samples) self._epochs_completed = 0 self._index_in_epoch = 0 # Shuffle for the first epoch if shuffle: np.random.shuffle(self._indices) @property def epochs_completed(self): return self._epochs_completed
[docs] def get_next(self, batch_size): """Returns the indices of the next batch. Args: batch_size (int): The number of elements to combine in a single batch. """ if batch_size > self.num_samples: raise ValueError( "batch_size={} is larger than num_samples={}.".format( batch_size, self.num_samples ) ) start = self._index_in_epoch if start + batch_size <= self.num_samples: self._index_in_epoch += batch_size end = self._index_in_epoch return self._indices[start:end] else: # Finished epoch self._epochs_completed += 1 # Get the rest examples in this epoch rest_num_samples = self.num_samples - start indices_rest_part = np.copy( self._indices[start : self.num_samples] ) # self._indices will be shuffled below. # Shuffle the indices if self.shuffle: np.random.shuffle(self._indices) # Start next epoch start = 0 self._index_in_epoch = batch_size - rest_num_samples end = self._index_in_epoch indices_new_part = self._indices[start:end] return np.hstack((indices_rest_part, indices_new_part))