feat: add XOR
dataset
This commit is contained in:
parent
c1d7cfee8f
commit
bf09ff8f7f
@ -10,3 +10,4 @@ from .sklearn import (
|
|||||||
)
|
)
|
||||||
from .spiral import Spiral
|
from .spiral import Spiral
|
||||||
from .tecator import Tecator
|
from .tecator import Tecator
|
||||||
|
from .xor import XOR
|
||||||
|
18
prototorch/datasets/xor.py
Normal file
18
prototorch/datasets/xor.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
"""Exclusive-or (XOR) dataset for binary classification."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def make_xor(num_samples=500):
|
||||||
|
x = torch.rand(num_samples, 2)
|
||||||
|
y = torch.zeros(num_samples)
|
||||||
|
y[torch.logical_and(x[:, 0] > 0.5, x[:, 1] < 0.5)] = 1
|
||||||
|
y[torch.logical_and(x[:, 1] > 0.5, x[:, 0] < 0.5)] = 1
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class XOR(torch.utils.data.TensorDataset):
|
||||||
|
"""Exclusive-or (XOR) dataset for binary classification."""
|
||||||
|
def __init__(self, num_samples: int = 500):
|
||||||
|
x, y = make_xor(num_samples)
|
||||||
|
super().__init__(x, y)
|
Loading…
Reference in New Issue
Block a user