implement a sample data preparation
This commit is contained in:
@ -1,3 +1,5 @@
|
|||||||
# Add patterns of files dvc should ignore, which could improve
|
# Add patterns of files dvc should ignore, which could improve
|
||||||
# the performance. Learn more at
|
# the performance. Learn more at
|
||||||
# https://dvc.org/doc/user-guide/dvcignore
|
# https://dvc.org/doc/user-guide/dvcignore
|
||||||
|
|
||||||
|
quickdraw_bot/data/processed
|
||||||
@ -35,4 +35,8 @@ target-version = "py313"
|
|||||||
select = ["E", "F", "I"]
|
select = ["E", "F", "I"]
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
quote-style = "single"
|
quote-style = 'single'
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
pythonpath = ["."]
|
||||||
|
|||||||
13
quickdraw_bot/assets/config.yaml
Normal file
13
quickdraw_bot/assets/config.yaml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
prepare:
|
||||||
|
data_dir: ./data
|
||||||
|
num_of_img_per_class: 10000
|
||||||
|
data_split:
|
||||||
|
train: 0.8
|
||||||
|
valid: 0.1
|
||||||
|
test: 0.1
|
||||||
|
random_seed: 1
|
||||||
|
train:
|
||||||
|
random_seed: 1
|
||||||
|
eval:
|
||||||
|
random_seed: 1
|
||||||
|
deploy:
|
||||||
@ -0,0 +1,67 @@
|
|||||||
|
# prepare.py
|
||||||
|
#
|
||||||
|
# author: deng
|
||||||
|
# date : 20260616
|
||||||
|
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from quickdraw_bot.utils.utils import load_config
|
||||||
|
|
||||||
|
|
||||||
|
class Prepare:
|
||||||
|
def __init__(self, config_path: str = './assets/config.yaml'):
|
||||||
|
self.config = load_config(config_path)['prepare']
|
||||||
|
self.set_random_seed()
|
||||||
|
|
||||||
|
def set_random_seed(self):
|
||||||
|
random.seed(self.config['random_seed'])
|
||||||
|
np.random.seed(self.config['random_seed'])
|
||||||
|
|
||||||
|
def load_dataset(self) -> dict[str, np.ndarray]:
|
||||||
|
data: dict[str, np.ndarray] = {}
|
||||||
|
raw_data_dir = Path(self.config['data_dir']) / 'raw'
|
||||||
|
for npy_file in raw_data_dir.glob('*.npy'):
|
||||||
|
class_name = npy_file.stem
|
||||||
|
images = np.load(npy_file) # shape: (N, 784)
|
||||||
|
images = images.reshape(-1, 1, 28, 28) # shape: (N, 1, 28, 28)
|
||||||
|
images = images.astype(np.int8)
|
||||||
|
if images.shape[0] < self.config['num_of_img_per_class']:
|
||||||
|
print(f'Class {class_name} has less than {self.config["num_of_img_per_class"]} samples, keep all')
|
||||||
|
data[class_name] = images
|
||||||
|
else:
|
||||||
|
random_indice = np.random.choice(images.shape[0], self.config['num_of_img_per_class'], replace=False)
|
||||||
|
data[class_name] = images[random_indice]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def split_data(self, data: dict[str, np.ndarray]) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
|
||||||
|
sets: dict[str, dict[str, list[np.ndarray]]] = {}
|
||||||
|
weights = {name: weight for name, weight in self.config['data_split'].items()}
|
||||||
|
if sum(weights.values()) != 1.0:
|
||||||
|
raise ValueError('Sum of data_split weights must be 1.0')
|
||||||
|
for class_name, images in data.items():
|
||||||
|
for image in images:
|
||||||
|
selection = np.random.choice(list(weights.keys()), p=list(weights.values()))
|
||||||
|
if selection not in sets:
|
||||||
|
sets[selection] = {}
|
||||||
|
if class_name not in sets[selection]:
|
||||||
|
sets[selection][class_name] = []
|
||||||
|
sets[selection][class_name].append(image)
|
||||||
|
return sets
|
||||||
|
|
||||||
|
def save_npz(self, sets: dict[str, dict[str, list[np.ndarray]]]) -> None:
|
||||||
|
save_dir = Path(self.config['data_dir']) / 'processed'
|
||||||
|
save_dir.mkdir(exist_ok=True)
|
||||||
|
for usage, data_dict in sets.items():
|
||||||
|
np.savez(f'{save_dir}/{usage}.npz', **data_dict)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
data = self.load_dataset()
|
||||||
|
sets = self.split_data(data)
|
||||||
|
self.save_npz(sets)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
Prepare().run()
|
||||||
|
|||||||
0
quickdraw_bot/utils/__init__.py
Normal file
0
quickdraw_bot/utils/__init__.py
Normal file
8
quickdraw_bot/utils/utils.py
Normal file
8
quickdraw_bot/utils/utils.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# utils.py
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_path: str) -> dict:
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
Reference in New Issue
Block a user