test_dvc/prepare.py

96 lines
2.9 KiB
Python

# prepare.py
#
# author: deng
# date : 20231228
from pathlib import Path
from shutil import rmtree
import random
import pickle
import yaml
import numpy as np
def prepare(params_path: str = 'params.yaml') -> None:
"""Preprocess data save as npz for model training
Args:
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
"""
with open(params_path, encoding='utf-8') as f:
params = yaml.safe_load(f)
data_dir = Path(params['prepare']['data_dir'])
save_dir = Path(params['prepare']['save_dir'])
train_valid_split = params['prepare']['train_valid_split']
random_seed = params['prepare']['random_seed']
train_dir = save_dir.joinpath('train')
valid_dir = save_dir.joinpath('valid')
test_dir = save_dir.joinpath('test')
if train_dir.is_dir():
rmtree(train_dir)
train_dir.mkdir()
if valid_dir.is_dir():
rmtree(valid_dir)
valid_dir.mkdir()
if test_dir.is_dir():
rmtree(test_dir)
test_dir.mkdir()
# Process training data
ids = list(range(50000))
random.Random(random_seed).shuffle(ids)
train_ids = ids[:int(50000 * train_valid_split[0])]
valid_ids = ids[int(50000 * train_valid_split[0]):]
current_id, train_count, valid_count = 0, 0, 0
cifar_10_dir = data_dir.joinpath('cifar-10-batches-py')
for data_path in cifar_10_dir.glob('data_batch_*'):
with open(data_path, 'rb') as f:
data = pickle.load(f, encoding='bytes')
for i, label in enumerate(data[b'labels']):
x = data[b'data'][i]
x = x.reshape(3, 32, 32)
x = x / 255
x = x.astype(np.float32) # mps does not support float64
y = np.zeros(10, dtype=np.float32)
y[label] = 1.
if current_id in train_ids:
npz_path = train_dir.joinpath(f'{train_count}.npz')
train_ids.remove(current_id)
train_count += 1
else:
npz_path = valid_dir.joinpath(f'{valid_count}.npz')
valid_ids.remove(current_id)
valid_count += 1
np.savez_compressed(
npz_path,
x=x,
y=y
)
current_id += 1
# Process testing data
data_path = cifar_10_dir.joinpath('test_batch')
with open(data_path, 'rb') as f:
data = pickle.load(f, encoding='bytes')
for i, label in enumerate(data[b'labels']):
x = data[b'data'][i]
x = x.reshape(3, 32, 32)
x = x / 255
x = x.astype(np.float32)
y = np.zeros(10, dtype=np.float32)
y[label] = 1.
np.savez_compressed(
save_dir.joinpath('test', f'{i}.npz'),
x=x,
y=y
)
if __name__ == '__main__':
prepare('params.yaml')