From 2f2db72db1a29a2448dc6725d285a215d10b1146 Mon Sep 17 00:00:00 2001 From: deng Date: Thu, 18 Jun 2026 09:16:42 +0800 Subject: [PATCH] [exp] init train --- .gitignore | 8 +- quickdraw_bot/assets/.gitignore | 1 + quickdraw_bot/assets/config.yaml | 15 +- quickdraw_bot/assets/dvc.yaml | 5 + quickdraw_bot/assets/model.pth.dvc | 5 + quickdraw_bot/doc/exp/train/metrics.json | 17 ++ .../train/plots/metrics/train/accuracy.tsv | 31 ++++ .../doc/exp/train/plots/metrics/train/f1.tsv | 31 ++++ .../exp/train/plots/metrics/train/loss.tsv | 31 ++++ .../train/plots/metrics/train/precision.tsv | 31 ++++ .../exp/train/plots/metrics/train/recall.tsv | 31 ++++ .../train/plots/metrics/valid/accuracy.tsv | 31 ++++ .../doc/exp/train/plots/metrics/valid/f1.tsv | 31 ++++ .../exp/train/plots/metrics/valid/loss.tsv | 31 ++++ .../train/plots/metrics/valid/precision.tsv | 31 ++++ .../exp/train/plots/metrics/valid/recall.tsv | 31 ++++ quickdraw_bot/doc/exp/train/report.html | 114 ++++++++++++ quickdraw_bot/prepare.py | 71 ++++--- quickdraw_bot/train.py | 174 ++++++++++++++++++ quickdraw_bot/utils/dataset.py | 88 +++++++++ quickdraw_bot/utils/model.py | 53 ++++++ 21 files changed, 833 insertions(+), 28 deletions(-) create mode 100644 quickdraw_bot/assets/.gitignore create mode 100644 quickdraw_bot/assets/dvc.yaml create mode 100644 quickdraw_bot/assets/model.pth.dvc create mode 100644 quickdraw_bot/doc/exp/train/metrics.json create mode 100644 quickdraw_bot/doc/exp/train/plots/metrics/train/accuracy.tsv create mode 100644 quickdraw_bot/doc/exp/train/plots/metrics/train/f1.tsv create mode 100644 quickdraw_bot/doc/exp/train/plots/metrics/train/loss.tsv create mode 100644 quickdraw_bot/doc/exp/train/plots/metrics/train/precision.tsv create mode 100644 quickdraw_bot/doc/exp/train/plots/metrics/train/recall.tsv create mode 100644 quickdraw_bot/doc/exp/train/plots/metrics/valid/accuracy.tsv create mode 100644 quickdraw_bot/doc/exp/train/plots/metrics/valid/f1.tsv create mode 100644 quickdraw_bot/doc/exp/train/plots/metrics/valid/loss.tsv create mode 100644 quickdraw_bot/doc/exp/train/plots/metrics/valid/precision.tsv create mode 100644 quickdraw_bot/doc/exp/train/plots/metrics/valid/recall.tsv create mode 100644 quickdraw_bot/doc/exp/train/report.html create mode 100644 quickdraw_bot/utils/dataset.py create mode 100644 quickdraw_bot/utils/model.py diff --git a/.gitignore b/.gitignore index b2f3222..44b89e5 100644 --- a/.gitignore +++ b/.gitignore @@ -12,5 +12,11 @@ wheels/ # Dataset quickdraw_bot/data +# Temp files +quickdraw_bot/tmp + # DVC -dvc/config.local \ No newline at end of file +dvc/config.local + +# .DS_Store +.DS_Store \ No newline at end of file diff --git a/quickdraw_bot/assets/.gitignore b/quickdraw_bot/assets/.gitignore new file mode 100644 index 0000000..c225f42 --- /dev/null +++ b/quickdraw_bot/assets/.gitignore @@ -0,0 +1 @@ +/model.pth diff --git a/quickdraw_bot/assets/config.yaml b/quickdraw_bot/assets/config.yaml index d97616f..9d85a6b 100644 --- a/quickdraw_bot/assets/config.yaml +++ b/quickdraw_bot/assets/config.yaml @@ -7,7 +7,20 @@ prepare: test: 0.1 random_seed: 1 train: + device_type: mps + train_npz: ./data/processed/train.npz + valid_npz: ./data/processed/valid.npz + batch_size: 256 + num_of_class: 20 + optimizer_name: sgd # sgd, adam + learning_rate: 0.001 + warmup_epochs: 5 + num_of_epochs: 30 + file_lazy_load: false random_seed: 1 + exp_msg: init train eval: + test_npz: ./data/processed/test.npz random_seed: 1 -deploy: \ No newline at end of file +deploy: + random_seed: 1 \ No newline at end of file diff --git a/quickdraw_bot/assets/dvc.yaml b/quickdraw_bot/assets/dvc.yaml new file mode 100644 index 0000000..6ef9e39 --- /dev/null +++ b/quickdraw_bot/assets/dvc.yaml @@ -0,0 +1,5 @@ +metrics: +- ../doc/exp/train/metrics.json +plots: +- ../doc/exp/train/plots/metrics: + x: step diff --git a/quickdraw_bot/assets/model.pth.dvc b/quickdraw_bot/assets/model.pth.dvc new file mode 100644 index 0000000..8c3eea2 --- /dev/null +++ b/quickdraw_bot/assets/model.pth.dvc @@ -0,0 +1,5 @@ +outs: +- md5: 263f47ef298fee74aed6acc3a316e7ad + size: 1701245 + hash: md5 + path: model.pth diff --git a/quickdraw_bot/doc/exp/train/metrics.json b/quickdraw_bot/doc/exp/train/metrics.json new file mode 100644 index 0000000..9580a0b --- /dev/null +++ b/quickdraw_bot/doc/exp/train/metrics.json @@ -0,0 +1,17 @@ +{ + "train": { + "loss": 1.992799331665039, + "accuracy": 0.4883750081062317, + "precision": 0.48363304138183594, + "recall": 0.4885570704936981, + "f1": 0.4849509298801422 + }, + "valid": { + "loss": 1.4377805100211614, + "accuracy": 0.7260000109672546, + "precision": 0.723875880241394, + "recall": 0.7256932258605957, + "f1": 0.7193899154663086 + }, + "step": 29 +} diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/train/accuracy.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/train/accuracy.tsv new file mode 100644 index 0000000..864aa9f --- /dev/null +++ b/quickdraw_bot/doc/exp/train/plots/metrics/train/accuracy.tsv @@ -0,0 +1,31 @@ +step accuracy +0 0.04570624977350235 +1 0.08645624667406082 +2 0.15463125705718994 +3 0.20809374749660492 +4 0.2542562484741211 +5 0.29279375076293945 +6 0.3235749900341034 +7 0.346756249666214 +8 0.363993763923645 +9 0.3763374984264374 +10 0.3868750035762787 +11 0.39285001158714294 +12 0.400112509727478 +13 0.4029250144958496 +14 0.403425008058548 +15 0.4108937382698059 +16 0.42086875438690186 +17 0.43145623803138733 +18 0.44205623865127563 +19 0.44743749499320984 +20 0.45317500829696655 +21 0.4583125114440918 +22 0.45945000648498535 +23 0.4590874910354614 +24 0.462799996137619 +25 0.46361875534057617 +26 0.47360000014305115 +27 0.47944375872612 +28 0.4831624925136566 +29 0.4883750081062317 diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/train/f1.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/train/f1.tsv new file mode 100644 index 0000000..981e0a5 --- /dev/null +++ b/quickdraw_bot/doc/exp/train/plots/metrics/train/f1.tsv @@ -0,0 +1,31 @@ +step f1 +0 0.03369621932506561 +1 0.08265631645917892 +2 0.14480535686016083 +3 0.1954474151134491 +4 0.24400727450847626 +5 0.2852896749973297 +6 0.3165817856788635 +7 0.34051138162612915 +8 0.3579234480857849 +9 0.3706662654876709 +10 0.3813643753528595 +11 0.38740092515945435 +12 0.3947397768497467 +13 0.3978673219680786 +14 0.3983455300331116 +15 0.4061855971813202 +16 0.41634228825569153 +17 0.4270275831222534 +18 0.4378680884838104 +19 0.4433649480342865 +20 0.44917452335357666 +21 0.45438480377197266 +22 0.45568108558654785 +23 0.45533287525177 +24 0.45924824476242065 +25 0.45992523431777954 +26 0.4698502719402313 +27 0.47596246004104614 +28 0.4798404574394226 +29 0.4849509298801422 diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/train/loss.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/train/loss.tsv new file mode 100644 index 0000000..14c58fc --- /dev/null +++ b/quickdraw_bot/doc/exp/train/plots/metrics/train/loss.tsv @@ -0,0 +1,31 @@ +step loss +0 4.0223345439910885 +1 3.1446908485412597 +2 2.845902690887451 +3 2.7091702819824217 +4 2.5876311504364016 +5 2.4898329914093016 +6 2.4078801975250244 +7 2.3469735752105714 +8 2.3044276081085204 +9 2.2692713054656983 +10 2.2433441226959228 +11 2.2251542106628417 +12 2.2075239395141604 +13 2.199652731704712 +14 2.2003354278564453 +15 2.1802532527923586 +16 2.155798588562012 +17 2.127256035041809 +18 2.1059615146636963 +19 2.0936844367980956 +20 2.077068444442749 +21 2.064746246147156 +22 2.0619221328735353 +23 2.058998599433899 +24 2.0547973834991455 +25 2.0490142166137697 +26 2.0298474113464358 +27 2.0157182762145998 +28 2.006533228492737 +29 1.992799331665039 diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/train/precision.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/train/precision.tsv new file mode 100644 index 0000000..c862d20 --- /dev/null +++ b/quickdraw_bot/doc/exp/train/plots/metrics/train/precision.tsv @@ -0,0 +1,31 @@ +step precision +0 0.055038902908563614 +1 0.08688722550868988 +2 0.14544154703617096 +3 0.19598175585269928 +4 0.24326808750629425 +5 0.28354209661483765 +6 0.3143269121646881 +7 0.33853644132614136 +8 0.3557613492012024 +9 0.3688707649707794 +10 0.37962812185287476 +11 0.3855380415916443 +12 0.3929717540740967 +13 0.3963885009288788 +14 0.39665019512176514 +15 0.40454378724098206 +16 0.41489875316619873 +17 0.42552798986434937 +18 0.43634524941444397 +19 0.44186848402023315 +20 0.4476196765899658 +21 0.4529317021369934 +22 0.4541561007499695 +23 0.45390117168426514 +24 0.45794767141342163 +25 0.45853471755981445 +26 0.46830785274505615 +27 0.4746767580509186 +28 0.47852134704589844 +29 0.48363304138183594 diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/train/recall.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/train/recall.tsv new file mode 100644 index 0000000..e3ea5b0 --- /dev/null +++ b/quickdraw_bot/doc/exp/train/plots/metrics/train/recall.tsv @@ -0,0 +1,31 @@ +step recall +0 0.04566790908575058 +1 0.08645831048488617 +2 0.1547394096851349 +3 0.2082621306180954 +4 0.2544386386871338 +5 0.29297617077827454 +6 0.3237552046775818 +7 0.346926212310791 +8 0.364177942276001 +9 0.3765082359313965 +10 0.3870483636856079 +11 0.39303654432296753 +12 0.4002862572669983 +13 0.403104305267334 +14 0.4036043882369995 +15 0.4110710322856903 +16 0.42104363441467285 +17 0.4316273629665375 +18 0.44224199652671814 +19 0.44761422276496887 +20 0.4533519148826599 +21 0.4584938883781433 +22 0.45962250232696533 +23 0.45926111936569214 +24 0.46298152208328247 +25 0.4637938141822815 +26 0.47378331422805786 +27 0.4796329736709595 +28 0.4833483397960663 +29 0.4885570704936981 diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/valid/accuracy.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/valid/accuracy.tsv new file mode 100644 index 0000000..a185a3e --- /dev/null +++ b/quickdraw_bot/doc/exp/train/plots/metrics/valid/accuracy.tsv @@ -0,0 +1,31 @@ +step accuracy +0 0.04039999842643738 +1 0.1996999979019165 +2 0.3260500133037567 +3 0.4186500012874603 +4 0.49950000643730164 +5 0.5444499850273132 +6 0.5720000267028809 +7 0.5917500257492065 +8 0.6093500256538391 +9 0.6218000054359436 +10 0.6304000020027161 +11 0.6385499835014343 +12 0.6404500007629395 +13 0.6446499824523926 +14 0.6463500261306763 +15 0.6567999720573425 +16 0.6672000288963318 +17 0.6779500246047974 +18 0.6841999888420105 +19 0.6896499991416931 +20 0.6960499882698059 +21 0.6970999836921692 +22 0.7006000280380249 +23 0.70169997215271 +24 0.7010999917984009 +25 0.7071499824523926 +26 0.7135000228881836 +27 0.7184500098228455 +28 0.7211499810218811 +29 0.7260000109672546 diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/valid/f1.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/valid/f1.tsv new file mode 100644 index 0000000..efb4ae1 --- /dev/null +++ b/quickdraw_bot/doc/exp/train/plots/metrics/valid/f1.tsv @@ -0,0 +1,31 @@ +step f1 +0 0.026379385963082314 +1 0.1792948693037033 +2 0.2925271987915039 +3 0.3919405937194824 +4 0.47940874099731445 +5 0.5274306535720825 +6 0.556631863117218 +7 0.576585054397583 +8 0.5947144031524658 +9 0.6082352995872498 +10 0.6178301572799683 +11 0.6262512803077698 +12 0.6278425455093384 +13 0.6331325769424438 +14 0.6344878673553467 +15 0.6453059315681458 +16 0.6570700407028198 +17 0.6685804128646851 +18 0.6745381355285645 +19 0.6804145574569702 +20 0.6874032020568848 +21 0.6882768273353577 +22 0.6920892596244812 +23 0.693403959274292 +24 0.6925287246704102 +25 0.6985740661621094 +26 0.7053770422935486 +27 0.7106728553771973 +28 0.7137280106544495 +29 0.7193899154663086 diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/valid/loss.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/valid/loss.tsv new file mode 100644 index 0000000..a9c6850 --- /dev/null +++ b/quickdraw_bot/doc/exp/train/plots/metrics/valid/loss.tsv @@ -0,0 +1,31 @@ +step loss +0 3.590050941781153 +1 2.7546746489367906 +2 2.5154529040372826 +3 2.2843858411040485 +4 2.0942421291447895 +5 1.9563061029096194 +6 1.8598378127134298 +7 1.7950012200995336 +8 1.7452865039245993 +9 1.7100401875338977 +10 1.6846234753162046 +11 1.6670298757432382 +12 1.6579805748372138 +13 1.6494467530069472 +14 1.6449626789817327 +15 1.6121938424774362 +16 1.5864867349214191 +17 1.5629751878448679 +18 1.5430825224405602 +19 1.5310050943229772 +20 1.5175003311302089 +21 1.5117049654827843 +22 1.5065134658089168 +23 1.5018350715878643 +24 1.5016868627524074 +25 1.4859640613386902 +26 1.4698024339313749 +27 1.4584274548518508 +28 1.4490519852577886 +29 1.4377805100211614 diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/valid/precision.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/valid/precision.tsv new file mode 100644 index 0000000..1065909 --- /dev/null +++ b/quickdraw_bot/doc/exp/train/plots/metrics/valid/precision.tsv @@ -0,0 +1,31 @@ +step precision +0 0.07338898628950119 +1 0.18911437690258026 +2 0.32449209690093994 +3 0.4227263331413269 +4 0.49964380264282227 +5 0.5427083373069763 +6 0.5689945220947266 +7 0.5907210111618042 +8 0.6088747978210449 +9 0.620964527130127 +10 0.6287857294082642 +11 0.6372801065444946 +12 0.6391931772232056 +13 0.643118143081665 +14 0.6445475816726685 +15 0.655581533908844 +16 0.6666973829269409 +17 0.6765724420547485 +18 0.6822240352630615 +19 0.6881765127182007 +20 0.6948975920677185 +21 0.695229172706604 +22 0.6998213529586792 +23 0.7005149126052856 +24 0.6999821662902832 +25 0.7058628797531128 +26 0.7112910151481628 +27 0.7165747880935669 +28 0.7200020551681519 +29 0.723875880241394 diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/valid/recall.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/valid/recall.tsv new file mode 100644 index 0000000..0c5d6f7 --- /dev/null +++ b/quickdraw_bot/doc/exp/train/plots/metrics/valid/recall.tsv @@ -0,0 +1,31 @@ +step recall +0 0.04072221741080284 +1 0.19861623644828796 +2 0.32381999492645264 +3 0.41670864820480347 +4 0.4983140230178833 +5 0.5435963273048401 +6 0.5713088512420654 +7 0.5910568237304688 +8 0.608674168586731 +9 0.6212138533592224 +10 0.6297356486320496 +11 0.6379297971725464 +12 0.6398895382881165 +13 0.6441172361373901 +14 0.6458127498626709 +15 0.6563705205917358 +16 0.6666883230209351 +17 0.6775047779083252 +18 0.6838763952255249 +19 0.6894745826721191 +20 0.6957231163978577 +21 0.6967899203300476 +22 0.7002253532409668 +23 0.7013353109359741 +24 0.7007752656936646 +25 0.706775963306427 +26 0.7130882740020752 +27 0.7181775569915771 +28 0.7208318114280701 +29 0.7256932258605957 diff --git a/quickdraw_bot/doc/exp/train/report.html b/quickdraw_bot/doc/exp/train/report.html new file mode 100644 index 0000000..ef43dc1 --- /dev/null +++ b/quickdraw_bot/doc/exp/train/report.html @@ -0,0 +1,114 @@ + + + + + DVC Plot + + + + + + + + + + +
+

