From bf09ff8f7ff73c5e7ce1134afd6d89c5ae230db1 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Thu, 15 Jul 2021 18:14:38 +0200 Subject: [PATCH] feat: add `XOR` dataset --- prototorch/datasets/__init__.py | 1 + prototorch/datasets/xor.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) create mode 100644 prototorch/datasets/xor.py diff --git a/prototorch/datasets/__init__.py b/prototorch/datasets/__init__.py index 1fd485f..a79c52d 100644 --- a/prototorch/datasets/__init__.py +++ b/prototorch/datasets/__init__.py @@ -10,3 +10,4 @@ from .sklearn import ( ) from .spiral import Spiral from .tecator import Tecator +from .xor import XOR diff --git a/prototorch/datasets/xor.py b/prototorch/datasets/xor.py new file mode 100644 index 0000000..57925aa --- /dev/null +++ b/prototorch/datasets/xor.py @@ -0,0 +1,18 @@ +"""Exclusive-or (XOR) dataset for binary classification.""" + +import torch + + +def make_xor(num_samples=500): + x = torch.rand(num_samples, 2) + y = torch.zeros(num_samples) + y[torch.logical_and(x[:, 0] > 0.5, x[:, 1] < 0.5)] = 1 + y[torch.logical_and(x[:, 1] > 0.5, x[:, 0] < 0.5)] = 1 + return x, y + + +class XOR(torch.utils.data.TensorDataset): + """Exclusive-or (XOR) dataset for binary classification.""" + def __init__(self, num_samples: int = 500): + x, y = make_xor(num_samples) + super().__init__(x, y)