feat: add XOR
dataset
This commit is contained in:
parent
c1d7cfee8f
commit
bf09ff8f7f
@ -10,3 +10,4 @@ from .sklearn import (
|
||||
)
|
||||
from .spiral import Spiral
|
||||
from .tecator import Tecator
|
||||
from .xor import XOR
|
||||
|
18
prototorch/datasets/xor.py
Normal file
18
prototorch/datasets/xor.py
Normal file
@ -0,0 +1,18 @@
|
||||
"""Exclusive-or (XOR) dataset for binary classification."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def make_xor(num_samples=500):
|
||||
x = torch.rand(num_samples, 2)
|
||||
y = torch.zeros(num_samples)
|
||||
y[torch.logical_and(x[:, 0] > 0.5, x[:, 1] < 0.5)] = 1
|
||||
y[torch.logical_and(x[:, 1] > 0.5, x[:, 0] < 0.5)] = 1
|
||||
return x, y
|
||||
|
||||
|
||||
class XOR(torch.utils.data.TensorDataset):
|
||||
"""Exclusive-or (XOR) dataset for binary classification."""
|
||||
def __init__(self, num_samples: int = 500):
|
||||
x, y = make_xor(num_samples)
|
||||
super().__init__(x, y)
|
Loading…
Reference in New Issue
Block a user