diff --git a/prototorch/datasets/__init__.py b/prototorch/datasets/__init__.py index 1fd485f..a79c52d 100644 --- a/prototorch/datasets/__init__.py +++ b/prototorch/datasets/__init__.py @@ -10,3 +10,4 @@ from .sklearn import ( ) from .spiral import Spiral from .tecator import Tecator +from .xor import XOR diff --git a/prototorch/datasets/xor.py b/prototorch/datasets/xor.py new file mode 100644 index 0000000..57925aa --- /dev/null +++ b/prototorch/datasets/xor.py @@ -0,0 +1,18 @@ +"""Exclusive-or (XOR) dataset for binary classification.""" + +import torch + + +def make_xor(num_samples=500): + x = torch.rand(num_samples, 2) + y = torch.zeros(num_samples) + y[torch.logical_and(x[:, 0] > 0.5, x[:, 1] < 0.5)] = 1 + y[torch.logical_and(x[:, 1] > 0.5, x[:, 0] < 0.5)] = 1 + return x, y + + +class XOR(torch.utils.data.TensorDataset): + """Exclusive-or (XOR) dataset for binary classification.""" + def __init__(self, num_samples: int = 500): + x, y = make_xor(num_samples) + super().__init__(x, y)