Add spiral dataset
This commit is contained in:
parent
b7d53aa5f1
commit
e686adbea1
33
prototorch/datasets/spiral.py
Normal file
33
prototorch/datasets/spiral.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
"""Spiral dataset for binary classification."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def make_spiral(n_samples=500, noise=0.3):
|
||||||
|
def get_samples(n, delta_t):
|
||||||
|
points = []
|
||||||
|
for i in range(n):
|
||||||
|
r = i / n_samples * 5
|
||||||
|
t = 1.75 * i / n * 2 * np.pi + delta_t
|
||||||
|
x = r * np.sin(t) + np.random.rand(1) * noise
|
||||||
|
y = r * np.cos(t) + np.random.rand(1) * noise
|
||||||
|
points.append([x, y])
|
||||||
|
return points
|
||||||
|
|
||||||
|
n = n_samples // 2
|
||||||
|
positive = get_samples(n=n, delta_t=0)
|
||||||
|
negative = get_samples(n=n, delta_t=np.pi)
|
||||||
|
x = np.concatenate(
|
||||||
|
[np.array(positive).reshape(n, -1),
|
||||||
|
np.array(negative).reshape(n, -1)],
|
||||||
|
axis=0)
|
||||||
|
y = np.concatenate([np.zeros(n), np.ones(n)])
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class Spiral(torch.utils.data.TensorDataset):
|
||||||
|
"""Spiral dataset for binary classification."""
|
||||||
|
def __init__(self, n_samples=500, noise=0.3):
|
||||||
|
x, y = make_spiral(n_samples, noise)
|
||||||
|
super().__init__(torch.Tensor(x), torch.LongTensor(y))
|
Loading…
Reference in New Issue
Block a user