from .data import Data
from .. import config
from ..backend import tf
[docs]
class Constraint(Data):
"""General constraints."""
def __init__(self, constraint, train_x, test_x):
self.constraint = constraint
self.train_x = train_x
self.test_x = test_x
[docs]
def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
f = tf.cond(
model.net.training,
lambda: self.constraint(inputs, outputs, self.train_x),
lambda: self.constraint(inputs, outputs, self.test_x),
)
return loss_fn(tf.zeros(tf.shape(f), dtype=config.real(tf)), f)
[docs]
def train_next_batch(self, batch_size=None):
return self.train_x, None
[docs]
def test(self):
return self.test_x, None