From 42d12922ee7ea27eb9f95c8eaf50f070bbca5460 Mon Sep 17 00:00:00 2001 From: deng Date: Tue, 16 Jun 2026 21:26:26 +0800 Subject: [PATCH] implement a sample data preparation --- .dvcignore | 2 + pyproject.toml | 6 ++- quickdraw_bot/assets/config.yaml | 13 +++++++ quickdraw_bot/prepare.py | 67 ++++++++++++++++++++++++++++++++ quickdraw_bot/utils/__init__.py | 0 quickdraw_bot/utils/utils.py | 8 ++++ 6 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 quickdraw_bot/assets/config.yaml create mode 100644 quickdraw_bot/utils/__init__.py create mode 100644 quickdraw_bot/utils/utils.py diff --git a/.dvcignore b/.dvcignore index 5197305..85d1cd7 100644 --- a/.dvcignore +++ b/.dvcignore @@ -1,3 +1,5 @@ # Add patterns of files dvc should ignore, which could improve # the performance. Learn more at # https://dvc.org/doc/user-guide/dvcignore + +quickdraw_bot/data/processed \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 5e8f449..fb60067 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,4 +35,8 @@ target-version = "py313" select = ["E", "F", "I"] [tool.ruff.format] -quote-style = "single" +quote-style = 'single' + +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["."] diff --git a/quickdraw_bot/assets/config.yaml b/quickdraw_bot/assets/config.yaml new file mode 100644 index 0000000..d97616f --- /dev/null +++ b/quickdraw_bot/assets/config.yaml @@ -0,0 +1,13 @@ +prepare: + data_dir: ./data + num_of_img_per_class: 10000 + data_split: + train: 0.8 + valid: 0.1 + test: 0.1 + random_seed: 1 +train: + random_seed: 1 +eval: + random_seed: 1 +deploy: \ No newline at end of file diff --git a/quickdraw_bot/prepare.py b/quickdraw_bot/prepare.py index e69de29..c0f1303 100644 --- a/quickdraw_bot/prepare.py +++ b/quickdraw_bot/prepare.py @@ -0,0 +1,67 @@ +# prepare.py +# +# author: deng +# date : 20260616 + +import random +from pathlib import Path + +import numpy as np + +from quickdraw_bot.utils.utils import load_config + + +class Prepare: + def __init__(self, config_path: str = './assets/config.yaml'): + self.config = load_config(config_path)['prepare'] + self.set_random_seed() + + def set_random_seed(self): + random.seed(self.config['random_seed']) + np.random.seed(self.config['random_seed']) + + def load_dataset(self) -> dict[str, np.ndarray]: + data: dict[str, np.ndarray] = {} + raw_data_dir = Path(self.config['data_dir']) / 'raw' + for npy_file in raw_data_dir.glob('*.npy'): + class_name = npy_file.stem + images = np.load(npy_file) # shape: (N, 784) + images = images.reshape(-1, 1, 28, 28) # shape: (N, 1, 28, 28) + images = images.astype(np.int8) + if images.shape[0] < self.config['num_of_img_per_class']: + print(f'Class {class_name} has less than {self.config["num_of_img_per_class"]} samples, keep all') + data[class_name] = images + else: + random_indice = np.random.choice(images.shape[0], self.config['num_of_img_per_class'], replace=False) + data[class_name] = images[random_indice] + return data + + def split_data(self, data: dict[str, np.ndarray]) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: + sets: dict[str, dict[str, list[np.ndarray]]] = {} + weights = {name: weight for name, weight in self.config['data_split'].items()} + if sum(weights.values()) != 1.0: + raise ValueError('Sum of data_split weights must be 1.0') + for class_name, images in data.items(): + for image in images: + selection = np.random.choice(list(weights.keys()), p=list(weights.values())) + if selection not in sets: + sets[selection] = {} + if class_name not in sets[selection]: + sets[selection][class_name] = [] + sets[selection][class_name].append(image) + return sets + + def save_npz(self, sets: dict[str, dict[str, list[np.ndarray]]]) -> None: + save_dir = Path(self.config['data_dir']) / 'processed' + save_dir.mkdir(exist_ok=True) + for usage, data_dict in sets.items(): + np.savez(f'{save_dir}/{usage}.npz', **data_dict) + + def run(self): + data = self.load_dataset() + sets = self.split_data(data) + self.save_npz(sets) + + +if __name__ == '__main__': + Prepare().run() diff --git a/quickdraw_bot/utils/__init__.py b/quickdraw_bot/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/quickdraw_bot/utils/utils.py b/quickdraw_bot/utils/utils.py new file mode 100644 index 0000000..467ed63 --- /dev/null +++ b/quickdraw_bot/utils/utils.py @@ -0,0 +1,8 @@ +# utils.py + +import yaml + + +def load_config(config_path: str) -> dict: + with open(config_path, 'r') as f: + return yaml.safe_load(f) \ No newline at end of file