metrics_json

+
+ + + + + + + +
train.loss train.accuracy train.precision train.recall train.f1 valid.loss valid.accuracy valid.precision valid.recall valid.f1 step
1.9928 0.488375 0.483633 0.488557 0.484951 1.43778 0.726 0.723876 0.725693 0.71939 29
+
+
+ +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + + \ No newline at end of file diff --git a/quickdraw_bot/prepare.py b/quickdraw_bot/prepare.py index c0f1303..1c17f82 100644 --- a/quickdraw_bot/prepare.py +++ b/quickdraw_bot/prepare.py @@ -14,53 +14,72 @@ from quickdraw_bot.utils.utils import load_config class Prepare: def __init__(self, config_path: str = './assets/config.yaml'): self.config = load_config(config_path)['prepare'] - self.set_random_seed() + self._set_random_seed() - def set_random_seed(self): + def _set_random_seed(self): random.seed(self.config['random_seed']) np.random.seed(self.config['random_seed']) - def load_dataset(self) -> dict[str, np.ndarray]: - data: dict[str, np.ndarray] = {} + def _load_dataset(self) -> dict[str, np.ndarray]: + data: dict[str, list] = { + 'images': [], + 'cate_names': [], + 'cate_ids': [] + } + cls_id_map: dict[str, int] = {} raw_data_dir = Path(self.config['data_dir']) / 'raw' - for npy_file in raw_data_dir.glob('*.npy'): - class_name = npy_file.stem + for npy_file in sorted(raw_data_dir.glob('*.npy')): + class_name = npy_file.stem.removeprefix('full_numpy_bitmap_') + if class_name not in cls_id_map: + cls_id_map[class_name] = len(cls_id_map) images = np.load(npy_file) # shape: (N, 784) images = images.reshape(-1, 1, 28, 28) # shape: (N, 1, 28, 28) images = images.astype(np.int8) if images.shape[0] < self.config['num_of_img_per_class']: print(f'Class {class_name} has less than {self.config["num_of_img_per_class"]} samples, keep all') - data[class_name] = images + data['images'].extend(images) + data['cate_names'].extend([class_name] * images.shape[0]) + data['cate_ids'].extend([cls_id_map[class_name]] * images.shape[0]) else: random_indice = np.random.choice(images.shape[0], self.config['num_of_img_per_class'], replace=False) - data[class_name] = images[random_indice] + data['images'].extend(images[random_indice]) + data['cate_names'].extend([class_name] * self.config['num_of_img_per_class']) + data['cate_ids'].extend([cls_id_map[class_name]] * self.config['num_of_img_per_class']) + data['images'] = np.array(data['images']).astype(np.uint8) + data['cate_names'] = np.array(data['cate_names']).astype('S30') + data['cate_ids'] = np.array(data['cate_ids']).astype(np.uint16) return data - def split_data(self, data: dict[str, np.ndarray]) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: - sets: dict[str, dict[str, list[np.ndarray]]] = {} - weights = {name: weight for name, weight in self.config['data_split'].items()} - if sum(weights.values()) != 1.0: + def _split_data(self, data: dict[str, np.ndarray]) -> dict[str, dict[str, np.ndarray]]: + weights = self.config['data_split'] + if abs(sum(weights.values()) - 1.0) > 1e-6: raise ValueError('Sum of data_split weights must be 1.0') - for class_name, images in data.items(): - for image in images: - selection = np.random.choice(list(weights.keys()), p=list(weights.values())) - if selection not in sets: - sets[selection] = {} - if class_name not in sets[selection]: - sets[selection][class_name] = [] - sets[selection][class_name].append(image) + + element_count = len(next(iter(data.values()))) + shuffled_indices = np.random.permutation(element_count) + + sets: dict[str, dict[str, np.ndarray]] = {} + start = 0 + for i, (name, weight) in enumerate(weights.items()): + if i == len(weights) - 1: + idx = shuffled_indices[start:] + else: + end = start + round(weight * element_count) + idx = shuffled_indices[start:end] + start = end + sets[name] = {key: value[idx] for key, value in data.items()} return sets - def save_npz(self, sets: dict[str, dict[str, list[np.ndarray]]]) -> None: + def _save_npz(self, sets: dict[str, dict[str, np.ndarray]]) -> None: save_dir = Path(self.config['data_dir']) / 'processed' save_dir.mkdir(exist_ok=True) - for usage, data_dict in sets.items(): - np.savez(f'{save_dir}/{usage}.npz', **data_dict) + for usage, data in sets.items(): + np.savez(f'{save_dir}/{usage}.npz', **data) def run(self): - data = self.load_dataset() - sets = self.split_data(data) - self.save_npz(sets) + data = self._load_dataset() + sets = self._split_data(data) + self._save_npz(sets) if __name__ == '__main__': diff --git a/quickdraw_bot/train.py b/quickdraw_bot/train.py index e69de29..a0b8b5f 100644 --- a/quickdraw_bot/train.py +++ b/quickdraw_bot/train.py @@ -0,0 +1,174 @@ +# train.py +# +# author: deng +# date : 20260617 + +import random + +import numpy as np +import torch +from dvclive import Live +from torch.utils.data import DataLoader +from torchmetrics import MetricCollection +from torchmetrics.classification import Accuracy, ConfusionMatrix, F1Score, Precision, Recall +from tqdm import tqdm + +from quickdraw_bot.utils.dataset import QuickDrawDataset +from quickdraw_bot.utils.model import BabyCNN +from quickdraw_bot.utils.utils import load_config + + +class Train: + def __init__(self, config_path: str = './assets/config.yaml'): + self.config = load_config(config_path)['train'] + self._device = torch.device(self.config['device_type']) + + self._ensure_deterministic() + + def _ensure_deterministic(self) -> None: + torch.use_deterministic_algorithms(mode=True, warn_only=True) + random.seed(self.config['random_seed']) + np.random.seed(self.config['random_seed']) + torch.manual_seed(self.config['random_seed']) + + def _get_dataloader(self): + train_dataset = QuickDrawDataset( + data_npz_path=self.config['train_npz'], + enable_data_aug=True, + file_lazy_load=self.config['file_lazy_load'], + return_cate_name=False, + # vis_dir='./tmp' + ) + valid_dataset = QuickDrawDataset( + data_npz_path=self.config['valid_npz'], + enable_data_aug=False, + file_lazy_load=self.config['file_lazy_load'], + return_cate_name=False, + ) + + train_dataloader = DataLoader( + train_dataset, + batch_size=self.config['batch_size'], + shuffle=True, + num_workers=4, + pin_memory=False, # not support for mps + persistent_workers=True + ) + valid_dataloader = DataLoader( + valid_dataset, + batch_size=self.config['batch_size'], + shuffle=False, + num_workers=1, + pin_memory=False, + persistent_workers=True + ) + return train_dataloader, valid_dataloader + + def _get_model(self) -> torch.nn.Module: + model = BabyCNN( + num_classes=self.config['num_of_class'], + dropout_p=0.3 + ).to(self._device) + model.train() + return model + + def _get_optimizer(self, model: torch.nn.Module) -> torch.optim.Optimizer: + if self.config['optimizer_name'] == 'adam': + optimizer = torch.optim.Adam(model.parameters(), lr=self.config['learning_rate']) + elif self.config['optimizer_name'] == 'sgd': + optimizer = torch.optim.SGD(model.parameters(), lr=self.config['learning_rate']) + else: + raise ValueError(f'Unknown optimizer name: {self.config["optimizer_name"]}') + return optimizer + + def _get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler: + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=0.0001) + if self.config['warmup_epochs'] > 0: + warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=self.config['warmup_epochs']) + scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup, scheduler], milestones=[self.config['warmup_epochs']]) + return scheduler + + def _get_loss(self) -> torch.nn.modules.loss._Loss: + loss = torch.nn.CrossEntropyLoss( + label_smoothing=0.1 + ).to(self._device) + return loss + + def _get_metrics(self) -> tuple[MetricCollection, ConfusionMatrix]: + metric_collection = MetricCollection([ + Accuracy(task='multiclass', num_classes=self.config['num_of_class'], top_k=1), + Precision(task='multiclass', num_classes=self.config['num_of_class'], average='macro'), + Recall(task='multiclass', num_classes=self.config['num_of_class'], average='macro'), + F1Score(task='multiclass', num_classes=self.config['num_of_class'], average='macro'), + ]).to(self._device) + confusion_matrix = ConfusionMatrix( + task='multiclass', + threshold=0.5, + num_classes=self.config['num_of_class'], + ).to(self._device) + return metric_collection, confusion_matrix + + def run(self): + train_dataloader, valid_dataloader = self._get_dataloader() + model = self._get_model() + optimizer = self._get_optimizer(model) + scheduler = self._get_scheduler(optimizer) + loss = self._get_loss() + metrics, _ = self._get_metrics() + + with Live( + dir='./doc/exp/train', + report='html', + dvcyaml='./assets/dvc.yaml', + exp_message=self.config['exp_msg']) as live: + + for epoch in tqdm(range(self.config['num_of_epochs']), desc='Training Epoch'): + metrics.reset() + model.train() + total_train_loss = 0. + for inputs, targets in train_dataloader: + inputs = inputs.to(self._device) + targets = targets.to(self._device) + optimizer.zero_grad() + outputs = model(inputs) + train_loss = loss(outputs, targets) + total_train_loss += train_loss.item() + train_loss.backward() + optimizer.step() + metrics.update(outputs, targets) + train_metrics = metrics.compute() + avg_train_loss = total_train_loss / len(train_dataloader) + + metrics.reset() + model.eval() + total_valid_loss = 0. + with torch.no_grad(): + for inputs, targets in valid_dataloader: + inputs = inputs.to(self._device) + targets = targets.to(self._device) + outputs = model(inputs) + valid_loss = loss(outputs, targets) + total_valid_loss += valid_loss.item() + metrics.update(outputs, targets) + valid_metrics = metrics.compute() + avg_valid_loss = total_valid_loss / len(valid_dataloader) + + live.log_metric('train/loss', avg_train_loss) + live.log_metric('train/accuracy', train_metrics['MulticlassAccuracy'].item()) + live.log_metric('train/precision', train_metrics['MulticlassPrecision'].item()) + live.log_metric('train/recall', train_metrics['MulticlassRecall'].item()) + live.log_metric('train/f1', train_metrics['MulticlassF1Score'].item()) + live.log_metric('valid/loss', avg_valid_loss) + live.log_metric('valid/accuracy', valid_metrics['MulticlassAccuracy'].item()) + live.log_metric('valid/precision', valid_metrics['MulticlassPrecision'].item()) + live.log_metric('valid/recall', valid_metrics['MulticlassRecall'].item()) + live.log_metric('valid/f1', valid_metrics['MulticlassF1Score'].item()) + + scheduler.step() + live.next_step() + + torch.save(model, './assets/model.pth') + + +if __name__ == '__main__': + Train().run() \ No newline at end of file diff --git a/quickdraw_bot/utils/dataset.py b/quickdraw_bot/utils/dataset.py new file mode 100644 index 0000000..469ccf8 --- /dev/null +++ b/quickdraw_bot/utils/dataset.py @@ -0,0 +1,88 @@ +# dataset.py +# +# author: deng +# date : 20260617 + +from pathlib import Path + +import numpy as np +import torch +from torchvision.transforms import v2 + + +class QuickDrawDataset(torch.utils.data.Dataset): + def __init__(self, + data_npz_path: str, + image_shape: tuple[int, int, int] = (1, 28, 28), + enable_data_aug: bool = False, + file_lazy_load: bool = False, + return_cate_name: bool = False, + vis_dir: str = None, + ) -> None: + super().__init__() + self._images: torch.Tensor | np.ndarray = [] + self._cate_names: list[str] = [] + self._cate_ids: torch.Tensor = [] + self._transform: callable = None + + self._enable_data_aug = enable_data_aug + self._data_npz_path = data_npz_path + self._image_shape = image_shape + self._file_lazy_load = file_lazy_load + self._return_cate_name = return_cate_name + self._vis_dir = Path(vis_dir) if vis_dir is not None else None + + self._set_data_transform() + self._collect_data() + + def _set_data_transform(self) -> None: + aug_pipeline = [] + if self._enable_data_aug: + aug_pipeline = [ + v2.RandomHorizontalFlip(p=0.2), + v2.RandomApply([v2.RandomAffine(degrees=(-30, 30), translate=(0.2, 0.2), scale=(0.8, 1.2), shear=(-10, 10))], p=0.5), + v2.RandomPerspective(distortion_scale=0.15, p=0.2), + v2.RandomApply([v2.ElasticTransform(alpha=15.0, sigma=3.0)], p=0.2), + v2.RandomErasing(p=0.2, scale=(0.02, 0.2)) + ] + self._transform = v2.Compose([ + *aug_pipeline, + v2.Resize(self._image_shape[1:]), + v2.ToDtype(torch.float32, scale=True), + ]) + + def _collect_data(self) -> None: + if self._file_lazy_load: + self._npz_file = np.load(self._data_npz_path, mmap_mode='r') + self._cate_names = [cate_name.decode() for cate_name in self._npz_file['cate_names']] + self._cate_ids = torch.from_numpy(self._npz_file['cate_ids'].copy()).long() + self._images = self._npz_file['images'] + else: + with np.load(self._data_npz_path, mmap_mode=None) as npz_file: + self._cate_names = [cate_name.decode() for cate_name in npz_file['cate_names']] + self._cate_ids = torch.from_numpy(npz_file['cate_ids']).long() + self._images = torch.from_numpy(npz_file['images']) + + def __len__(self) -> int: + return len(self._images) + + def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: + if self._file_lazy_load: + x = torch.from_numpy(self._images[index]) + else: + x = self._images[index] + x = self._transform(x) + y = self._cate_ids[index] + + if self._vis_dir is not None: + vis_path = self._vis_dir / f'{index:05d}_{self._cate_names[index]}.png' + if not vis_path.exists(): + v2.ToPILImage()(x).save(vis_path) + + if self._return_cate_name: + return x, y, self._cate_names[index] + return x, y + + def set_data_aug(self, enable_data_aug: bool) -> None: + self._enable_data_aug = enable_data_aug + self._set_data_transform() diff --git a/quickdraw_bot/utils/model.py b/quickdraw_bot/utils/model.py new file mode 100644 index 0000000..849a606 --- /dev/null +++ b/quickdraw_bot/utils/model.py @@ -0,0 +1,53 @@ +# model.py +# +# author: deng +# date : 20260617 + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BabyCNN(nn.Module): + def __init__(self, + num_classes: int = 10, + dropout_p: float = 0.5) -> None: + super().__init__() + + # Conv Block 1: 28x28 -> 14x14 + self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(num_features=32) + + # Conv Block 2: 14x14 -> 7x7 + self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(num_features=64) + + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.dropout = nn.Dropout(p=dropout_p) + + # FC Layers + self.fc1 = nn.Linear(in_features=64 * 7 * 7, out_features=128) + self.fc2 = nn.Linear(in_features=128, out_features=num_classes) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity='relu') + nn.init.zeros_(m.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool(F.relu(self.bn1(self.conv1(x)))) + x = self.pool(F.relu(self.bn2(self.conv2(x)))) + x = x.view(x.size(0), -1) + x = self.dropout(F.relu(self.fc1(x))) + x = self.fc2(x) + return x \ No newline at end of file