Compare commits

..

2 Commits

Author SHA1 Message Date
38f81be722 log data 2026-06-16 21:29:53 +08:00
42d12922ee implement a sample data preparation 2026-06-16 21:26:26 +08:00
7 changed files with 98 additions and 4 deletions

View File

@ -1,3 +1,5 @@
# Add patterns of files dvc should ignore, which could improve # Add patterns of files dvc should ignore, which could improve
# the performance. Learn more at # the performance. Learn more at
# https://dvc.org/doc/user-guide/dvcignore # https://dvc.org/doc/user-guide/dvcignore
quickdraw_bot/data/processed

View File

@ -35,4 +35,8 @@ target-version = "py313"
select = ["E", "F", "I"] select = ["E", "F", "I"]
[tool.ruff.format] [tool.ruff.format]
quote-style = "single" quote-style = 'single'
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["."]

View File

@ -0,0 +1,13 @@
prepare:
data_dir: ./data
num_of_img_per_class: 10000
data_split:
train: 0.8
valid: 0.1
test: 0.1
random_seed: 1
train:
random_seed: 1
eval:
random_seed: 1
deploy:

View File

@ -1,6 +1,6 @@
outs: outs:
- md5: f46c877020a059df22fc383cd2f0a0bb.dir - md5: dd93e82604be92816a10bfb1e709edf2.dir
size: 3 size: 2306020784
nfiles: 1 nfiles: 20
hash: md5 hash: md5
path: data path: data

View File

@ -0,0 +1,67 @@
# prepare.py
#
# author: deng
# date : 20260616
import random
from pathlib import Path
import numpy as np
from quickdraw_bot.utils.utils import load_config
class Prepare:
def __init__(self, config_path: str = './assets/config.yaml'):
self.config = load_config(config_path)['prepare']
self.set_random_seed()
def set_random_seed(self):
random.seed(self.config['random_seed'])
np.random.seed(self.config['random_seed'])
def load_dataset(self) -> dict[str, np.ndarray]:
data: dict[str, np.ndarray] = {}
raw_data_dir = Path(self.config['data_dir']) / 'raw'
for npy_file in raw_data_dir.glob('*.npy'):
class_name = npy_file.stem
images = np.load(npy_file) # shape: (N, 784)
images = images.reshape(-1, 1, 28, 28) # shape: (N, 1, 28, 28)
images = images.astype(np.int8)
if images.shape[0] < self.config['num_of_img_per_class']:
print(f'Class {class_name} has less than {self.config["num_of_img_per_class"]} samples, keep all')
data[class_name] = images
else:
random_indice = np.random.choice(images.shape[0], self.config['num_of_img_per_class'], replace=False)
data[class_name] = images[random_indice]
return data
def split_data(self, data: dict[str, np.ndarray]) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
sets: dict[str, dict[str, list[np.ndarray]]] = {}
weights = {name: weight for name, weight in self.config['data_split'].items()}
if sum(weights.values()) != 1.0:
raise ValueError('Sum of data_split weights must be 1.0')
for class_name, images in data.items():
for image in images:
selection = np.random.choice(list(weights.keys()), p=list(weights.values()))
if selection not in sets:
sets[selection] = {}
if class_name not in sets[selection]:
sets[selection][class_name] = []
sets[selection][class_name].append(image)
return sets
def save_npz(self, sets: dict[str, dict[str, list[np.ndarray]]]) -> None:
save_dir = Path(self.config['data_dir']) / 'processed'
save_dir.mkdir(exist_ok=True)
for usage, data_dict in sets.items():
np.savez(f'{save_dir}/{usage}.npz', **data_dict)
def run(self):
data = self.load_dataset()
sets = self.split_data(data)
self.save_npz(sets)
if __name__ == '__main__':
Prepare().run()

View File

View File

@ -0,0 +1,8 @@
# utils.py
import yaml
def load_config(config_path: str) -> dict:
with open(config_path, 'r') as f:
return yaml.safe_load(f)