96 lines
2.9 KiB
Python
96 lines
2.9 KiB
Python
# prepare.py
|
|
#
|
|
# author: deng
|
|
# date : 20231228
|
|
|
|
from pathlib import Path
|
|
from shutil import rmtree
|
|
import random
|
|
import pickle
|
|
import yaml
|
|
|
|
import numpy as np
|
|
|
|
|
|
def prepare(params_path: str = 'params.yaml') -> None:
|
|
"""Preprocess data save as npz for model training
|
|
|
|
Args:
|
|
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
|
|
"""
|
|
|
|
with open(params_path, encoding='utf-8') as f:
|
|
params = yaml.safe_load(f)
|
|
data_dir = Path(params['prepare']['data_dir'])
|
|
save_dir = Path(params['prepare']['save_dir'])
|
|
train_valid_split = params['prepare']['train_valid_split']
|
|
random_seed = params['prepare']['random_seed']
|
|
|
|
train_dir = save_dir.joinpath('train')
|
|
valid_dir = save_dir.joinpath('valid')
|
|
test_dir = save_dir.joinpath('test')
|
|
|
|
if train_dir.is_dir():
|
|
rmtree(train_dir)
|
|
train_dir.mkdir()
|
|
if valid_dir.is_dir():
|
|
rmtree(valid_dir)
|
|
valid_dir.mkdir()
|
|
if test_dir.is_dir():
|
|
rmtree(test_dir)
|
|
test_dir.mkdir()
|
|
|
|
# Process training data
|
|
ids = list(range(50000))
|
|
random.Random(random_seed).shuffle(ids)
|
|
train_ids = ids[:int(50000 * train_valid_split[0])]
|
|
valid_ids = ids[int(50000 * train_valid_split[0]):]
|
|
|
|
current_id, train_count, valid_count = 0, 0, 0
|
|
cifar_10_dir = data_dir.joinpath('cifar-10-batches-py')
|
|
for data_path in cifar_10_dir.glob('data_batch_*'):
|
|
with open(data_path, 'rb') as f:
|
|
data = pickle.load(f, encoding='bytes')
|
|
for i, label in enumerate(data[b'labels']):
|
|
x = data[b'data'][i]
|
|
x = x.reshape(3, 32, 32)
|
|
x = x / 255
|
|
x = x.astype(np.float32) # mps does not support float64
|
|
y = np.zeros(10, dtype=np.float32)
|
|
y[label] = 1.
|
|
if current_id in train_ids:
|
|
npz_path = train_dir.joinpath(f'{train_count}.npz')
|
|
train_ids.remove(current_id)
|
|
train_count += 1
|
|
else:
|
|
npz_path = valid_dir.joinpath(f'{valid_count}.npz')
|
|
valid_ids.remove(current_id)
|
|
valid_count += 1
|
|
np.savez_compressed(
|
|
npz_path,
|
|
x=x,
|
|
y=y
|
|
)
|
|
current_id += 1
|
|
|
|
# Process testing data
|
|
data_path = cifar_10_dir.joinpath('test_batch')
|
|
with open(data_path, 'rb') as f:
|
|
data = pickle.load(f, encoding='bytes')
|
|
for i, label in enumerate(data[b'labels']):
|
|
x = data[b'data'][i]
|
|
x = x.reshape(3, 32, 32)
|
|
x = x / 255
|
|
x = x.astype(np.float32)
|
|
y = np.zeros(10, dtype=np.float32)
|
|
y[label] = 1.
|
|
np.savez_compressed(
|
|
save_dir.joinpath('test', f'{i}.npz'),
|
|
x=x,
|
|
y=y
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
prepare('params.yaml')
|