diff --git a/.gitignore b/.gitignore
index b2f3222..44b89e5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,5 +12,11 @@ wheels/
# Dataset
quickdraw_bot/data
+# Temp files
+quickdraw_bot/tmp
+
# DVC
-dvc/config.local
\ No newline at end of file
+dvc/config.local
+
+# .DS_Store
+.DS_Store
\ No newline at end of file
diff --git a/quickdraw_bot/assets/.gitignore b/quickdraw_bot/assets/.gitignore
new file mode 100644
index 0000000..c225f42
--- /dev/null
+++ b/quickdraw_bot/assets/.gitignore
@@ -0,0 +1 @@
+/model.pth
diff --git a/quickdraw_bot/assets/config.yaml b/quickdraw_bot/assets/config.yaml
index d97616f..9d85a6b 100644
--- a/quickdraw_bot/assets/config.yaml
+++ b/quickdraw_bot/assets/config.yaml
@@ -7,7 +7,20 @@ prepare:
test: 0.1
random_seed: 1
train:
+ device_type: mps
+ train_npz: ./data/processed/train.npz
+ valid_npz: ./data/processed/valid.npz
+ batch_size: 256
+ num_of_class: 20
+ optimizer_name: sgd # sgd, adam
+ learning_rate: 0.001
+ warmup_epochs: 5
+ num_of_epochs: 30
+ file_lazy_load: false
random_seed: 1
+ exp_msg: init train
eval:
+ test_npz: ./data/processed/test.npz
random_seed: 1
-deploy:
\ No newline at end of file
+deploy:
+ random_seed: 1
\ No newline at end of file
diff --git a/quickdraw_bot/assets/dvc.yaml b/quickdraw_bot/assets/dvc.yaml
new file mode 100644
index 0000000..6ef9e39
--- /dev/null
+++ b/quickdraw_bot/assets/dvc.yaml
@@ -0,0 +1,5 @@
+metrics:
+- ../doc/exp/train/metrics.json
+plots:
+- ../doc/exp/train/plots/metrics:
+ x: step
diff --git a/quickdraw_bot/assets/model.pth.dvc b/quickdraw_bot/assets/model.pth.dvc
new file mode 100644
index 0000000..8c3eea2
--- /dev/null
+++ b/quickdraw_bot/assets/model.pth.dvc
@@ -0,0 +1,5 @@
+outs:
+- md5: 263f47ef298fee74aed6acc3a316e7ad
+ size: 1701245
+ hash: md5
+ path: model.pth
diff --git a/quickdraw_bot/doc/exp/train/metrics.json b/quickdraw_bot/doc/exp/train/metrics.json
new file mode 100644
index 0000000..9580a0b
--- /dev/null
+++ b/quickdraw_bot/doc/exp/train/metrics.json
@@ -0,0 +1,17 @@
+{
+ "train": {
+ "loss": 1.992799331665039,
+ "accuracy": 0.4883750081062317,
+ "precision": 0.48363304138183594,
+ "recall": 0.4885570704936981,
+ "f1": 0.4849509298801422
+ },
+ "valid": {
+ "loss": 1.4377805100211614,
+ "accuracy": 0.7260000109672546,
+ "precision": 0.723875880241394,
+ "recall": 0.7256932258605957,
+ "f1": 0.7193899154663086
+ },
+ "step": 29
+}
diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/train/accuracy.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/train/accuracy.tsv
new file mode 100644
index 0000000..864aa9f
--- /dev/null
+++ b/quickdraw_bot/doc/exp/train/plots/metrics/train/accuracy.tsv
@@ -0,0 +1,31 @@
+step accuracy
+0 0.04570624977350235
+1 0.08645624667406082
+2 0.15463125705718994
+3 0.20809374749660492
+4 0.2542562484741211
+5 0.29279375076293945
+6 0.3235749900341034
+7 0.346756249666214
+8 0.363993763923645
+9 0.3763374984264374
+10 0.3868750035762787
+11 0.39285001158714294
+12 0.400112509727478
+13 0.4029250144958496
+14 0.403425008058548
+15 0.4108937382698059
+16 0.42086875438690186
+17 0.43145623803138733
+18 0.44205623865127563
+19 0.44743749499320984
+20 0.45317500829696655
+21 0.4583125114440918
+22 0.45945000648498535
+23 0.4590874910354614
+24 0.462799996137619
+25 0.46361875534057617
+26 0.47360000014305115
+27 0.47944375872612
+28 0.4831624925136566
+29 0.4883750081062317
diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/train/f1.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/train/f1.tsv
new file mode 100644
index 0000000..981e0a5
--- /dev/null
+++ b/quickdraw_bot/doc/exp/train/plots/metrics/train/f1.tsv
@@ -0,0 +1,31 @@
+step f1
+0 0.03369621932506561
+1 0.08265631645917892
+2 0.14480535686016083
+3 0.1954474151134491
+4 0.24400727450847626
+5 0.2852896749973297
+6 0.3165817856788635
+7 0.34051138162612915
+8 0.3579234480857849
+9 0.3706662654876709
+10 0.3813643753528595
+11 0.38740092515945435
+12 0.3947397768497467
+13 0.3978673219680786
+14 0.3983455300331116
+15 0.4061855971813202
+16 0.41634228825569153
+17 0.4270275831222534
+18 0.4378680884838104
+19 0.4433649480342865
+20 0.44917452335357666
+21 0.45438480377197266
+22 0.45568108558654785
+23 0.45533287525177
+24 0.45924824476242065
+25 0.45992523431777954
+26 0.4698502719402313
+27 0.47596246004104614
+28 0.4798404574394226
+29 0.4849509298801422
diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/train/loss.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/train/loss.tsv
new file mode 100644
index 0000000..14c58fc
--- /dev/null
+++ b/quickdraw_bot/doc/exp/train/plots/metrics/train/loss.tsv
@@ -0,0 +1,31 @@
+step loss
+0 4.0223345439910885
+1 3.1446908485412597
+2 2.845902690887451
+3 2.7091702819824217
+4 2.5876311504364016
+5 2.4898329914093016
+6 2.4078801975250244
+7 2.3469735752105714
+8 2.3044276081085204
+9 2.2692713054656983
+10 2.2433441226959228
+11 2.2251542106628417
+12 2.2075239395141604
+13 2.199652731704712
+14 2.2003354278564453
+15 2.1802532527923586
+16 2.155798588562012
+17 2.127256035041809
+18 2.1059615146636963
+19 2.0936844367980956
+20 2.077068444442749
+21 2.064746246147156
+22 2.0619221328735353
+23 2.058998599433899
+24 2.0547973834991455
+25 2.0490142166137697
+26 2.0298474113464358
+27 2.0157182762145998
+28 2.006533228492737
+29 1.992799331665039
diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/train/precision.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/train/precision.tsv
new file mode 100644
index 0000000..c862d20
--- /dev/null
+++ b/quickdraw_bot/doc/exp/train/plots/metrics/train/precision.tsv
@@ -0,0 +1,31 @@
+step precision
+0 0.055038902908563614
+1 0.08688722550868988
+2 0.14544154703617096
+3 0.19598175585269928
+4 0.24326808750629425
+5 0.28354209661483765
+6 0.3143269121646881
+7 0.33853644132614136
+8 0.3557613492012024
+9 0.3688707649707794
+10 0.37962812185287476
+11 0.3855380415916443
+12 0.3929717540740967
+13 0.3963885009288788
+14 0.39665019512176514
+15 0.40454378724098206
+16 0.41489875316619873
+17 0.42552798986434937
+18 0.43634524941444397
+19 0.44186848402023315
+20 0.4476196765899658
+21 0.4529317021369934
+22 0.4541561007499695
+23 0.45390117168426514
+24 0.45794767141342163
+25 0.45853471755981445
+26 0.46830785274505615
+27 0.4746767580509186
+28 0.47852134704589844
+29 0.48363304138183594
diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/train/recall.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/train/recall.tsv
new file mode 100644
index 0000000..e3ea5b0
--- /dev/null
+++ b/quickdraw_bot/doc/exp/train/plots/metrics/train/recall.tsv
@@ -0,0 +1,31 @@
+step recall
+0 0.04566790908575058
+1 0.08645831048488617
+2 0.1547394096851349
+3 0.2082621306180954
+4 0.2544386386871338
+5 0.29297617077827454
+6 0.3237552046775818
+7 0.346926212310791
+8 0.364177942276001
+9 0.3765082359313965
+10 0.3870483636856079
+11 0.39303654432296753
+12 0.4002862572669983
+13 0.403104305267334
+14 0.4036043882369995
+15 0.4110710322856903
+16 0.42104363441467285
+17 0.4316273629665375
+18 0.44224199652671814
+19 0.44761422276496887
+20 0.4533519148826599
+21 0.4584938883781433
+22 0.45962250232696533
+23 0.45926111936569214
+24 0.46298152208328247
+25 0.4637938141822815
+26 0.47378331422805786
+27 0.4796329736709595
+28 0.4833483397960663
+29 0.4885570704936981
diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/valid/accuracy.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/valid/accuracy.tsv
new file mode 100644
index 0000000..a185a3e
--- /dev/null
+++ b/quickdraw_bot/doc/exp/train/plots/metrics/valid/accuracy.tsv
@@ -0,0 +1,31 @@
+step accuracy
+0 0.04039999842643738
+1 0.1996999979019165
+2 0.3260500133037567
+3 0.4186500012874603
+4 0.49950000643730164
+5 0.5444499850273132
+6 0.5720000267028809
+7 0.5917500257492065
+8 0.6093500256538391
+9 0.6218000054359436
+10 0.6304000020027161
+11 0.6385499835014343
+12 0.6404500007629395
+13 0.6446499824523926
+14 0.6463500261306763
+15 0.6567999720573425
+16 0.6672000288963318
+17 0.6779500246047974
+18 0.6841999888420105
+19 0.6896499991416931
+20 0.6960499882698059
+21 0.6970999836921692
+22 0.7006000280380249
+23 0.70169997215271
+24 0.7010999917984009
+25 0.7071499824523926
+26 0.7135000228881836
+27 0.7184500098228455
+28 0.7211499810218811
+29 0.7260000109672546
diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/valid/f1.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/valid/f1.tsv
new file mode 100644
index 0000000..efb4ae1
--- /dev/null
+++ b/quickdraw_bot/doc/exp/train/plots/metrics/valid/f1.tsv
@@ -0,0 +1,31 @@
+step f1
+0 0.026379385963082314
+1 0.1792948693037033
+2 0.2925271987915039
+3 0.3919405937194824
+4 0.47940874099731445
+5 0.5274306535720825
+6 0.556631863117218
+7 0.576585054397583
+8 0.5947144031524658
+9 0.6082352995872498
+10 0.6178301572799683
+11 0.6262512803077698
+12 0.6278425455093384
+13 0.6331325769424438
+14 0.6344878673553467
+15 0.6453059315681458
+16 0.6570700407028198
+17 0.6685804128646851
+18 0.6745381355285645
+19 0.6804145574569702
+20 0.6874032020568848
+21 0.6882768273353577
+22 0.6920892596244812
+23 0.693403959274292
+24 0.6925287246704102
+25 0.6985740661621094
+26 0.7053770422935486
+27 0.7106728553771973
+28 0.7137280106544495
+29 0.7193899154663086
diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/valid/loss.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/valid/loss.tsv
new file mode 100644
index 0000000..a9c6850
--- /dev/null
+++ b/quickdraw_bot/doc/exp/train/plots/metrics/valid/loss.tsv
@@ -0,0 +1,31 @@
+step loss
+0 3.590050941781153
+1 2.7546746489367906
+2 2.5154529040372826
+3 2.2843858411040485
+4 2.0942421291447895
+5 1.9563061029096194
+6 1.8598378127134298
+7 1.7950012200995336
+8 1.7452865039245993
+9 1.7100401875338977
+10 1.6846234753162046
+11 1.6670298757432382
+12 1.6579805748372138
+13 1.6494467530069472
+14 1.6449626789817327
+15 1.6121938424774362
+16 1.5864867349214191
+17 1.5629751878448679
+18 1.5430825224405602
+19 1.5310050943229772
+20 1.5175003311302089
+21 1.5117049654827843
+22 1.5065134658089168
+23 1.5018350715878643
+24 1.5016868627524074
+25 1.4859640613386902
+26 1.4698024339313749
+27 1.4584274548518508
+28 1.4490519852577886
+29 1.4377805100211614
diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/valid/precision.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/valid/precision.tsv
new file mode 100644
index 0000000..1065909
--- /dev/null
+++ b/quickdraw_bot/doc/exp/train/plots/metrics/valid/precision.tsv
@@ -0,0 +1,31 @@
+step precision
+0 0.07338898628950119
+1 0.18911437690258026
+2 0.32449209690093994
+3 0.4227263331413269
+4 0.49964380264282227
+5 0.5427083373069763
+6 0.5689945220947266
+7 0.5907210111618042
+8 0.6088747978210449
+9 0.620964527130127
+10 0.6287857294082642
+11 0.6372801065444946
+12 0.6391931772232056
+13 0.643118143081665
+14 0.6445475816726685
+15 0.655581533908844
+16 0.6666973829269409
+17 0.6765724420547485
+18 0.6822240352630615
+19 0.6881765127182007
+20 0.6948975920677185
+21 0.695229172706604
+22 0.6998213529586792
+23 0.7005149126052856
+24 0.6999821662902832
+25 0.7058628797531128
+26 0.7112910151481628
+27 0.7165747880935669
+28 0.7200020551681519
+29 0.723875880241394
diff --git a/quickdraw_bot/doc/exp/train/plots/metrics/valid/recall.tsv b/quickdraw_bot/doc/exp/train/plots/metrics/valid/recall.tsv
new file mode 100644
index 0000000..0c5d6f7
--- /dev/null
+++ b/quickdraw_bot/doc/exp/train/plots/metrics/valid/recall.tsv
@@ -0,0 +1,31 @@
+step recall
+0 0.04072221741080284
+1 0.19861623644828796
+2 0.32381999492645264
+3 0.41670864820480347
+4 0.4983140230178833
+5 0.5435963273048401
+6 0.5713088512420654
+7 0.5910568237304688
+8 0.608674168586731
+9 0.6212138533592224
+10 0.6297356486320496
+11 0.6379297971725464
+12 0.6398895382881165
+13 0.6441172361373901
+14 0.6458127498626709
+15 0.6563705205917358
+16 0.6666883230209351
+17 0.6775047779083252
+18 0.6838763952255249
+19 0.6894745826721191
+20 0.6957231163978577
+21 0.6967899203300476
+22 0.7002253532409668
+23 0.7013353109359741
+24 0.7007752656936646
+25 0.706775963306427
+26 0.7130882740020752
+27 0.7181775569915771
+28 0.7208318114280701
+29 0.7256932258605957
diff --git a/quickdraw_bot/doc/exp/train/report.html b/quickdraw_bot/doc/exp/train/report.html
new file mode 100644
index 0000000..ef43dc1
--- /dev/null
+++ b/quickdraw_bot/doc/exp/train/report.html
@@ -0,0 +1,114 @@
+
+
+
+
+ DVC Plot
+
+
+
+
+
+
+
+
+
+
+
+
metrics_json
+
+
+
+| train.loss | train.accuracy | train.precision | train.recall | train.f1 | valid.loss | valid.accuracy | valid.precision | valid.recall | valid.f1 | step |
+
+
+| 1.9928 | 0.488375 | 0.483633 | 0.488557 | 0.484951 | 1.43778 | 0.726 | 0.723876 | 0.725693 | 0.71939 | 29 |
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/quickdraw_bot/prepare.py b/quickdraw_bot/prepare.py
index c0f1303..1c17f82 100644
--- a/quickdraw_bot/prepare.py
+++ b/quickdraw_bot/prepare.py
@@ -14,53 +14,72 @@ from quickdraw_bot.utils.utils import load_config
class Prepare:
def __init__(self, config_path: str = './assets/config.yaml'):
self.config = load_config(config_path)['prepare']
- self.set_random_seed()
+ self._set_random_seed()
- def set_random_seed(self):
+ def _set_random_seed(self):
random.seed(self.config['random_seed'])
np.random.seed(self.config['random_seed'])
- def load_dataset(self) -> dict[str, np.ndarray]:
- data: dict[str, np.ndarray] = {}
+ def _load_dataset(self) -> dict[str, np.ndarray]:
+ data: dict[str, list] = {
+ 'images': [],
+ 'cate_names': [],
+ 'cate_ids': []
+ }
+ cls_id_map: dict[str, int] = {}
raw_data_dir = Path(self.config['data_dir']) / 'raw'
- for npy_file in raw_data_dir.glob('*.npy'):
- class_name = npy_file.stem
+ for npy_file in sorted(raw_data_dir.glob('*.npy')):
+ class_name = npy_file.stem.removeprefix('full_numpy_bitmap_')
+ if class_name not in cls_id_map:
+ cls_id_map[class_name] = len(cls_id_map)
images = np.load(npy_file) # shape: (N, 784)
images = images.reshape(-1, 1, 28, 28) # shape: (N, 1, 28, 28)
images = images.astype(np.int8)
if images.shape[0] < self.config['num_of_img_per_class']:
print(f'Class {class_name} has less than {self.config["num_of_img_per_class"]} samples, keep all')
- data[class_name] = images
+ data['images'].extend(images)
+ data['cate_names'].extend([class_name] * images.shape[0])
+ data['cate_ids'].extend([cls_id_map[class_name]] * images.shape[0])
else:
random_indice = np.random.choice(images.shape[0], self.config['num_of_img_per_class'], replace=False)
- data[class_name] = images[random_indice]
+ data['images'].extend(images[random_indice])
+ data['cate_names'].extend([class_name] * self.config['num_of_img_per_class'])
+ data['cate_ids'].extend([cls_id_map[class_name]] * self.config['num_of_img_per_class'])
+ data['images'] = np.array(data['images']).astype(np.uint8)
+ data['cate_names'] = np.array(data['cate_names']).astype('S30')
+ data['cate_ids'] = np.array(data['cate_ids']).astype(np.uint16)
return data
- def split_data(self, data: dict[str, np.ndarray]) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
- sets: dict[str, dict[str, list[np.ndarray]]] = {}
- weights = {name: weight for name, weight in self.config['data_split'].items()}
- if sum(weights.values()) != 1.0:
+ def _split_data(self, data: dict[str, np.ndarray]) -> dict[str, dict[str, np.ndarray]]:
+ weights = self.config['data_split']
+ if abs(sum(weights.values()) - 1.0) > 1e-6:
raise ValueError('Sum of data_split weights must be 1.0')
- for class_name, images in data.items():
- for image in images:
- selection = np.random.choice(list(weights.keys()), p=list(weights.values()))
- if selection not in sets:
- sets[selection] = {}
- if class_name not in sets[selection]:
- sets[selection][class_name] = []
- sets[selection][class_name].append(image)
+
+ element_count = len(next(iter(data.values())))
+ shuffled_indices = np.random.permutation(element_count)
+
+ sets: dict[str, dict[str, np.ndarray]] = {}
+ start = 0
+ for i, (name, weight) in enumerate(weights.items()):
+ if i == len(weights) - 1:
+ idx = shuffled_indices[start:]
+ else:
+ end = start + round(weight * element_count)
+ idx = shuffled_indices[start:end]
+ start = end
+ sets[name] = {key: value[idx] for key, value in data.items()}
return sets
- def save_npz(self, sets: dict[str, dict[str, list[np.ndarray]]]) -> None:
+ def _save_npz(self, sets: dict[str, dict[str, np.ndarray]]) -> None:
save_dir = Path(self.config['data_dir']) / 'processed'
save_dir.mkdir(exist_ok=True)
- for usage, data_dict in sets.items():
- np.savez(f'{save_dir}/{usage}.npz', **data_dict)
+ for usage, data in sets.items():
+ np.savez(f'{save_dir}/{usage}.npz', **data)
def run(self):
- data = self.load_dataset()
- sets = self.split_data(data)
- self.save_npz(sets)
+ data = self._load_dataset()
+ sets = self._split_data(data)
+ self._save_npz(sets)
if __name__ == '__main__':
diff --git a/quickdraw_bot/train.py b/quickdraw_bot/train.py
index e69de29..a0b8b5f 100644
--- a/quickdraw_bot/train.py
+++ b/quickdraw_bot/train.py
@@ -0,0 +1,174 @@
+# train.py
+#
+# author: deng
+# date : 20260617
+
+import random
+
+import numpy as np
+import torch
+from dvclive import Live
+from torch.utils.data import DataLoader
+from torchmetrics import MetricCollection
+from torchmetrics.classification import Accuracy, ConfusionMatrix, F1Score, Precision, Recall
+from tqdm import tqdm
+
+from quickdraw_bot.utils.dataset import QuickDrawDataset
+from quickdraw_bot.utils.model import BabyCNN
+from quickdraw_bot.utils.utils import load_config
+
+
+class Train:
+ def __init__(self, config_path: str = './assets/config.yaml'):
+ self.config = load_config(config_path)['train']
+ self._device = torch.device(self.config['device_type'])
+
+ self._ensure_deterministic()
+
+ def _ensure_deterministic(self) -> None:
+ torch.use_deterministic_algorithms(mode=True, warn_only=True)
+ random.seed(self.config['random_seed'])
+ np.random.seed(self.config['random_seed'])
+ torch.manual_seed(self.config['random_seed'])
+
+ def _get_dataloader(self):
+ train_dataset = QuickDrawDataset(
+ data_npz_path=self.config['train_npz'],
+ enable_data_aug=True,
+ file_lazy_load=self.config['file_lazy_load'],
+ return_cate_name=False,
+ # vis_dir='./tmp'
+ )
+ valid_dataset = QuickDrawDataset(
+ data_npz_path=self.config['valid_npz'],
+ enable_data_aug=False,
+ file_lazy_load=self.config['file_lazy_load'],
+ return_cate_name=False,
+ )
+
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=self.config['batch_size'],
+ shuffle=True,
+ num_workers=4,
+ pin_memory=False, # not support for mps
+ persistent_workers=True
+ )
+ valid_dataloader = DataLoader(
+ valid_dataset,
+ batch_size=self.config['batch_size'],
+ shuffle=False,
+ num_workers=1,
+ pin_memory=False,
+ persistent_workers=True
+ )
+ return train_dataloader, valid_dataloader
+
+ def _get_model(self) -> torch.nn.Module:
+ model = BabyCNN(
+ num_classes=self.config['num_of_class'],
+ dropout_p=0.3
+ ).to(self._device)
+ model.train()
+ return model
+
+ def _get_optimizer(self, model: torch.nn.Module) -> torch.optim.Optimizer:
+ if self.config['optimizer_name'] == 'adam':
+ optimizer = torch.optim.Adam(model.parameters(), lr=self.config['learning_rate'])
+ elif self.config['optimizer_name'] == 'sgd':
+ optimizer = torch.optim.SGD(model.parameters(), lr=self.config['learning_rate'])
+ else:
+ raise ValueError(f'Unknown optimizer name: {self.config["optimizer_name"]}')
+ return optimizer
+
+ def _get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=0.0001)
+ if self.config['warmup_epochs'] > 0:
+ warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=self.config['warmup_epochs'])
+ scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup, scheduler], milestones=[self.config['warmup_epochs']])
+ return scheduler
+
+ def _get_loss(self) -> torch.nn.modules.loss._Loss:
+ loss = torch.nn.CrossEntropyLoss(
+ label_smoothing=0.1
+ ).to(self._device)
+ return loss
+
+ def _get_metrics(self) -> tuple[MetricCollection, ConfusionMatrix]:
+ metric_collection = MetricCollection([
+ Accuracy(task='multiclass', num_classes=self.config['num_of_class'], top_k=1),
+ Precision(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
+ Recall(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
+ F1Score(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
+ ]).to(self._device)
+ confusion_matrix = ConfusionMatrix(
+ task='multiclass',
+ threshold=0.5,
+ num_classes=self.config['num_of_class'],
+ ).to(self._device)
+ return metric_collection, confusion_matrix
+
+ def run(self):
+ train_dataloader, valid_dataloader = self._get_dataloader()
+ model = self._get_model()
+ optimizer = self._get_optimizer(model)
+ scheduler = self._get_scheduler(optimizer)
+ loss = self._get_loss()
+ metrics, _ = self._get_metrics()
+
+ with Live(
+ dir='./doc/exp/train',
+ report='html',
+ dvcyaml='./assets/dvc.yaml',
+ exp_message=self.config['exp_msg']) as live:
+
+ for epoch in tqdm(range(self.config['num_of_epochs']), desc='Training Epoch'):
+ metrics.reset()
+ model.train()
+ total_train_loss = 0.
+ for inputs, targets in train_dataloader:
+ inputs = inputs.to(self._device)
+ targets = targets.to(self._device)
+ optimizer.zero_grad()
+ outputs = model(inputs)
+ train_loss = loss(outputs, targets)
+ total_train_loss += train_loss.item()
+ train_loss.backward()
+ optimizer.step()
+ metrics.update(outputs, targets)
+ train_metrics = metrics.compute()
+ avg_train_loss = total_train_loss / len(train_dataloader)
+
+ metrics.reset()
+ model.eval()
+ total_valid_loss = 0.
+ with torch.no_grad():
+ for inputs, targets in valid_dataloader:
+ inputs = inputs.to(self._device)
+ targets = targets.to(self._device)
+ outputs = model(inputs)
+ valid_loss = loss(outputs, targets)
+ total_valid_loss += valid_loss.item()
+ metrics.update(outputs, targets)
+ valid_metrics = metrics.compute()
+ avg_valid_loss = total_valid_loss / len(valid_dataloader)
+
+ live.log_metric('train/loss', avg_train_loss)
+ live.log_metric('train/accuracy', train_metrics['MulticlassAccuracy'].item())
+ live.log_metric('train/precision', train_metrics['MulticlassPrecision'].item())
+ live.log_metric('train/recall', train_metrics['MulticlassRecall'].item())
+ live.log_metric('train/f1', train_metrics['MulticlassF1Score'].item())
+ live.log_metric('valid/loss', avg_valid_loss)
+ live.log_metric('valid/accuracy', valid_metrics['MulticlassAccuracy'].item())
+ live.log_metric('valid/precision', valid_metrics['MulticlassPrecision'].item())
+ live.log_metric('valid/recall', valid_metrics['MulticlassRecall'].item())
+ live.log_metric('valid/f1', valid_metrics['MulticlassF1Score'].item())
+
+ scheduler.step()
+ live.next_step()
+
+ torch.save(model, './assets/model.pth')
+
+
+if __name__ == '__main__':
+ Train().run()
\ No newline at end of file
diff --git a/quickdraw_bot/utils/dataset.py b/quickdraw_bot/utils/dataset.py
new file mode 100644
index 0000000..469ccf8
--- /dev/null
+++ b/quickdraw_bot/utils/dataset.py
@@ -0,0 +1,88 @@
+# dataset.py
+#
+# author: deng
+# date : 20260617
+
+from pathlib import Path
+
+import numpy as np
+import torch
+from torchvision.transforms import v2
+
+
+class QuickDrawDataset(torch.utils.data.Dataset):
+ def __init__(self,
+ data_npz_path: str,
+ image_shape: tuple[int, int, int] = (1, 28, 28),
+ enable_data_aug: bool = False,
+ file_lazy_load: bool = False,
+ return_cate_name: bool = False,
+ vis_dir: str = None,
+ ) -> None:
+ super().__init__()
+ self._images: torch.Tensor | np.ndarray = []
+ self._cate_names: list[str] = []
+ self._cate_ids: torch.Tensor = []
+ self._transform: callable = None
+
+ self._enable_data_aug = enable_data_aug
+ self._data_npz_path = data_npz_path
+ self._image_shape = image_shape
+ self._file_lazy_load = file_lazy_load
+ self._return_cate_name = return_cate_name
+ self._vis_dir = Path(vis_dir) if vis_dir is not None else None
+
+ self._set_data_transform()
+ self._collect_data()
+
+ def _set_data_transform(self) -> None:
+ aug_pipeline = []
+ if self._enable_data_aug:
+ aug_pipeline = [
+ v2.RandomHorizontalFlip(p=0.2),
+ v2.RandomApply([v2.RandomAffine(degrees=(-30, 30), translate=(0.2, 0.2), scale=(0.8, 1.2), shear=(-10, 10))], p=0.5),
+ v2.RandomPerspective(distortion_scale=0.15, p=0.2),
+ v2.RandomApply([v2.ElasticTransform(alpha=15.0, sigma=3.0)], p=0.2),
+ v2.RandomErasing(p=0.2, scale=(0.02, 0.2))
+ ]
+ self._transform = v2.Compose([
+ *aug_pipeline,
+ v2.Resize(self._image_shape[1:]),
+ v2.ToDtype(torch.float32, scale=True),
+ ])
+
+ def _collect_data(self) -> None:
+ if self._file_lazy_load:
+ self._npz_file = np.load(self._data_npz_path, mmap_mode='r')
+ self._cate_names = [cate_name.decode() for cate_name in self._npz_file['cate_names']]
+ self._cate_ids = torch.from_numpy(self._npz_file['cate_ids'].copy()).long()
+ self._images = self._npz_file['images']
+ else:
+ with np.load(self._data_npz_path, mmap_mode=None) as npz_file:
+ self._cate_names = [cate_name.decode() for cate_name in npz_file['cate_names']]
+ self._cate_ids = torch.from_numpy(npz_file['cate_ids']).long()
+ self._images = torch.from_numpy(npz_file['images'])
+
+ def __len__(self) -> int:
+ return len(self._images)
+
+ def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
+ if self._file_lazy_load:
+ x = torch.from_numpy(self._images[index])
+ else:
+ x = self._images[index]
+ x = self._transform(x)
+ y = self._cate_ids[index]
+
+ if self._vis_dir is not None:
+ vis_path = self._vis_dir / f'{index:05d}_{self._cate_names[index]}.png'
+ if not vis_path.exists():
+ v2.ToPILImage()(x).save(vis_path)
+
+ if self._return_cate_name:
+ return x, y, self._cate_names[index]
+ return x, y
+
+ def set_data_aug(self, enable_data_aug: bool) -> None:
+ self._enable_data_aug = enable_data_aug
+ self._set_data_transform()
diff --git a/quickdraw_bot/utils/model.py b/quickdraw_bot/utils/model.py
new file mode 100644
index 0000000..849a606
--- /dev/null
+++ b/quickdraw_bot/utils/model.py
@@ -0,0 +1,53 @@
+# model.py
+#
+# author: deng
+# date : 20260617
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BabyCNN(nn.Module):
+ def __init__(self,
+ num_classes: int = 10,
+ dropout_p: float = 0.5) -> None:
+ super().__init__()
+
+ # Conv Block 1: 28x28 -> 14x14
+ self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
+ self.bn1 = nn.BatchNorm2d(num_features=32)
+
+ # Conv Block 2: 14x14 -> 7x7
+ self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
+ self.bn2 = nn.BatchNorm2d(num_features=64)
+
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.dropout = nn.Dropout(p=dropout_p)
+
+ # FC Layers
+ self.fc1 = nn.Linear(in_features=64 * 7 * 7, out_features=128)
+ self.fc2 = nn.Linear(in_features=128, out_features=num_classes)
+
+ self._init_weights()
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.pool(F.relu(self.bn1(self.conv1(x))))
+ x = self.pool(F.relu(self.bn2(self.conv2(x))))
+ x = x.view(x.size(0), -1)
+ x = self.dropout(F.relu(self.fc1(x)))
+ x = self.fc2(x)
+ return x
\ No newline at end of file