implement a sample data preparation
This commit is contained in:
@ -1,3 +1,5 @@
|
||||
# Add patterns of files dvc should ignore, which could improve
|
||||
# the performance. Learn more at
|
||||
# https://dvc.org/doc/user-guide/dvcignore
|
||||
|
||||
quickdraw_bot/data/processed
|
||||
@ -35,4 +35,8 @@ target-version = "py313"
|
||||
select = ["E", "F", "I"]
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "single"
|
||||
quote-style = 'single'
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["."]
|
||||
|
||||
13
quickdraw_bot/assets/config.yaml
Normal file
13
quickdraw_bot/assets/config.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
prepare:
|
||||
data_dir: ./data
|
||||
num_of_img_per_class: 10000
|
||||
data_split:
|
||||
train: 0.8
|
||||
valid: 0.1
|
||||
test: 0.1
|
||||
random_seed: 1
|
||||
train:
|
||||
random_seed: 1
|
||||
eval:
|
||||
random_seed: 1
|
||||
deploy:
|
||||
@ -0,0 +1,67 @@
|
||||
# prepare.py
|
||||
#
|
||||
# author: deng
|
||||
# date : 20260616
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from quickdraw_bot.utils.utils import load_config
|
||||
|
||||
|
||||
class Prepare:
|
||||
def __init__(self, config_path: str = './assets/config.yaml'):
|
||||
self.config = load_config(config_path)['prepare']
|
||||
self.set_random_seed()
|
||||
|
||||
def set_random_seed(self):
|
||||
random.seed(self.config['random_seed'])
|
||||
np.random.seed(self.config['random_seed'])
|
||||
|
||||
def load_dataset(self) -> dict[str, np.ndarray]:
|
||||
data: dict[str, np.ndarray] = {}
|
||||
raw_data_dir = Path(self.config['data_dir']) / 'raw'
|
||||
for npy_file in raw_data_dir.glob('*.npy'):
|
||||
class_name = npy_file.stem
|
||||
images = np.load(npy_file) # shape: (N, 784)
|
||||
images = images.reshape(-1, 1, 28, 28) # shape: (N, 1, 28, 28)
|
||||
images = images.astype(np.int8)
|
||||
if images.shape[0] < self.config['num_of_img_per_class']:
|
||||
print(f'Class {class_name} has less than {self.config["num_of_img_per_class"]} samples, keep all')
|
||||
data[class_name] = images
|
||||
else:
|
||||
random_indice = np.random.choice(images.shape[0], self.config['num_of_img_per_class'], replace=False)
|
||||
data[class_name] = images[random_indice]
|
||||
return data
|
||||
|
||||
def split_data(self, data: dict[str, np.ndarray]) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
|
||||
sets: dict[str, dict[str, list[np.ndarray]]] = {}
|
||||
weights = {name: weight for name, weight in self.config['data_split'].items()}
|
||||
if sum(weights.values()) != 1.0:
|
||||
raise ValueError('Sum of data_split weights must be 1.0')
|
||||
for class_name, images in data.items():
|
||||
for image in images:
|
||||
selection = np.random.choice(list(weights.keys()), p=list(weights.values()))
|
||||
if selection not in sets:
|
||||
sets[selection] = {}
|
||||
if class_name not in sets[selection]:
|
||||
sets[selection][class_name] = []
|
||||
sets[selection][class_name].append(image)
|
||||
return sets
|
||||
|
||||
def save_npz(self, sets: dict[str, dict[str, list[np.ndarray]]]) -> None:
|
||||
save_dir = Path(self.config['data_dir']) / 'processed'
|
||||
save_dir.mkdir(exist_ok=True)
|
||||
for usage, data_dict in sets.items():
|
||||
np.savez(f'{save_dir}/{usage}.npz', **data_dict)
|
||||
|
||||
def run(self):
|
||||
data = self.load_dataset()
|
||||
sets = self.split_data(data)
|
||||
self.save_npz(sets)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
Prepare().run()
|
||||
|
||||
0
quickdraw_bot/utils/__init__.py
Normal file
0
quickdraw_bot/utils/__init__.py
Normal file
8
quickdraw_bot/utils/utils.py
Normal file
8
quickdraw_bot/utils/utils.py
Normal file
@ -0,0 +1,8 @@
|
||||
# utils.py
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def load_config(config_path: str) -> dict:
|
||||
with open(config_path, 'r') as f:
|
||||
return yaml.safe_load(f)
|
||||
Reference in New Issue
Block a user