use resnet50 to train a cifar10 classifier
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,2 +1,4 @@
|
||||
data
|
||||
*.pt
|
||||
.DS_Store
|
||||
__pycache__
|
@ -1,6 +1,6 @@
|
||||
# Abstract
|
||||
|
||||
Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track model training with PyTorch, including data, trained model file, and used parameters. The data will be recorded and pushed to my private DVC remote via webdav🎁
|
||||
Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track image classification model training with PyTorch, including data, trained model file, and used parameters. The data will be recorded and pushed to my private DVC remote via webdav🎁
|
||||
|
||||
# Requirements
|
||||
|
||||
@ -11,6 +11,8 @@ Attempt to use [DVC](https://dvc.ai/), a data versioning tool, to track model tr
|
||||
* **env**
|
||||
* **pt.yaml**
|
||||
* conda env yaml to run this repo
|
||||
* **utils**
|
||||
* house pre-built functions
|
||||
|
||||
# Files
|
||||
|
||||
|
48
dvc.lock
Normal file
48
dvc.lock
Normal file
@ -0,0 +1,48 @@
|
||||
schema: '2.0'
|
||||
stages:
|
||||
prepare:
|
||||
cmd: python prepare.py
|
||||
deps:
|
||||
- path: prepare.py
|
||||
hash: md5
|
||||
md5: a1c07d1d5caf6e5288560a189415785c
|
||||
size: 2979
|
||||
params:
|
||||
params.yaml:
|
||||
prepare:
|
||||
data_dir: data/raw
|
||||
save_dir: data/processed
|
||||
train_valid_split:
|
||||
- 0.7
|
||||
- 0.3
|
||||
random_seed: 0
|
||||
outs:
|
||||
- path: data/processed
|
||||
hash: md5
|
||||
md5: f4bf62ffa725ca9144b7852a283dc1da.dir
|
||||
size: 295118798
|
||||
nfiles: 60000
|
||||
train:
|
||||
cmd: python train.py
|
||||
deps:
|
||||
- path: data/processed
|
||||
hash: md5
|
||||
md5: f4bf62ffa725ca9144b7852a283dc1da.dir
|
||||
size: 295118798
|
||||
nfiles: 60000
|
||||
- path: train.py
|
||||
hash: md5
|
||||
md5: b797ccf2fe61952bbf6d83fa51b0b11f
|
||||
size: 3407
|
||||
params:
|
||||
params.yaml:
|
||||
train:
|
||||
data_dir: data/processed
|
||||
epochs: 5
|
||||
batch_size: 256
|
||||
learning_rate: 5e-05
|
||||
outs:
|
||||
- path: model.pt
|
||||
hash: md5
|
||||
md5: 8ead2a7cd52d70b359d3cdc3df5e43e3
|
||||
size: 102592994
|
39
dvc.yaml
39
dvc.yaml
@ -2,27 +2,42 @@ stages:
|
||||
prepare:
|
||||
cmd: python prepare.py
|
||||
deps:
|
||||
- prepare.py
|
||||
- prepare.py
|
||||
params:
|
||||
- prepare
|
||||
- prepare
|
||||
outs:
|
||||
- data/processed
|
||||
- data/processed
|
||||
train:
|
||||
cmd: python train.py
|
||||
deps:
|
||||
- data/processed
|
||||
- train.py
|
||||
- data/processed
|
||||
- train.py
|
||||
params:
|
||||
- train
|
||||
- train
|
||||
outs:
|
||||
- model.pt
|
||||
- model.pt
|
||||
evaluate:
|
||||
cmd: python evaluate.py
|
||||
deps:
|
||||
- data/processed
|
||||
- evaluate.py
|
||||
- model.pt
|
||||
- data/processed
|
||||
- evaluate.py
|
||||
- model.pt
|
||||
params:
|
||||
- evaluate
|
||||
- evaluate
|
||||
outs:
|
||||
- eval
|
||||
- eval
|
||||
params:
|
||||
- dvclive/train/params.yaml
|
||||
- dvclive/eval/params.yaml
|
||||
artifacts:
|
||||
resnet50:
|
||||
path: model.pt
|
||||
type: model
|
||||
metrics:
|
||||
- dvclive/train/metrics.json
|
||||
- dvclive/eval/metrics.json
|
||||
plots:
|
||||
- dvclive/train/plots/metrics:
|
||||
x: step
|
||||
- dvclive/eval/plots/metrics:
|
||||
x: step
|
||||
|
3
dvclive/eval/metrics.json
Normal file
3
dvclive/eval/metrics.json
Normal file
@ -0,0 +1,3 @@
|
||||
{
|
||||
"test_acc": 0.7336928844451904
|
||||
}
|
1
dvclive/eval/params.yaml
Normal file
1
dvclive/eval/params.yaml
Normal file
@ -0,0 +1 @@
|
||||
data_dir: data/processed
|
2
dvclive/eval/plots/metrics/test_acc.tsv
Normal file
2
dvclive/eval/plots/metrics/test_acc.tsv
Normal file
@ -0,0 +1,2 @@
|
||||
step test_acc
|
||||
0 0.7336928844451904
|
|
15
dvclive/eval/report.md
Normal file
15
dvclive/eval/report.md
Normal file
@ -0,0 +1,15 @@
|
||||
# DVC Report
|
||||
|
||||
params.yaml
|
||||
|
||||
| data_dir |
|
||||
|----------------|
|
||||
| data/processed |
|
||||
|
||||
metrics.json
|
||||
|
||||
| test_acc |
|
||||
|------------|
|
||||
| 0.733693 |
|
||||
|
||||

|
BIN
dvclive/eval/static/test_acc.png
Normal file
BIN
dvclive/eval/static/test_acc.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 14 KiB |
7
dvclive/train/metrics.json
Normal file
7
dvclive/train/metrics.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"train_loss": 2.2422571182250977,
|
||||
"train_acc": 0.7347080707550049,
|
||||
"valid_loss": 2.3184266090393066,
|
||||
"valid_acc": 0.7381500005722046,
|
||||
"step": 4
|
||||
}
|
5
dvclive/train/params.yaml
Normal file
5
dvclive/train/params.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
data_dir: data/processed
|
||||
epochs: 5
|
||||
batch_size: 256
|
||||
learning_rate: 5e-05
|
||||
epoch: 4
|
6
dvclive/train/plots/metrics/train_acc.tsv
Normal file
6
dvclive/train/plots/metrics/train_acc.tsv
Normal file
@ -0,0 +1,6 @@
|
||||
step train_acc
|
||||
0 0.6712241768836975
|
||||
1 0.6976224184036255
|
||||
2 0.7157850861549377
|
||||
3 0.7277812957763672
|
||||
4 0.7347080707550049
|
|
6
dvclive/train/plots/metrics/train_loss.tsv
Normal file
6
dvclive/train/plots/metrics/train_loss.tsv
Normal file
@ -0,0 +1,6 @@
|
||||
step train_loss
|
||||
0 3.0726168155670166
|
||||
1 2.7409346103668213
|
||||
2 2.5224294662475586
|
||||
3 2.364570140838623
|
||||
4 2.2422571182250977
|
|
6
dvclive/train/plots/metrics/valid_acc.tsv
Normal file
6
dvclive/train/plots/metrics/valid_acc.tsv
Normal file
@ -0,0 +1,6 @@
|
||||
step valid_acc
|
||||
0 0.6918894052505493
|
||||
1 0.7131190896034241
|
||||
2 0.7261338233947754
|
||||
3 0.7339118123054504
|
||||
4 0.7381500005722046
|
|
6
dvclive/train/plots/metrics/valid_loss.tsv
Normal file
6
dvclive/train/plots/metrics/valid_loss.tsv
Normal file
@ -0,0 +1,6 @@
|
||||
step valid_loss
|
||||
0 2.890321969985962
|
||||
1 2.669679880142212
|
||||
2 2.5183584690093994
|
||||
3 2.4061686992645264
|
||||
4 2.3184266090393066
|
|
21
dvclive/train/report.md
Normal file
21
dvclive/train/report.md
Normal file
@ -0,0 +1,21 @@
|
||||
# DVC Report
|
||||
|
||||
params.yaml
|
||||
|
||||
| data_dir | epochs | batch_size | learning_rate | epoch |
|
||||
|----------------|----------|--------------|-----------------|---------|
|
||||
| data/processed | 5 | 256 | 5e-05 | 4 |
|
||||
|
||||
metrics.json
|
||||
|
||||
| train_loss | train_acc | valid_loss | valid_acc | step |
|
||||
|--------------|-------------|--------------|-------------|--------|
|
||||
| 2.24226 | 0.734708 | 2.31843 | 0.73815 | 4 |
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
BIN
dvclive/train/static/train_acc.png
Normal file
BIN
dvclive/train/static/train_acc.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 21 KiB |
BIN
dvclive/train/static/train_loss.png
Normal file
BIN
dvclive/train/static/train_loss.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 20 KiB |
BIN
dvclive/train/static/valid_acc.png
Normal file
BIN
dvclive/train/static/valid_acc.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 21 KiB |
BIN
dvclive/train/static/valid_loss.png
Normal file
BIN
dvclive/train/static/valid_loss.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 22 KiB |
18
env/pt.yaml
vendored
18
env/pt.yaml
vendored
@ -21,9 +21,11 @@ dependencies:
|
||||
- billiard=4.1.0
|
||||
- boto3=1.34.9
|
||||
- botocore=1.34.9
|
||||
- brotli=1.1.0
|
||||
- brotli-bin=1.1.0
|
||||
- brotli-python=1.1.0
|
||||
- bzip2=1.0.8
|
||||
- ca-certificates=2023.08.22
|
||||
- ca-certificates=2023.11.17
|
||||
- cairo=1.18.0
|
||||
- celery=5.3.4
|
||||
- certifi=2023.11.17
|
||||
@ -35,7 +37,9 @@ dependencies:
|
||||
- click-repl=0.3.0
|
||||
- colorama=0.4.6
|
||||
- configobj=5.0.8
|
||||
- contourpy=1.2.0
|
||||
- cryptography=41.0.7
|
||||
- cycler=0.12.1
|
||||
- dav1d=1.2.1
|
||||
- decorator=5.1.1
|
||||
- dictdiffer=0.9.0
|
||||
@ -64,6 +68,7 @@ dependencies:
|
||||
- fontconfig=2.14.2
|
||||
- fonts-conda-ecosystem=1
|
||||
- fonts-conda-forge=1
|
||||
- fonttools=4.47.0
|
||||
- freetype=2.12.1
|
||||
- fribidi=1.0.10
|
||||
- frozenlist=1.4.1
|
||||
@ -94,6 +99,7 @@ dependencies:
|
||||
- jinja2=3.1.2
|
||||
- jmespath=1.0.1
|
||||
- joblib=1.2.0
|
||||
- kiwisolver=1.4.5
|
||||
- kombu=5.3.4
|
||||
- krb5=1.21.2
|
||||
- lame=3.100
|
||||
@ -101,6 +107,9 @@ dependencies:
|
||||
- lerc=4.0.0
|
||||
- libass=0.17.1
|
||||
- libblas=3.9.0
|
||||
- libbrotlicommon=1.1.0
|
||||
- libbrotlidec=1.1.0
|
||||
- libbrotlienc=1.1.0
|
||||
- libcblas=3.9.0
|
||||
- libcxx=16.0.6
|
||||
- libdeflate=1.19
|
||||
@ -131,14 +140,18 @@ dependencies:
|
||||
- libxcb=1.15
|
||||
- libxml2=2.12.3
|
||||
- libzlib=1.2.13
|
||||
- lightning-utilities=0.10.0
|
||||
- llvm-openmp=17.0.6
|
||||
- markdown-it-py=3.0.0
|
||||
- markupsafe=2.1.3
|
||||
- matplotlib=3.8.2
|
||||
- matplotlib-base=3.8.2
|
||||
- mdurl=0.1.0
|
||||
- mpc=1.3.1
|
||||
- mpfr=4.2.1
|
||||
- mpmath=1.3.0
|
||||
- multidict=6.0.4
|
||||
- munkres=1.1.4
|
||||
- nanotime=0.5.2
|
||||
- ncurses=6.4
|
||||
- nettle=3.9.1
|
||||
@ -206,12 +219,15 @@ dependencies:
|
||||
- tk=8.6.13
|
||||
- tomlkit=0.12.3
|
||||
- torchaudio=2.1.2
|
||||
- torchmetrics=1.2.1
|
||||
- torchvision=0.16.2
|
||||
- tornado=6.3.3
|
||||
- tqdm=4.66.1
|
||||
- typer=0.9.0
|
||||
- typing-extensions=4.9.0
|
||||
- typing_extensions=4.9.0
|
||||
- tzdata=2023d
|
||||
- unicodedata2=15.1.0
|
||||
- urllib3=1.26.18
|
||||
- vine=5.0.0
|
||||
- voluptuous=0.14.1
|
||||
|
30
evaluate.py
30
evaluate.py
@ -3,9 +3,15 @@
|
||||
# author: deng
|
||||
# date : 20231228
|
||||
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from dvclive import Live
|
||||
|
||||
from utils.dataset import ProcessedDataset
|
||||
|
||||
|
||||
def evaluate(params_path: str = 'params.yaml') -> None:
|
||||
@ -17,6 +23,30 @@ def evaluate(params_path: str = 'params.yaml') -> None:
|
||||
|
||||
with open(params_path, 'r') as f:
|
||||
params = yaml.safe_load(f)
|
||||
data_dir = Path(params['evaluate']['data_dir'])
|
||||
|
||||
test_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('test'))
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=64)
|
||||
|
||||
device = torch.device('mps')
|
||||
net = torch.load('model.pt')
|
||||
net.to(device)
|
||||
net.eval()
|
||||
|
||||
metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted')
|
||||
metric.to(device)
|
||||
|
||||
with Live(dir='dvclive/eval', report='md') as live:
|
||||
live.log_params(params['evaluate'])
|
||||
|
||||
for data in test_dataloader:
|
||||
inputs, labels = data[0].to(device), data[1].to(device)
|
||||
outputs = net(inputs)
|
||||
_ = metric(outputs, labels)
|
||||
test_acc = metric.compute()
|
||||
|
||||
print(f'test_acc:{test_acc}')
|
||||
live.log_metric('test_acc', float(test_acc.cpu()))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,3 +1,12 @@
|
||||
prepare:
|
||||
data_dir: data/raw
|
||||
save_dir: data/processed
|
||||
train_valid_split: [0.7, 0.3]
|
||||
random_seed: 0
|
||||
train:
|
||||
data_dir: data/processed
|
||||
epochs: 5
|
||||
batch_size: 256
|
||||
learning_rate: 0.00005
|
||||
evaluate:
|
||||
data_dir: data/processed
|
76
prepare.py
76
prepare.py
@ -3,9 +3,13 @@
|
||||
# author: deng
|
||||
# date : 20231228
|
||||
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
import random
|
||||
import pickle
|
||||
import yaml
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def prepare(params_path: str = 'params.yaml') -> None:
|
||||
@ -17,7 +21,75 @@ def prepare(params_path: str = 'params.yaml') -> None:
|
||||
|
||||
with open(params_path, 'r') as f:
|
||||
params = yaml.safe_load(f)
|
||||
data_dir = Path(params['prepare']['data_dir'])
|
||||
save_dir = Path(params['prepare']['save_dir'])
|
||||
train_valid_split = params['prepare']['train_valid_split']
|
||||
random_seed = params['prepare']['random_seed']
|
||||
|
||||
train_dir = save_dir.joinpath('train')
|
||||
valid_dir = save_dir.joinpath('valid')
|
||||
test_dir = save_dir.joinpath('test')
|
||||
|
||||
if train_dir.is_dir():
|
||||
rmtree(train_dir)
|
||||
train_dir.mkdir()
|
||||
if valid_dir.is_dir():
|
||||
rmtree(valid_dir)
|
||||
valid_dir.mkdir()
|
||||
if test_dir.is_dir():
|
||||
rmtree(test_dir)
|
||||
test_dir.mkdir()
|
||||
|
||||
# Process training data
|
||||
ids = list(range(50000))
|
||||
random.Random(random_seed).shuffle(ids)
|
||||
train_ids = ids[:int(50000 * train_valid_split[0])]
|
||||
valid_ids = ids[int(50000 * train_valid_split[0]):]
|
||||
|
||||
current_id, train_count, valid_count = 0, 0, 0
|
||||
cifar_10_dir = data_dir.joinpath('cifar-10-batches-py')
|
||||
for data_path in cifar_10_dir.glob('data_batch_*'):
|
||||
with open(data_path, 'rb') as f:
|
||||
data = pickle.load(f, encoding='bytes')
|
||||
for i, label in enumerate(data[b'labels']):
|
||||
x = data[b'data'][i]
|
||||
x = x.reshape(3, 32, 32)
|
||||
x = x / 255
|
||||
x = x.astype(np.float32) # mps does not support float64
|
||||
y = np.zeros(10, dtype=np.float32)
|
||||
y[label] = 1.
|
||||
if current_id in train_ids:
|
||||
npz_path = train_dir.joinpath(f'{train_count}.npz')
|
||||
train_ids.remove(current_id)
|
||||
train_count += 1
|
||||
else:
|
||||
npz_path = valid_dir.joinpath(f'{valid_count}.npz')
|
||||
valid_ids.remove(current_id)
|
||||
valid_count += 1
|
||||
np.savez_compressed(
|
||||
npz_path,
|
||||
x=x,
|
||||
y=y
|
||||
)
|
||||
current_id += 1
|
||||
|
||||
# Process testing data
|
||||
data_path = cifar_10_dir.joinpath('test_batch')
|
||||
with open(data_path, 'rb') as f:
|
||||
data = pickle.load(f, encoding='bytes')
|
||||
for i, label in enumerate(data[b'labels']):
|
||||
x = data[b'data'][i]
|
||||
x = x.reshape(3, 32, 32)
|
||||
x = x / 255
|
||||
x = x.astype(np.float32)
|
||||
y = np.zeros(10, dtype=np.float32)
|
||||
y[label] = 1.
|
||||
np.savez_compressed(
|
||||
save_dir.joinpath('test', f'{i}.npz'),
|
||||
x=x,
|
||||
y=y
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
prepare()
|
||||
prepare('params.yaml')
|
||||
|
78
train.py
78
train.py
@ -3,21 +3,97 @@
|
||||
# author: deng
|
||||
# date : 20231228
|
||||
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
import torch
|
||||
from rich.progress import track
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.models import resnet50
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from dvclive import Live
|
||||
|
||||
from utils.dataset import ProcessedDataset
|
||||
|
||||
|
||||
def train(params_path: str = 'params.yaml') -> None:
|
||||
"""Train a simple model using Pytorch
|
||||
|
||||
Args:
|
||||
params_path (str, optional): path of config yaml. Defaults to 'params.yaml'.
|
||||
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
|
||||
"""
|
||||
|
||||
with open(params_path, 'r') as f:
|
||||
params = yaml.safe_load(f)
|
||||
data_dir = Path(params['train']['data_dir'])
|
||||
epochs = params['train']['epochs']
|
||||
batch_size = params['train']['batch_size']
|
||||
learning_rate = params['train']['learning_rate']
|
||||
|
||||
train_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('train'))
|
||||
valid_dataset = ProcessedDataset(dataset_dir=data_dir.joinpath('valid'))
|
||||
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=2)
|
||||
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=2)
|
||||
|
||||
device = torch.device('mps')
|
||||
net = torch.nn.Sequential(
|
||||
resnet50(weights='IMAGENET1K_V1'),
|
||||
torch.nn.Linear(1000, 10)
|
||||
)
|
||||
net.to(device)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
|
||||
metric = MulticlassAccuracy(num_classes=10, top_k=1, average='weighted')
|
||||
metric.to(device)
|
||||
|
||||
with Live(dir='dvclive/train', report='md') as live:
|
||||
live.log_params(params['train'])
|
||||
|
||||
for epoch in track(range(epochs), 'Epochs'):
|
||||
|
||||
net.train()
|
||||
train_loss = 0.
|
||||
for data in train_dataloader:
|
||||
inputs, labels = data[0].to(device), data[1].to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
train_loss += loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
_ = metric(outputs, labels)
|
||||
train_loss /= len(train_dataloader)
|
||||
train_acc = metric.compute()
|
||||
metric.reset()
|
||||
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
valid_loss = 0.
|
||||
for data in valid_dataloader:
|
||||
inputs, labels = data[0].to(device), data[1].to(device)
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
valid_loss += loss
|
||||
_ = metric(outputs, labels)
|
||||
valid_loss /= len(valid_dataloader)
|
||||
valid_acc = metric.compute()
|
||||
metric.reset()
|
||||
|
||||
print(
|
||||
f'Epoch {epoch} - train_loss:{train_loss} train_acc:{train_acc} '
|
||||
f'valid_loss:{valid_loss} valid_acc:{valid_acc}'
|
||||
)
|
||||
live.log_param('epoch', epoch)
|
||||
live.log_metric('train_loss', float(train_loss.cpu()))
|
||||
live.log_metric('train_acc', float(train_acc.cpu()))
|
||||
live.log_metric('valid_loss', float(valid_loss.cpu()))
|
||||
live.log_metric('valid_acc', float(valid_acc.cpu()))
|
||||
live.next_step()
|
||||
|
||||
torch.save(net, 'model.pt')
|
||||
live.log_artifact('model.pt', type='model', name='resnet50')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
26
utils/dataset.py
Normal file
26
utils/dataset.py
Normal file
@ -0,0 +1,26 @@
|
||||
# prepare.py
|
||||
#
|
||||
# author: deng
|
||||
# date : 20231229
|
||||
|
||||
from pathlib import PosixPath
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class ProcessedDataset(Dataset):
|
||||
""""Load processed data"""
|
||||
def __init__(self, dataset_dir: PosixPath):
|
||||
self.dataset_dir = dataset_dir
|
||||
self.file_paths = list(self.dataset_dir.glob('*.npz'))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
npz = np.load(self.file_paths[idx])
|
||||
x = torch.from_numpy(npz['x'])
|
||||
y = torch.from_numpy(npz['y'])
|
||||
return x, y
|
Reference in New Issue
Block a user