# prepare.py # # author: deng # date : 20231228 from pathlib import Path from shutil import rmtree import random import pickle import yaml import numpy as np def prepare(params_path: str = 'params.yaml') -> None: """Preprocess data save as npz for model training Args: params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'. """ with open(params_path, encoding='utf-8') as f: params = yaml.safe_load(f) data_dir = Path(params['prepare']['data_dir']) save_dir = Path(params['prepare']['save_dir']) train_valid_split = params['prepare']['train_valid_split'] random_seed = params['prepare']['random_seed'] train_dir = save_dir.joinpath('train') valid_dir = save_dir.joinpath('valid') test_dir = save_dir.joinpath('test') if train_dir.is_dir(): rmtree(train_dir) train_dir.mkdir() if valid_dir.is_dir(): rmtree(valid_dir) valid_dir.mkdir() if test_dir.is_dir(): rmtree(test_dir) test_dir.mkdir() # Process training data ids = list(range(50000)) random.Random(random_seed).shuffle(ids) train_ids = ids[:int(50000 * train_valid_split[0])] valid_ids = ids[int(50000 * train_valid_split[0]):] current_id, train_count, valid_count = 0, 0, 0 cifar_10_dir = data_dir.joinpath('cifar-10-batches-py') for data_path in cifar_10_dir.glob('data_batch_*'): with open(data_path, 'rb') as f: data = pickle.load(f, encoding='bytes') for i, label in enumerate(data[b'labels']): x = data[b'data'][i] x = x.reshape(3, 32, 32) x = x / 255 x = x.astype(np.float32) # mps does not support float64 y = np.zeros(10, dtype=np.float32) y[label] = 1. if current_id in train_ids: npz_path = train_dir.joinpath(f'{train_count}.npz') train_ids.remove(current_id) train_count += 1 else: npz_path = valid_dir.joinpath(f'{valid_count}.npz') valid_ids.remove(current_id) valid_count += 1 np.savez_compressed( npz_path, x=x, y=y ) current_id += 1 # Process testing data data_path = cifar_10_dir.joinpath('test_batch') with open(data_path, 'rb') as f: data = pickle.load(f, encoding='bytes') for i, label in enumerate(data[b'labels']): x = data[b'data'][i] x = x.reshape(3, 32, 32) x = x / 255 x = x.astype(np.float32) y = np.zeros(10, dtype=np.float32) y[label] = 1. np.savez_compressed( save_dir.joinpath('test', f'{i}.npz'), x=x, y=y ) if __name__ == '__main__': prepare('params.yaml')