[exp] init train
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@ -12,5 +12,11 @@ wheels/
|
|||||||
# Dataset
|
# Dataset
|
||||||
quickdraw_bot/data
|
quickdraw_bot/data
|
||||||
|
|
||||||
|
# Temp files
|
||||||
|
quickdraw_bot/tmp
|
||||||
|
|
||||||
# DVC
|
# DVC
|
||||||
dvc/config.local
|
dvc/config.local
|
||||||
|
|
||||||
|
# .DS_Store
|
||||||
|
.DS_Store
|
||||||
1
quickdraw_bot/assets/.gitignore
vendored
Normal file
1
quickdraw_bot/assets/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/model.pth
|
||||||
@ -7,7 +7,20 @@ prepare:
|
|||||||
test: 0.1
|
test: 0.1
|
||||||
random_seed: 1
|
random_seed: 1
|
||||||
train:
|
train:
|
||||||
|
device_type: mps
|
||||||
|
train_npz: ./data/processed/train.npz
|
||||||
|
valid_npz: ./data/processed/valid.npz
|
||||||
|
batch_size: 256
|
||||||
|
num_of_class: 20
|
||||||
|
optimizer_name: sgd # sgd, adam
|
||||||
|
learning_rate: 0.001
|
||||||
|
warmup_epochs: 5
|
||||||
|
num_of_epochs: 30
|
||||||
|
file_lazy_load: false
|
||||||
random_seed: 1
|
random_seed: 1
|
||||||
|
exp_msg: init train
|
||||||
eval:
|
eval:
|
||||||
|
test_npz: ./data/processed/test.npz
|
||||||
random_seed: 1
|
random_seed: 1
|
||||||
deploy:
|
deploy:
|
||||||
|
random_seed: 1
|
||||||
5
quickdraw_bot/assets/dvc.yaml
Normal file
5
quickdraw_bot/assets/dvc.yaml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
metrics:
|
||||||
|
- ../doc/exp/train/metrics.json
|
||||||
|
plots:
|
||||||
|
- ../doc/exp/train/plots/metrics:
|
||||||
|
x: step
|
||||||
5
quickdraw_bot/assets/model.pth.dvc
Normal file
5
quickdraw_bot/assets/model.pth.dvc
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
outs:
|
||||||
|
- md5: 263f47ef298fee74aed6acc3a316e7ad
|
||||||
|
size: 1701245
|
||||||
|
hash: md5
|
||||||
|
path: model.pth
|
||||||
17
quickdraw_bot/doc/exp/train/metrics.json
Normal file
17
quickdraw_bot/doc/exp/train/metrics.json
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"train": {
|
||||||
|
"loss": 1.992799331665039,
|
||||||
|
"accuracy": 0.4883750081062317,
|
||||||
|
"precision": 0.48363304138183594,
|
||||||
|
"recall": 0.4885570704936981,
|
||||||
|
"f1": 0.4849509298801422
|
||||||
|
},
|
||||||
|
"valid": {
|
||||||
|
"loss": 1.4377805100211614,
|
||||||
|
"accuracy": 0.7260000109672546,
|
||||||
|
"precision": 0.723875880241394,
|
||||||
|
"recall": 0.7256932258605957,
|
||||||
|
"f1": 0.7193899154663086
|
||||||
|
},
|
||||||
|
"step": 29
|
||||||
|
}
|
||||||
31
quickdraw_bot/doc/exp/train/plots/metrics/train/accuracy.tsv
Normal file
31
quickdraw_bot/doc/exp/train/plots/metrics/train/accuracy.tsv
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
step accuracy
|
||||||
|
0 0.04570624977350235
|
||||||
|
1 0.08645624667406082
|
||||||
|
2 0.15463125705718994
|
||||||
|
3 0.20809374749660492
|
||||||
|
4 0.2542562484741211
|
||||||
|
5 0.29279375076293945
|
||||||
|
6 0.3235749900341034
|
||||||
|
7 0.346756249666214
|
||||||
|
8 0.363993763923645
|
||||||
|
9 0.3763374984264374
|
||||||
|
10 0.3868750035762787
|
||||||
|
11 0.39285001158714294
|
||||||
|
12 0.400112509727478
|
||||||
|
13 0.4029250144958496
|
||||||
|
14 0.403425008058548
|
||||||
|
15 0.4108937382698059
|
||||||
|
16 0.42086875438690186
|
||||||
|
17 0.43145623803138733
|
||||||
|
18 0.44205623865127563
|
||||||
|
19 0.44743749499320984
|
||||||
|
20 0.45317500829696655
|
||||||
|
21 0.4583125114440918
|
||||||
|
22 0.45945000648498535
|
||||||
|
23 0.4590874910354614
|
||||||
|
24 0.462799996137619
|
||||||
|
25 0.46361875534057617
|
||||||
|
26 0.47360000014305115
|
||||||
|
27 0.47944375872612
|
||||||
|
28 0.4831624925136566
|
||||||
|
29 0.4883750081062317
|
||||||
|
31
quickdraw_bot/doc/exp/train/plots/metrics/train/f1.tsv
Normal file
31
quickdraw_bot/doc/exp/train/plots/metrics/train/f1.tsv
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
step f1
|
||||||
|
0 0.03369621932506561
|
||||||
|
1 0.08265631645917892
|
||||||
|
2 0.14480535686016083
|
||||||
|
3 0.1954474151134491
|
||||||
|
4 0.24400727450847626
|
||||||
|
5 0.2852896749973297
|
||||||
|
6 0.3165817856788635
|
||||||
|
7 0.34051138162612915
|
||||||
|
8 0.3579234480857849
|
||||||
|
9 0.3706662654876709
|
||||||
|
10 0.3813643753528595
|
||||||
|
11 0.38740092515945435
|
||||||
|
12 0.3947397768497467
|
||||||
|
13 0.3978673219680786
|
||||||
|
14 0.3983455300331116
|
||||||
|
15 0.4061855971813202
|
||||||
|
16 0.41634228825569153
|
||||||
|
17 0.4270275831222534
|
||||||
|
18 0.4378680884838104
|
||||||
|
19 0.4433649480342865
|
||||||
|
20 0.44917452335357666
|
||||||
|
21 0.45438480377197266
|
||||||
|
22 0.45568108558654785
|
||||||
|
23 0.45533287525177
|
||||||
|
24 0.45924824476242065
|
||||||
|
25 0.45992523431777954
|
||||||
|
26 0.4698502719402313
|
||||||
|
27 0.47596246004104614
|
||||||
|
28 0.4798404574394226
|
||||||
|
29 0.4849509298801422
|
||||||
|
31
quickdraw_bot/doc/exp/train/plots/metrics/train/loss.tsv
Normal file
31
quickdraw_bot/doc/exp/train/plots/metrics/train/loss.tsv
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
step loss
|
||||||
|
0 4.0223345439910885
|
||||||
|
1 3.1446908485412597
|
||||||
|
2 2.845902690887451
|
||||||
|
3 2.7091702819824217
|
||||||
|
4 2.5876311504364016
|
||||||
|
5 2.4898329914093016
|
||||||
|
6 2.4078801975250244
|
||||||
|
7 2.3469735752105714
|
||||||
|
8 2.3044276081085204
|
||||||
|
9 2.2692713054656983
|
||||||
|
10 2.2433441226959228
|
||||||
|
11 2.2251542106628417
|
||||||
|
12 2.2075239395141604
|
||||||
|
13 2.199652731704712
|
||||||
|
14 2.2003354278564453
|
||||||
|
15 2.1802532527923586
|
||||||
|
16 2.155798588562012
|
||||||
|
17 2.127256035041809
|
||||||
|
18 2.1059615146636963
|
||||||
|
19 2.0936844367980956
|
||||||
|
20 2.077068444442749
|
||||||
|
21 2.064746246147156
|
||||||
|
22 2.0619221328735353
|
||||||
|
23 2.058998599433899
|
||||||
|
24 2.0547973834991455
|
||||||
|
25 2.0490142166137697
|
||||||
|
26 2.0298474113464358
|
||||||
|
27 2.0157182762145998
|
||||||
|
28 2.006533228492737
|
||||||
|
29 1.992799331665039
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
step precision
|
||||||
|
0 0.055038902908563614
|
||||||
|
1 0.08688722550868988
|
||||||
|
2 0.14544154703617096
|
||||||
|
3 0.19598175585269928
|
||||||
|
4 0.24326808750629425
|
||||||
|
5 0.28354209661483765
|
||||||
|
6 0.3143269121646881
|
||||||
|
7 0.33853644132614136
|
||||||
|
8 0.3557613492012024
|
||||||
|
9 0.3688707649707794
|
||||||
|
10 0.37962812185287476
|
||||||
|
11 0.3855380415916443
|
||||||
|
12 0.3929717540740967
|
||||||
|
13 0.3963885009288788
|
||||||
|
14 0.39665019512176514
|
||||||
|
15 0.40454378724098206
|
||||||
|
16 0.41489875316619873
|
||||||
|
17 0.42552798986434937
|
||||||
|
18 0.43634524941444397
|
||||||
|
19 0.44186848402023315
|
||||||
|
20 0.4476196765899658
|
||||||
|
21 0.4529317021369934
|
||||||
|
22 0.4541561007499695
|
||||||
|
23 0.45390117168426514
|
||||||
|
24 0.45794767141342163
|
||||||
|
25 0.45853471755981445
|
||||||
|
26 0.46830785274505615
|
||||||
|
27 0.4746767580509186
|
||||||
|
28 0.47852134704589844
|
||||||
|
29 0.48363304138183594
|
||||||
|
31
quickdraw_bot/doc/exp/train/plots/metrics/train/recall.tsv
Normal file
31
quickdraw_bot/doc/exp/train/plots/metrics/train/recall.tsv
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
step recall
|
||||||
|
0 0.04566790908575058
|
||||||
|
1 0.08645831048488617
|
||||||
|
2 0.1547394096851349
|
||||||
|
3 0.2082621306180954
|
||||||
|
4 0.2544386386871338
|
||||||
|
5 0.29297617077827454
|
||||||
|
6 0.3237552046775818
|
||||||
|
7 0.346926212310791
|
||||||
|
8 0.364177942276001
|
||||||
|
9 0.3765082359313965
|
||||||
|
10 0.3870483636856079
|
||||||
|
11 0.39303654432296753
|
||||||
|
12 0.4002862572669983
|
||||||
|
13 0.403104305267334
|
||||||
|
14 0.4036043882369995
|
||||||
|
15 0.4110710322856903
|
||||||
|
16 0.42104363441467285
|
||||||
|
17 0.4316273629665375
|
||||||
|
18 0.44224199652671814
|
||||||
|
19 0.44761422276496887
|
||||||
|
20 0.4533519148826599
|
||||||
|
21 0.4584938883781433
|
||||||
|
22 0.45962250232696533
|
||||||
|
23 0.45926111936569214
|
||||||
|
24 0.46298152208328247
|
||||||
|
25 0.4637938141822815
|
||||||
|
26 0.47378331422805786
|
||||||
|
27 0.4796329736709595
|
||||||
|
28 0.4833483397960663
|
||||||
|
29 0.4885570704936981
|
||||||
|
31
quickdraw_bot/doc/exp/train/plots/metrics/valid/accuracy.tsv
Normal file
31
quickdraw_bot/doc/exp/train/plots/metrics/valid/accuracy.tsv
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
step accuracy
|
||||||
|
0 0.04039999842643738
|
||||||
|
1 0.1996999979019165
|
||||||
|
2 0.3260500133037567
|
||||||
|
3 0.4186500012874603
|
||||||
|
4 0.49950000643730164
|
||||||
|
5 0.5444499850273132
|
||||||
|
6 0.5720000267028809
|
||||||
|
7 0.5917500257492065
|
||||||
|
8 0.6093500256538391
|
||||||
|
9 0.6218000054359436
|
||||||
|
10 0.6304000020027161
|
||||||
|
11 0.6385499835014343
|
||||||
|
12 0.6404500007629395
|
||||||
|
13 0.6446499824523926
|
||||||
|
14 0.6463500261306763
|
||||||
|
15 0.6567999720573425
|
||||||
|
16 0.6672000288963318
|
||||||
|
17 0.6779500246047974
|
||||||
|
18 0.6841999888420105
|
||||||
|
19 0.6896499991416931
|
||||||
|
20 0.6960499882698059
|
||||||
|
21 0.6970999836921692
|
||||||
|
22 0.7006000280380249
|
||||||
|
23 0.70169997215271
|
||||||
|
24 0.7010999917984009
|
||||||
|
25 0.7071499824523926
|
||||||
|
26 0.7135000228881836
|
||||||
|
27 0.7184500098228455
|
||||||
|
28 0.7211499810218811
|
||||||
|
29 0.7260000109672546
|
||||||
|
31
quickdraw_bot/doc/exp/train/plots/metrics/valid/f1.tsv
Normal file
31
quickdraw_bot/doc/exp/train/plots/metrics/valid/f1.tsv
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
step f1
|
||||||
|
0 0.026379385963082314
|
||||||
|
1 0.1792948693037033
|
||||||
|
2 0.2925271987915039
|
||||||
|
3 0.3919405937194824
|
||||||
|
4 0.47940874099731445
|
||||||
|
5 0.5274306535720825
|
||||||
|
6 0.556631863117218
|
||||||
|
7 0.576585054397583
|
||||||
|
8 0.5947144031524658
|
||||||
|
9 0.6082352995872498
|
||||||
|
10 0.6178301572799683
|
||||||
|
11 0.6262512803077698
|
||||||
|
12 0.6278425455093384
|
||||||
|
13 0.6331325769424438
|
||||||
|
14 0.6344878673553467
|
||||||
|
15 0.6453059315681458
|
||||||
|
16 0.6570700407028198
|
||||||
|
17 0.6685804128646851
|
||||||
|
18 0.6745381355285645
|
||||||
|
19 0.6804145574569702
|
||||||
|
20 0.6874032020568848
|
||||||
|
21 0.6882768273353577
|
||||||
|
22 0.6920892596244812
|
||||||
|
23 0.693403959274292
|
||||||
|
24 0.6925287246704102
|
||||||
|
25 0.6985740661621094
|
||||||
|
26 0.7053770422935486
|
||||||
|
27 0.7106728553771973
|
||||||
|
28 0.7137280106544495
|
||||||
|
29 0.7193899154663086
|
||||||
|
31
quickdraw_bot/doc/exp/train/plots/metrics/valid/loss.tsv
Normal file
31
quickdraw_bot/doc/exp/train/plots/metrics/valid/loss.tsv
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
step loss
|
||||||
|
0 3.590050941781153
|
||||||
|
1 2.7546746489367906
|
||||||
|
2 2.5154529040372826
|
||||||
|
3 2.2843858411040485
|
||||||
|
4 2.0942421291447895
|
||||||
|
5 1.9563061029096194
|
||||||
|
6 1.8598378127134298
|
||||||
|
7 1.7950012200995336
|
||||||
|
8 1.7452865039245993
|
||||||
|
9 1.7100401875338977
|
||||||
|
10 1.6846234753162046
|
||||||
|
11 1.6670298757432382
|
||||||
|
12 1.6579805748372138
|
||||||
|
13 1.6494467530069472
|
||||||
|
14 1.6449626789817327
|
||||||
|
15 1.6121938424774362
|
||||||
|
16 1.5864867349214191
|
||||||
|
17 1.5629751878448679
|
||||||
|
18 1.5430825224405602
|
||||||
|
19 1.5310050943229772
|
||||||
|
20 1.5175003311302089
|
||||||
|
21 1.5117049654827843
|
||||||
|
22 1.5065134658089168
|
||||||
|
23 1.5018350715878643
|
||||||
|
24 1.5016868627524074
|
||||||
|
25 1.4859640613386902
|
||||||
|
26 1.4698024339313749
|
||||||
|
27 1.4584274548518508
|
||||||
|
28 1.4490519852577886
|
||||||
|
29 1.4377805100211614
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
step precision
|
||||||
|
0 0.07338898628950119
|
||||||
|
1 0.18911437690258026
|
||||||
|
2 0.32449209690093994
|
||||||
|
3 0.4227263331413269
|
||||||
|
4 0.49964380264282227
|
||||||
|
5 0.5427083373069763
|
||||||
|
6 0.5689945220947266
|
||||||
|
7 0.5907210111618042
|
||||||
|
8 0.6088747978210449
|
||||||
|
9 0.620964527130127
|
||||||
|
10 0.6287857294082642
|
||||||
|
11 0.6372801065444946
|
||||||
|
12 0.6391931772232056
|
||||||
|
13 0.643118143081665
|
||||||
|
14 0.6445475816726685
|
||||||
|
15 0.655581533908844
|
||||||
|
16 0.6666973829269409
|
||||||
|
17 0.6765724420547485
|
||||||
|
18 0.6822240352630615
|
||||||
|
19 0.6881765127182007
|
||||||
|
20 0.6948975920677185
|
||||||
|
21 0.695229172706604
|
||||||
|
22 0.6998213529586792
|
||||||
|
23 0.7005149126052856
|
||||||
|
24 0.6999821662902832
|
||||||
|
25 0.7058628797531128
|
||||||
|
26 0.7112910151481628
|
||||||
|
27 0.7165747880935669
|
||||||
|
28 0.7200020551681519
|
||||||
|
29 0.723875880241394
|
||||||
|
31
quickdraw_bot/doc/exp/train/plots/metrics/valid/recall.tsv
Normal file
31
quickdraw_bot/doc/exp/train/plots/metrics/valid/recall.tsv
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
step recall
|
||||||
|
0 0.04072221741080284
|
||||||
|
1 0.19861623644828796
|
||||||
|
2 0.32381999492645264
|
||||||
|
3 0.41670864820480347
|
||||||
|
4 0.4983140230178833
|
||||||
|
5 0.5435963273048401
|
||||||
|
6 0.5713088512420654
|
||||||
|
7 0.5910568237304688
|
||||||
|
8 0.608674168586731
|
||||||
|
9 0.6212138533592224
|
||||||
|
10 0.6297356486320496
|
||||||
|
11 0.6379297971725464
|
||||||
|
12 0.6398895382881165
|
||||||
|
13 0.6441172361373901
|
||||||
|
14 0.6458127498626709
|
||||||
|
15 0.6563705205917358
|
||||||
|
16 0.6666883230209351
|
||||||
|
17 0.6775047779083252
|
||||||
|
18 0.6838763952255249
|
||||||
|
19 0.6894745826721191
|
||||||
|
20 0.6957231163978577
|
||||||
|
21 0.6967899203300476
|
||||||
|
22 0.7002253532409668
|
||||||
|
23 0.7013353109359741
|
||||||
|
24 0.7007752656936646
|
||||||
|
25 0.706775963306427
|
||||||
|
26 0.7130882740020752
|
||||||
|
27 0.7181775569915771
|
||||||
|
28 0.7208318114280701
|
||||||
|
29 0.7256932258605957
|
||||||
|
114
quickdraw_bot/doc/exp/train/report.html
Normal file
114
quickdraw_bot/doc/exp/train/report.html
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta http-equiv="refresh" content="5">
|
||||||
|
<title>DVC Plot</title>
|
||||||
|
|
||||||
|
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/vega@5.20.2"></script>
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/vega-lite@5.2.0"></script>
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/vega-embed@6.18.2"></script>
|
||||||
|
|
||||||
|
<style>
|
||||||
|
table {
|
||||||
|
border-spacing: 15px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
|
||||||
|
<div id="metrics_json" style="text-align: center; padding: 10x">
|
||||||
|
<p>metrics_json</p>
|
||||||
|
<div style="display: flex;justify-content: center;">
|
||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr><th style="text-align: right;"> train.loss</th><th style="text-align: right;"> train.accuracy</th><th style="text-align: right;"> train.precision</th><th style="text-align: right;"> train.recall</th><th style="text-align: right;"> train.f1</th><th style="text-align: right;"> valid.loss</th><th style="text-align: right;"> valid.accuracy</th><th style="text-align: right;"> valid.precision</th><th style="text-align: right;"> valid.recall</th><th style="text-align: right;"> valid.f1</th><th style="text-align: right;"> step</th></tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
<tr><td style="text-align: right;"> 1.9928</td><td style="text-align: right;"> 0.488375</td><td style="text-align: right;"> 0.483633</td><td style="text-align: right;"> 0.488557</td><td style="text-align: right;"> 0.484951</td><td style="text-align: right;"> 1.43778</td><td style="text-align: right;"> 0.726</td><td style="text-align: right;"> 0.723876</td><td style="text-align: right;"> 0.725693</td><td style="text-align: right;"> 0.71939</td><td style="text-align: right;"> 29</td></tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div id = "static_valid_f1">
|
||||||
|
<script type = "text/javascript">
|
||||||
|
var spec = {"$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": [{"step": "0", "f1": "0.026379385963082314", "rev": "workspace"}, {"step": "1", "f1": "0.1792948693037033", "rev": "workspace"}, {"step": "2", "f1": "0.2925271987915039", "rev": "workspace"}, {"step": "3", "f1": "0.3919405937194824", "rev": "workspace"}, {"step": "4", "f1": "0.47940874099731445", "rev": "workspace"}, {"step": "5", "f1": "0.5274306535720825", "rev": "workspace"}, {"step": "6", "f1": "0.556631863117218", "rev": "workspace"}, {"step": "7", "f1": "0.576585054397583", "rev": "workspace"}, {"step": "8", "f1": "0.5947144031524658", "rev": "workspace"}, {"step": "9", "f1": "0.6082352995872498", "rev": "workspace"}, {"step": "10", "f1": "0.6178301572799683", "rev": "workspace"}, {"step": "11", "f1": "0.6262512803077698", "rev": "workspace"}, {"step": "12", "f1": "0.6278425455093384", "rev": "workspace"}, {"step": "13", "f1": "0.6331325769424438", "rev": "workspace"}, {"step": "14", "f1": "0.6344878673553467", "rev": "workspace"}, {"step": "15", "f1": "0.6453059315681458", "rev": "workspace"}, {"step": "16", "f1": "0.6570700407028198", "rev": "workspace"}, {"step": "17", "f1": "0.6685804128646851", "rev": "workspace"}, {"step": "18", "f1": "0.6745381355285645", "rev": "workspace"}, {"step": "19", "f1": "0.6804145574569702", "rev": "workspace"}, {"step": "20", "f1": "0.6874032020568848", "rev": "workspace"}, {"step": "21", "f1": "0.6882768273353577", "rev": "workspace"}, {"step": "22", "f1": "0.6920892596244812", "rev": "workspace"}, {"step": "23", "f1": "0.693403959274292", "rev": "workspace"}, {"step": "24", "f1": "0.6925287246704102", "rev": "workspace"}, {"step": "25", "f1": "0.6985740661621094", "rev": "workspace"}, {"step": "26", "f1": "0.7053770422935486", "rev": "workspace"}, {"step": "27", "f1": "0.7106728553771973", "rev": "workspace"}, {"step": "28", "f1": "0.7137280106544495", "rev": "workspace"}, {"step": "29", "f1": "0.7193899154663086", "rev": "workspace"}]}, "title": {"text": "valid/f1", "anchor": "middle"}, "width": 300, "height": 300, "params": [{"name": "smooth", "value": 0.001, "bind": {"input": "range", "min": 0.001, "max": 1, "step": 0.001}}], "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "color": {"field": "rev", "scale": {"domain": ["workspace"], "range": ["#945dd6"]}}, "strokeDash": {}}, "layer": [{"layer": [{"params": [{"name": "grid", "select": "interval", "bind": "scales"}], "mark": "line"}, {"transform": [{"filter": {"param": "hover", "empty": false}}], "mark": "point"}], "encoding": {"y": {"field": "f1", "type": "quantitative", "title": "f1", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}, "transform": [{"loess": "f1", "on": "step", "groupby": ["rev"], "bandwidth": {"signal": "smooth"}}]}, {"mark": {"type": "line", "opacity": 0.2}, "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "y": {"field": "f1", "type": "quantitative", "title": "f1", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"mark": {"type": "circle", "size": 10}, "encoding": {"x": {"aggregate": "max", "field": "step", "type": "quantitative", "title": "step"}, "y": {"aggregate": {"argmax": "step"}, "field": "f1", "type": "quantitative", "title": "f1", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"transform": [{"calculate": "datum.rev", "as": "pivot_field"}, {"pivot": "pivot_field", "op": "mean", "value": "f1", "groupby": ["step"]}], "mark": {"type": "rule", "tooltip": {"content": "data"}, "stroke": "grey"}, "encoding": {"opacity": {"condition": {"value": 0.3, "param": "hover", "empty": false}, "value": 0}}, "params": [{"name": "hover", "select": {"type": "point", "fields": ["step"], "nearest": true, "on": "mouseover", "clear": "mouseout"}}]}]};
|
||||||
|
vegaEmbed('#static_valid_f1', spec);
|
||||||
|
</script>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
<div id = "static_valid_loss">
|
||||||
|
<script type = "text/javascript">
|
||||||
|
var spec = {"$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": [{"step": "0", "loss": "3.590050941781153", "rev": "workspace"}, {"step": "1", "loss": "2.7546746489367906", "rev": "workspace"}, {"step": "2", "loss": "2.5154529040372826", "rev": "workspace"}, {"step": "3", "loss": "2.2843858411040485", "rev": "workspace"}, {"step": "4", "loss": "2.0942421291447895", "rev": "workspace"}, {"step": "5", "loss": "1.9563061029096194", "rev": "workspace"}, {"step": "6", "loss": "1.8598378127134298", "rev": "workspace"}, {"step": "7", "loss": "1.7950012200995336", "rev": "workspace"}, {"step": "8", "loss": "1.7452865039245993", "rev": "workspace"}, {"step": "9", "loss": "1.7100401875338977", "rev": "workspace"}, {"step": "10", "loss": "1.6846234753162046", "rev": "workspace"}, {"step": "11", "loss": "1.6670298757432382", "rev": "workspace"}, {"step": "12", "loss": "1.6579805748372138", "rev": "workspace"}, {"step": "13", "loss": "1.6494467530069472", "rev": "workspace"}, {"step": "14", "loss": "1.6449626789817327", "rev": "workspace"}, {"step": "15", "loss": "1.6121938424774362", "rev": "workspace"}, {"step": "16", "loss": "1.5864867349214191", "rev": "workspace"}, {"step": "17", "loss": "1.5629751878448679", "rev": "workspace"}, {"step": "18", "loss": "1.5430825224405602", "rev": "workspace"}, {"step": "19", "loss": "1.5310050943229772", "rev": "workspace"}, {"step": "20", "loss": "1.5175003311302089", "rev": "workspace"}, {"step": "21", "loss": "1.5117049654827843", "rev": "workspace"}, {"step": "22", "loss": "1.5065134658089168", "rev": "workspace"}, {"step": "23", "loss": "1.5018350715878643", "rev": "workspace"}, {"step": "24", "loss": "1.5016868627524074", "rev": "workspace"}, {"step": "25", "loss": "1.4859640613386902", "rev": "workspace"}, {"step": "26", "loss": "1.4698024339313749", "rev": "workspace"}, {"step": "27", "loss": "1.4584274548518508", "rev": "workspace"}, {"step": "28", "loss": "1.4490519852577886", "rev": "workspace"}, {"step": "29", "loss": "1.4377805100211614", "rev": "workspace"}]}, "title": {"text": "valid/loss", "anchor": "middle"}, "width": 300, "height": 300, "params": [{"name": "smooth", "value": 0.001, "bind": {"input": "range", "min": 0.001, "max": 1, "step": 0.001}}], "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "color": {"field": "rev", "scale": {"domain": ["workspace"], "range": ["#945dd6"]}}, "strokeDash": {}}, "layer": [{"layer": [{"params": [{"name": "grid", "select": "interval", "bind": "scales"}], "mark": "line"}, {"transform": [{"filter": {"param": "hover", "empty": false}}], "mark": "point"}], "encoding": {"y": {"field": "loss", "type": "quantitative", "title": "loss", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}, "transform": [{"loess": "loss", "on": "step", "groupby": ["rev"], "bandwidth": {"signal": "smooth"}}]}, {"mark": {"type": "line", "opacity": 0.2}, "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "y": {"field": "loss", "type": "quantitative", "title": "loss", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"mark": {"type": "circle", "size": 10}, "encoding": {"x": {"aggregate": "max", "field": "step", "type": "quantitative", "title": "step"}, "y": {"aggregate": {"argmax": "step"}, "field": "loss", "type": "quantitative", "title": "loss", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"transform": [{"calculate": "datum.rev", "as": "pivot_field"}, {"pivot": "pivot_field", "op": "mean", "value": "loss", "groupby": ["step"]}], "mark": {"type": "rule", "tooltip": {"content": "data"}, "stroke": "grey"}, "encoding": {"opacity": {"condition": {"value": 0.3, "param": "hover", "empty": false}, "value": 0}}, "params": [{"name": "hover", "select": {"type": "point", "fields": ["step"], "nearest": true, "on": "mouseover", "clear": "mouseout"}}]}]};
|
||||||
|
vegaEmbed('#static_valid_loss', spec);
|
||||||
|
</script>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
<div id = "static_valid_accuracy">
|
||||||
|
<script type = "text/javascript">
|
||||||
|
var spec = {"$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": [{"step": "0", "accuracy": "0.04039999842643738", "rev": "workspace"}, {"step": "1", "accuracy": "0.1996999979019165", "rev": "workspace"}, {"step": "2", "accuracy": "0.3260500133037567", "rev": "workspace"}, {"step": "3", "accuracy": "0.4186500012874603", "rev": "workspace"}, {"step": "4", "accuracy": "0.49950000643730164", "rev": "workspace"}, {"step": "5", "accuracy": "0.5444499850273132", "rev": "workspace"}, {"step": "6", "accuracy": "0.5720000267028809", "rev": "workspace"}, {"step": "7", "accuracy": "0.5917500257492065", "rev": "workspace"}, {"step": "8", "accuracy": "0.6093500256538391", "rev": "workspace"}, {"step": "9", "accuracy": "0.6218000054359436", "rev": "workspace"}, {"step": "10", "accuracy": "0.6304000020027161", "rev": "workspace"}, {"step": "11", "accuracy": "0.6385499835014343", "rev": "workspace"}, {"step": "12", "accuracy": "0.6404500007629395", "rev": "workspace"}, {"step": "13", "accuracy": "0.6446499824523926", "rev": "workspace"}, {"step": "14", "accuracy": "0.6463500261306763", "rev": "workspace"}, {"step": "15", "accuracy": "0.6567999720573425", "rev": "workspace"}, {"step": "16", "accuracy": "0.6672000288963318", "rev": "workspace"}, {"step": "17", "accuracy": "0.6779500246047974", "rev": "workspace"}, {"step": "18", "accuracy": "0.6841999888420105", "rev": "workspace"}, {"step": "19", "accuracy": "0.6896499991416931", "rev": "workspace"}, {"step": "20", "accuracy": "0.6960499882698059", "rev": "workspace"}, {"step": "21", "accuracy": "0.6970999836921692", "rev": "workspace"}, {"step": "22", "accuracy": "0.7006000280380249", "rev": "workspace"}, {"step": "23", "accuracy": "0.70169997215271", "rev": "workspace"}, {"step": "24", "accuracy": "0.7010999917984009", "rev": "workspace"}, {"step": "25", "accuracy": "0.7071499824523926", "rev": "workspace"}, {"step": "26", "accuracy": "0.7135000228881836", "rev": "workspace"}, {"step": "27", "accuracy": "0.7184500098228455", "rev": "workspace"}, {"step": "28", "accuracy": "0.7211499810218811", "rev": "workspace"}, {"step": "29", "accuracy": "0.7260000109672546", "rev": "workspace"}]}, "title": {"text": "valid/accuracy", "anchor": "middle"}, "width": 300, "height": 300, "params": [{"name": "smooth", "value": 0.001, "bind": {"input": "range", "min": 0.001, "max": 1, "step": 0.001}}], "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "color": {"field": "rev", "scale": {"domain": ["workspace"], "range": ["#945dd6"]}}, "strokeDash": {}}, "layer": [{"layer": [{"params": [{"name": "grid", "select": "interval", "bind": "scales"}], "mark": "line"}, {"transform": [{"filter": {"param": "hover", "empty": false}}], "mark": "point"}], "encoding": {"y": {"field": "accuracy", "type": "quantitative", "title": "accuracy", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}, "transform": [{"loess": "accuracy", "on": "step", "groupby": ["rev"], "bandwidth": {"signal": "smooth"}}]}, {"mark": {"type": "line", "opacity": 0.2}, "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "y": {"field": "accuracy", "type": "quantitative", "title": "accuracy", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"mark": {"type": "circle", "size": 10}, "encoding": {"x": {"aggregate": "max", "field": "step", "type": "quantitative", "title": "step"}, "y": {"aggregate": {"argmax": "step"}, "field": "accuracy", "type": "quantitative", "title": "accuracy", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"transform": [{"calculate": "datum.rev", "as": "pivot_field"}, {"pivot": "pivot_field", "op": "mean", "value": "accuracy", "groupby": ["step"]}], "mark": {"type": "rule", "tooltip": {"content": "data"}, "stroke": "grey"}, "encoding": {"opacity": {"condition": {"value": 0.3, "param": "hover", "empty": false}, "value": 0}}, "params": [{"name": "hover", "select": {"type": "point", "fields": ["step"], "nearest": true, "on": "mouseover", "clear": "mouseout"}}]}]};
|
||||||
|
vegaEmbed('#static_valid_accuracy', spec);
|
||||||
|
</script>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
<div id = "static_valid_recall">
|
||||||
|
<script type = "text/javascript">
|
||||||
|
var spec = {"$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": [{"step": "0", "recall": "0.04072221741080284", "rev": "workspace"}, {"step": "1", "recall": "0.19861623644828796", "rev": "workspace"}, {"step": "2", "recall": "0.32381999492645264", "rev": "workspace"}, {"step": "3", "recall": "0.41670864820480347", "rev": "workspace"}, {"step": "4", "recall": "0.4983140230178833", "rev": "workspace"}, {"step": "5", "recall": "0.5435963273048401", "rev": "workspace"}, {"step": "6", "recall": "0.5713088512420654", "rev": "workspace"}, {"step": "7", "recall": "0.5910568237304688", "rev": "workspace"}, {"step": "8", "recall": "0.608674168586731", "rev": "workspace"}, {"step": "9", "recall": "0.6212138533592224", "rev": "workspace"}, {"step": "10", "recall": "0.6297356486320496", "rev": "workspace"}, {"step": "11", "recall": "0.6379297971725464", "rev": "workspace"}, {"step": "12", "recall": "0.6398895382881165", "rev": "workspace"}, {"step": "13", "recall": "0.6441172361373901", "rev": "workspace"}, {"step": "14", "recall": "0.6458127498626709", "rev": "workspace"}, {"step": "15", "recall": "0.6563705205917358", "rev": "workspace"}, {"step": "16", "recall": "0.6666883230209351", "rev": "workspace"}, {"step": "17", "recall": "0.6775047779083252", "rev": "workspace"}, {"step": "18", "recall": "0.6838763952255249", "rev": "workspace"}, {"step": "19", "recall": "0.6894745826721191", "rev": "workspace"}, {"step": "20", "recall": "0.6957231163978577", "rev": "workspace"}, {"step": "21", "recall": "0.6967899203300476", "rev": "workspace"}, {"step": "22", "recall": "0.7002253532409668", "rev": "workspace"}, {"step": "23", "recall": "0.7013353109359741", "rev": "workspace"}, {"step": "24", "recall": "0.7007752656936646", "rev": "workspace"}, {"step": "25", "recall": "0.706775963306427", "rev": "workspace"}, {"step": "26", "recall": "0.7130882740020752", "rev": "workspace"}, {"step": "27", "recall": "0.7181775569915771", "rev": "workspace"}, {"step": "28", "recall": "0.7208318114280701", "rev": "workspace"}, {"step": "29", "recall": "0.7256932258605957", "rev": "workspace"}]}, "title": {"text": "valid/recall", "anchor": "middle"}, "width": 300, "height": 300, "params": [{"name": "smooth", "value": 0.001, "bind": {"input": "range", "min": 0.001, "max": 1, "step": 0.001}}], "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "color": {"field": "rev", "scale": {"domain": ["workspace"], "range": ["#945dd6"]}}, "strokeDash": {}}, "layer": [{"layer": [{"params": [{"name": "grid", "select": "interval", "bind": "scales"}], "mark": "line"}, {"transform": [{"filter": {"param": "hover", "empty": false}}], "mark": "point"}], "encoding": {"y": {"field": "recall", "type": "quantitative", "title": "recall", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}, "transform": [{"loess": "recall", "on": "step", "groupby": ["rev"], "bandwidth": {"signal": "smooth"}}]}, {"mark": {"type": "line", "opacity": 0.2}, "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "y": {"field": "recall", "type": "quantitative", "title": "recall", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"mark": {"type": "circle", "size": 10}, "encoding": {"x": {"aggregate": "max", "field": "step", "type": "quantitative", "title": "step"}, "y": {"aggregate": {"argmax": "step"}, "field": "recall", "type": "quantitative", "title": "recall", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"transform": [{"calculate": "datum.rev", "as": "pivot_field"}, {"pivot": "pivot_field", "op": "mean", "value": "recall", "groupby": ["step"]}], "mark": {"type": "rule", "tooltip": {"content": "data"}, "stroke": "grey"}, "encoding": {"opacity": {"condition": {"value": 0.3, "param": "hover", "empty": false}, "value": 0}}, "params": [{"name": "hover", "select": {"type": "point", "fields": ["step"], "nearest": true, "on": "mouseover", "clear": "mouseout"}}]}]};
|
||||||
|
vegaEmbed('#static_valid_recall', spec);
|
||||||
|
</script>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
<div id = "static_valid_precision">
|
||||||
|
<script type = "text/javascript">
|
||||||
|
var spec = {"$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": [{"step": "0", "precision": "0.07338898628950119", "rev": "workspace"}, {"step": "1", "precision": "0.18911437690258026", "rev": "workspace"}, {"step": "2", "precision": "0.32449209690093994", "rev": "workspace"}, {"step": "3", "precision": "0.4227263331413269", "rev": "workspace"}, {"step": "4", "precision": "0.49964380264282227", "rev": "workspace"}, {"step": "5", "precision": "0.5427083373069763", "rev": "workspace"}, {"step": "6", "precision": "0.5689945220947266", "rev": "workspace"}, {"step": "7", "precision": "0.5907210111618042", "rev": "workspace"}, {"step": "8", "precision": "0.6088747978210449", "rev": "workspace"}, {"step": "9", "precision": "0.620964527130127", "rev": "workspace"}, {"step": "10", "precision": "0.6287857294082642", "rev": "workspace"}, {"step": "11", "precision": "0.6372801065444946", "rev": "workspace"}, {"step": "12", "precision": "0.6391931772232056", "rev": "workspace"}, {"step": "13", "precision": "0.643118143081665", "rev": "workspace"}, {"step": "14", "precision": "0.6445475816726685", "rev": "workspace"}, {"step": "15", "precision": "0.655581533908844", "rev": "workspace"}, {"step": "16", "precision": "0.6666973829269409", "rev": "workspace"}, {"step": "17", "precision": "0.6765724420547485", "rev": "workspace"}, {"step": "18", "precision": "0.6822240352630615", "rev": "workspace"}, {"step": "19", "precision": "0.6881765127182007", "rev": "workspace"}, {"step": "20", "precision": "0.6948975920677185", "rev": "workspace"}, {"step": "21", "precision": "0.695229172706604", "rev": "workspace"}, {"step": "22", "precision": "0.6998213529586792", "rev": "workspace"}, {"step": "23", "precision": "0.7005149126052856", "rev": "workspace"}, {"step": "24", "precision": "0.6999821662902832", "rev": "workspace"}, {"step": "25", "precision": "0.7058628797531128", "rev": "workspace"}, {"step": "26", "precision": "0.7112910151481628", "rev": "workspace"}, {"step": "27", "precision": "0.7165747880935669", "rev": "workspace"}, {"step": "28", "precision": "0.7200020551681519", "rev": "workspace"}, {"step": "29", "precision": "0.723875880241394", "rev": "workspace"}]}, "title": {"text": "valid/precision", "anchor": "middle"}, "width": 300, "height": 300, "params": [{"name": "smooth", "value": 0.001, "bind": {"input": "range", "min": 0.001, "max": 1, "step": 0.001}}], "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "color": {"field": "rev", "scale": {"domain": ["workspace"], "range": ["#945dd6"]}}, "strokeDash": {}}, "layer": [{"layer": [{"params": [{"name": "grid", "select": "interval", "bind": "scales"}], "mark": "line"}, {"transform": [{"filter": {"param": "hover", "empty": false}}], "mark": "point"}], "encoding": {"y": {"field": "precision", "type": "quantitative", "title": "precision", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}, "transform": [{"loess": "precision", "on": "step", "groupby": ["rev"], "bandwidth": {"signal": "smooth"}}]}, {"mark": {"type": "line", "opacity": 0.2}, "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "y": {"field": "precision", "type": "quantitative", "title": "precision", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"mark": {"type": "circle", "size": 10}, "encoding": {"x": {"aggregate": "max", "field": "step", "type": "quantitative", "title": "step"}, "y": {"aggregate": {"argmax": "step"}, "field": "precision", "type": "quantitative", "title": "precision", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"transform": [{"calculate": "datum.rev", "as": "pivot_field"}, {"pivot": "pivot_field", "op": "mean", "value": "precision", "groupby": ["step"]}], "mark": {"type": "rule", "tooltip": {"content": "data"}, "stroke": "grey"}, "encoding": {"opacity": {"condition": {"value": 0.3, "param": "hover", "empty": false}, "value": 0}}, "params": [{"name": "hover", "select": {"type": "point", "fields": ["step"], "nearest": true, "on": "mouseover", "clear": "mouseout"}}]}]};
|
||||||
|
vegaEmbed('#static_valid_precision', spec);
|
||||||
|
</script>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
<div id = "static_train_f1">
|
||||||
|
<script type = "text/javascript">
|
||||||
|
var spec = {"$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": [{"step": "0", "f1": "0.03369621932506561", "rev": "workspace"}, {"step": "1", "f1": "0.08265631645917892", "rev": "workspace"}, {"step": "2", "f1": "0.14480535686016083", "rev": "workspace"}, {"step": "3", "f1": "0.1954474151134491", "rev": "workspace"}, {"step": "4", "f1": "0.24400727450847626", "rev": "workspace"}, {"step": "5", "f1": "0.2852896749973297", "rev": "workspace"}, {"step": "6", "f1": "0.3165817856788635", "rev": "workspace"}, {"step": "7", "f1": "0.34051138162612915", "rev": "workspace"}, {"step": "8", "f1": "0.3579234480857849", "rev": "workspace"}, {"step": "9", "f1": "0.3706662654876709", "rev": "workspace"}, {"step": "10", "f1": "0.3813643753528595", "rev": "workspace"}, {"step": "11", "f1": "0.38740092515945435", "rev": "workspace"}, {"step": "12", "f1": "0.3947397768497467", "rev": "workspace"}, {"step": "13", "f1": "0.3978673219680786", "rev": "workspace"}, {"step": "14", "f1": "0.3983455300331116", "rev": "workspace"}, {"step": "15", "f1": "0.4061855971813202", "rev": "workspace"}, {"step": "16", "f1": "0.41634228825569153", "rev": "workspace"}, {"step": "17", "f1": "0.4270275831222534", "rev": "workspace"}, {"step": "18", "f1": "0.4378680884838104", "rev": "workspace"}, {"step": "19", "f1": "0.4433649480342865", "rev": "workspace"}, {"step": "20", "f1": "0.44917452335357666", "rev": "workspace"}, {"step": "21", "f1": "0.45438480377197266", "rev": "workspace"}, {"step": "22", "f1": "0.45568108558654785", "rev": "workspace"}, {"step": "23", "f1": "0.45533287525177", "rev": "workspace"}, {"step": "24", "f1": "0.45924824476242065", "rev": "workspace"}, {"step": "25", "f1": "0.45992523431777954", "rev": "workspace"}, {"step": "26", "f1": "0.4698502719402313", "rev": "workspace"}, {"step": "27", "f1": "0.47596246004104614", "rev": "workspace"}, {"step": "28", "f1": "0.4798404574394226", "rev": "workspace"}, {"step": "29", "f1": "0.4849509298801422", "rev": "workspace"}]}, "title": {"text": "train/f1", "anchor": "middle"}, "width": 300, "height": 300, "params": [{"name": "smooth", "value": 0.001, "bind": {"input": "range", "min": 0.001, "max": 1, "step": 0.001}}], "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "color": {"field": "rev", "scale": {"domain": ["workspace"], "range": ["#945dd6"]}}, "strokeDash": {}}, "layer": [{"layer": [{"params": [{"name": "grid", "select": "interval", "bind": "scales"}], "mark": "line"}, {"transform": [{"filter": {"param": "hover", "empty": false}}], "mark": "point"}], "encoding": {"y": {"field": "f1", "type": "quantitative", "title": "f1", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}, "transform": [{"loess": "f1", "on": "step", "groupby": ["rev"], "bandwidth": {"signal": "smooth"}}]}, {"mark": {"type": "line", "opacity": 0.2}, "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "y": {"field": "f1", "type": "quantitative", "title": "f1", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"mark": {"type": "circle", "size": 10}, "encoding": {"x": {"aggregate": "max", "field": "step", "type": "quantitative", "title": "step"}, "y": {"aggregate": {"argmax": "step"}, "field": "f1", "type": "quantitative", "title": "f1", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"transform": [{"calculate": "datum.rev", "as": "pivot_field"}, {"pivot": "pivot_field", "op": "mean", "value": "f1", "groupby": ["step"]}], "mark": {"type": "rule", "tooltip": {"content": "data"}, "stroke": "grey"}, "encoding": {"opacity": {"condition": {"value": 0.3, "param": "hover", "empty": false}, "value": 0}}, "params": [{"name": "hover", "select": {"type": "point", "fields": ["step"], "nearest": true, "on": "mouseover", "clear": "mouseout"}}]}]};
|
||||||
|
vegaEmbed('#static_train_f1', spec);
|
||||||
|
</script>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
<div id = "static_train_loss">
|
||||||
|
<script type = "text/javascript">
|
||||||
|
var spec = {"$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": [{"step": "0", "loss": "4.0223345439910885", "rev": "workspace"}, {"step": "1", "loss": "3.1446908485412597", "rev": "workspace"}, {"step": "2", "loss": "2.845902690887451", "rev": "workspace"}, {"step": "3", "loss": "2.7091702819824217", "rev": "workspace"}, {"step": "4", "loss": "2.5876311504364016", "rev": "workspace"}, {"step": "5", "loss": "2.4898329914093016", "rev": "workspace"}, {"step": "6", "loss": "2.4078801975250244", "rev": "workspace"}, {"step": "7", "loss": "2.3469735752105714", "rev": "workspace"}, {"step": "8", "loss": "2.3044276081085204", "rev": "workspace"}, {"step": "9", "loss": "2.2692713054656983", "rev": "workspace"}, {"step": "10", "loss": "2.2433441226959228", "rev": "workspace"}, {"step": "11", "loss": "2.2251542106628417", "rev": "workspace"}, {"step": "12", "loss": "2.2075239395141604", "rev": "workspace"}, {"step": "13", "loss": "2.199652731704712", "rev": "workspace"}, {"step": "14", "loss": "2.2003354278564453", "rev": "workspace"}, {"step": "15", "loss": "2.1802532527923586", "rev": "workspace"}, {"step": "16", "loss": "2.155798588562012", "rev": "workspace"}, {"step": "17", "loss": "2.127256035041809", "rev": "workspace"}, {"step": "18", "loss": "2.1059615146636963", "rev": "workspace"}, {"step": "19", "loss": "2.0936844367980956", "rev": "workspace"}, {"step": "20", "loss": "2.077068444442749", "rev": "workspace"}, {"step": "21", "loss": "2.064746246147156", "rev": "workspace"}, {"step": "22", "loss": "2.0619221328735353", "rev": "workspace"}, {"step": "23", "loss": "2.058998599433899", "rev": "workspace"}, {"step": "24", "loss": "2.0547973834991455", "rev": "workspace"}, {"step": "25", "loss": "2.0490142166137697", "rev": "workspace"}, {"step": "26", "loss": "2.0298474113464358", "rev": "workspace"}, {"step": "27", "loss": "2.0157182762145998", "rev": "workspace"}, {"step": "28", "loss": "2.006533228492737", "rev": "workspace"}, {"step": "29", "loss": "1.992799331665039", "rev": "workspace"}]}, "title": {"text": "train/loss", "anchor": "middle"}, "width": 300, "height": 300, "params": [{"name": "smooth", "value": 0.001, "bind": {"input": "range", "min": 0.001, "max": 1, "step": 0.001}}], "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "color": {"field": "rev", "scale": {"domain": ["workspace"], "range": ["#945dd6"]}}, "strokeDash": {}}, "layer": [{"layer": [{"params": [{"name": "grid", "select": "interval", "bind": "scales"}], "mark": "line"}, {"transform": [{"filter": {"param": "hover", "empty": false}}], "mark": "point"}], "encoding": {"y": {"field": "loss", "type": "quantitative", "title": "loss", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}, "transform": [{"loess": "loss", "on": "step", "groupby": ["rev"], "bandwidth": {"signal": "smooth"}}]}, {"mark": {"type": "line", "opacity": 0.2}, "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "y": {"field": "loss", "type": "quantitative", "title": "loss", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"mark": {"type": "circle", "size": 10}, "encoding": {"x": {"aggregate": "max", "field": "step", "type": "quantitative", "title": "step"}, "y": {"aggregate": {"argmax": "step"}, "field": "loss", "type": "quantitative", "title": "loss", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"transform": [{"calculate": "datum.rev", "as": "pivot_field"}, {"pivot": "pivot_field", "op": "mean", "value": "loss", "groupby": ["step"]}], "mark": {"type": "rule", "tooltip": {"content": "data"}, "stroke": "grey"}, "encoding": {"opacity": {"condition": {"value": 0.3, "param": "hover", "empty": false}, "value": 0}}, "params": [{"name": "hover", "select": {"type": "point", "fields": ["step"], "nearest": true, "on": "mouseover", "clear": "mouseout"}}]}]};
|
||||||
|
vegaEmbed('#static_train_loss', spec);
|
||||||
|
</script>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
<div id = "static_train_accuracy">
|
||||||
|
<script type = "text/javascript">
|
||||||
|
var spec = {"$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": [{"step": "0", "accuracy": "0.04570624977350235", "rev": "workspace"}, {"step": "1", "accuracy": "0.08645624667406082", "rev": "workspace"}, {"step": "2", "accuracy": "0.15463125705718994", "rev": "workspace"}, {"step": "3", "accuracy": "0.20809374749660492", "rev": "workspace"}, {"step": "4", "accuracy": "0.2542562484741211", "rev": "workspace"}, {"step": "5", "accuracy": "0.29279375076293945", "rev": "workspace"}, {"step": "6", "accuracy": "0.3235749900341034", "rev": "workspace"}, {"step": "7", "accuracy": "0.346756249666214", "rev": "workspace"}, {"step": "8", "accuracy": "0.363993763923645", "rev": "workspace"}, {"step": "9", "accuracy": "0.3763374984264374", "rev": "workspace"}, {"step": "10", "accuracy": "0.3868750035762787", "rev": "workspace"}, {"step": "11", "accuracy": "0.39285001158714294", "rev": "workspace"}, {"step": "12", "accuracy": "0.400112509727478", "rev": "workspace"}, {"step": "13", "accuracy": "0.4029250144958496", "rev": "workspace"}, {"step": "14", "accuracy": "0.403425008058548", "rev": "workspace"}, {"step": "15", "accuracy": "0.4108937382698059", "rev": "workspace"}, {"step": "16", "accuracy": "0.42086875438690186", "rev": "workspace"}, {"step": "17", "accuracy": "0.43145623803138733", "rev": "workspace"}, {"step": "18", "accuracy": "0.44205623865127563", "rev": "workspace"}, {"step": "19", "accuracy": "0.44743749499320984", "rev": "workspace"}, {"step": "20", "accuracy": "0.45317500829696655", "rev": "workspace"}, {"step": "21", "accuracy": "0.4583125114440918", "rev": "workspace"}, {"step": "22", "accuracy": "0.45945000648498535", "rev": "workspace"}, {"step": "23", "accuracy": "0.4590874910354614", "rev": "workspace"}, {"step": "24", "accuracy": "0.462799996137619", "rev": "workspace"}, {"step": "25", "accuracy": "0.46361875534057617", "rev": "workspace"}, {"step": "26", "accuracy": "0.47360000014305115", "rev": "workspace"}, {"step": "27", "accuracy": "0.47944375872612", "rev": "workspace"}, {"step": "28", "accuracy": "0.4831624925136566", "rev": "workspace"}, {"step": "29", "accuracy": "0.4883750081062317", "rev": "workspace"}]}, "title": {"text": "train/accuracy", "anchor": "middle"}, "width": 300, "height": 300, "params": [{"name": "smooth", "value": 0.001, "bind": {"input": "range", "min": 0.001, "max": 1, "step": 0.001}}], "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "color": {"field": "rev", "scale": {"domain": ["workspace"], "range": ["#945dd6"]}}, "strokeDash": {}}, "layer": [{"layer": [{"params": [{"name": "grid", "select": "interval", "bind": "scales"}], "mark": "line"}, {"transform": [{"filter": {"param": "hover", "empty": false}}], "mark": "point"}], "encoding": {"y": {"field": "accuracy", "type": "quantitative", "title": "accuracy", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}, "transform": [{"loess": "accuracy", "on": "step", "groupby": ["rev"], "bandwidth": {"signal": "smooth"}}]}, {"mark": {"type": "line", "opacity": 0.2}, "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "y": {"field": "accuracy", "type": "quantitative", "title": "accuracy", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"mark": {"type": "circle", "size": 10}, "encoding": {"x": {"aggregate": "max", "field": "step", "type": "quantitative", "title": "step"}, "y": {"aggregate": {"argmax": "step"}, "field": "accuracy", "type": "quantitative", "title": "accuracy", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"transform": [{"calculate": "datum.rev", "as": "pivot_field"}, {"pivot": "pivot_field", "op": "mean", "value": "accuracy", "groupby": ["step"]}], "mark": {"type": "rule", "tooltip": {"content": "data"}, "stroke": "grey"}, "encoding": {"opacity": {"condition": {"value": 0.3, "param": "hover", "empty": false}, "value": 0}}, "params": [{"name": "hover", "select": {"type": "point", "fields": ["step"], "nearest": true, "on": "mouseover", "clear": "mouseout"}}]}]};
|
||||||
|
vegaEmbed('#static_train_accuracy', spec);
|
||||||
|
</script>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
<div id = "static_train_recall">
|
||||||
|
<script type = "text/javascript">
|
||||||
|
var spec = {"$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": [{"step": "0", "recall": "0.04566790908575058", "rev": "workspace"}, {"step": "1", "recall": "0.08645831048488617", "rev": "workspace"}, {"step": "2", "recall": "0.1547394096851349", "rev": "workspace"}, {"step": "3", "recall": "0.2082621306180954", "rev": "workspace"}, {"step": "4", "recall": "0.2544386386871338", "rev": "workspace"}, {"step": "5", "recall": "0.29297617077827454", "rev": "workspace"}, {"step": "6", "recall": "0.3237552046775818", "rev": "workspace"}, {"step": "7", "recall": "0.346926212310791", "rev": "workspace"}, {"step": "8", "recall": "0.364177942276001", "rev": "workspace"}, {"step": "9", "recall": "0.3765082359313965", "rev": "workspace"}, {"step": "10", "recall": "0.3870483636856079", "rev": "workspace"}, {"step": "11", "recall": "0.39303654432296753", "rev": "workspace"}, {"step": "12", "recall": "0.4002862572669983", "rev": "workspace"}, {"step": "13", "recall": "0.403104305267334", "rev": "workspace"}, {"step": "14", "recall": "0.4036043882369995", "rev": "workspace"}, {"step": "15", "recall": "0.4110710322856903", "rev": "workspace"}, {"step": "16", "recall": "0.42104363441467285", "rev": "workspace"}, {"step": "17", "recall": "0.4316273629665375", "rev": "workspace"}, {"step": "18", "recall": "0.44224199652671814", "rev": "workspace"}, {"step": "19", "recall": "0.44761422276496887", "rev": "workspace"}, {"step": "20", "recall": "0.4533519148826599", "rev": "workspace"}, {"step": "21", "recall": "0.4584938883781433", "rev": "workspace"}, {"step": "22", "recall": "0.45962250232696533", "rev": "workspace"}, {"step": "23", "recall": "0.45926111936569214", "rev": "workspace"}, {"step": "24", "recall": "0.46298152208328247", "rev": "workspace"}, {"step": "25", "recall": "0.4637938141822815", "rev": "workspace"}, {"step": "26", "recall": "0.47378331422805786", "rev": "workspace"}, {"step": "27", "recall": "0.4796329736709595", "rev": "workspace"}, {"step": "28", "recall": "0.4833483397960663", "rev": "workspace"}, {"step": "29", "recall": "0.4885570704936981", "rev": "workspace"}]}, "title": {"text": "train/recall", "anchor": "middle"}, "width": 300, "height": 300, "params": [{"name": "smooth", "value": 0.001, "bind": {"input": "range", "min": 0.001, "max": 1, "step": 0.001}}], "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "color": {"field": "rev", "scale": {"domain": ["workspace"], "range": ["#945dd6"]}}, "strokeDash": {}}, "layer": [{"layer": [{"params": [{"name": "grid", "select": "interval", "bind": "scales"}], "mark": "line"}, {"transform": [{"filter": {"param": "hover", "empty": false}}], "mark": "point"}], "encoding": {"y": {"field": "recall", "type": "quantitative", "title": "recall", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}, "transform": [{"loess": "recall", "on": "step", "groupby": ["rev"], "bandwidth": {"signal": "smooth"}}]}, {"mark": {"type": "line", "opacity": 0.2}, "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "y": {"field": "recall", "type": "quantitative", "title": "recall", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"mark": {"type": "circle", "size": 10}, "encoding": {"x": {"aggregate": "max", "field": "step", "type": "quantitative", "title": "step"}, "y": {"aggregate": {"argmax": "step"}, "field": "recall", "type": "quantitative", "title": "recall", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"transform": [{"calculate": "datum.rev", "as": "pivot_field"}, {"pivot": "pivot_field", "op": "mean", "value": "recall", "groupby": ["step"]}], "mark": {"type": "rule", "tooltip": {"content": "data"}, "stroke": "grey"}, "encoding": {"opacity": {"condition": {"value": 0.3, "param": "hover", "empty": false}, "value": 0}}, "params": [{"name": "hover", "select": {"type": "point", "fields": ["step"], "nearest": true, "on": "mouseover", "clear": "mouseout"}}]}]};
|
||||||
|
vegaEmbed('#static_train_recall', spec);
|
||||||
|
</script>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
<div id = "static_train_precision">
|
||||||
|
<script type = "text/javascript">
|
||||||
|
var spec = {"$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": [{"step": "0", "precision": "0.055038902908563614", "rev": "workspace"}, {"step": "1", "precision": "0.08688722550868988", "rev": "workspace"}, {"step": "2", "precision": "0.14544154703617096", "rev": "workspace"}, {"step": "3", "precision": "0.19598175585269928", "rev": "workspace"}, {"step": "4", "precision": "0.24326808750629425", "rev": "workspace"}, {"step": "5", "precision": "0.28354209661483765", "rev": "workspace"}, {"step": "6", "precision": "0.3143269121646881", "rev": "workspace"}, {"step": "7", "precision": "0.33853644132614136", "rev": "workspace"}, {"step": "8", "precision": "0.3557613492012024", "rev": "workspace"}, {"step": "9", "precision": "0.3688707649707794", "rev": "workspace"}, {"step": "10", "precision": "0.37962812185287476", "rev": "workspace"}, {"step": "11", "precision": "0.3855380415916443", "rev": "workspace"}, {"step": "12", "precision": "0.3929717540740967", "rev": "workspace"}, {"step": "13", "precision": "0.3963885009288788", "rev": "workspace"}, {"step": "14", "precision": "0.39665019512176514", "rev": "workspace"}, {"step": "15", "precision": "0.40454378724098206", "rev": "workspace"}, {"step": "16", "precision": "0.41489875316619873", "rev": "workspace"}, {"step": "17", "precision": "0.42552798986434937", "rev": "workspace"}, {"step": "18", "precision": "0.43634524941444397", "rev": "workspace"}, {"step": "19", "precision": "0.44186848402023315", "rev": "workspace"}, {"step": "20", "precision": "0.4476196765899658", "rev": "workspace"}, {"step": "21", "precision": "0.4529317021369934", "rev": "workspace"}, {"step": "22", "precision": "0.4541561007499695", "rev": "workspace"}, {"step": "23", "precision": "0.45390117168426514", "rev": "workspace"}, {"step": "24", "precision": "0.45794767141342163", "rev": "workspace"}, {"step": "25", "precision": "0.45853471755981445", "rev": "workspace"}, {"step": "26", "precision": "0.46830785274505615", "rev": "workspace"}, {"step": "27", "precision": "0.4746767580509186", "rev": "workspace"}, {"step": "28", "precision": "0.47852134704589844", "rev": "workspace"}, {"step": "29", "precision": "0.48363304138183594", "rev": "workspace"}]}, "title": {"text": "train/precision", "anchor": "middle"}, "width": 300, "height": 300, "params": [{"name": "smooth", "value": 0.001, "bind": {"input": "range", "min": 0.001, "max": 1, "step": 0.001}}], "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "color": {"field": "rev", "scale": {"domain": ["workspace"], "range": ["#945dd6"]}}, "strokeDash": {}}, "layer": [{"layer": [{"params": [{"name": "grid", "select": "interval", "bind": "scales"}], "mark": "line"}, {"transform": [{"filter": {"param": "hover", "empty": false}}], "mark": "point"}], "encoding": {"y": {"field": "precision", "type": "quantitative", "title": "precision", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}, "transform": [{"loess": "precision", "on": "step", "groupby": ["rev"], "bandwidth": {"signal": "smooth"}}]}, {"mark": {"type": "line", "opacity": 0.2}, "encoding": {"x": {"field": "step", "type": "quantitative", "title": "step"}, "y": {"field": "precision", "type": "quantitative", "title": "precision", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"mark": {"type": "circle", "size": 10}, "encoding": {"x": {"aggregate": "max", "field": "step", "type": "quantitative", "title": "step"}, "y": {"aggregate": {"argmax": "step"}, "field": "precision", "type": "quantitative", "title": "precision", "scale": {"zero": false}}, "color": {"field": "rev", "type": "nominal"}}}, {"transform": [{"calculate": "datum.rev", "as": "pivot_field"}, {"pivot": "pivot_field", "op": "mean", "value": "precision", "groupby": ["step"]}], "mark": {"type": "rule", "tooltip": {"content": "data"}, "stroke": "grey"}, "encoding": {"opacity": {"condition": {"value": 0.3, "param": "hover", "empty": false}, "value": 0}}, "params": [{"name": "hover", "select": {"type": "point", "fields": ["step"], "nearest": true, "on": "mouseover", "clear": "mouseout"}}]}]};
|
||||||
|
vegaEmbed('#static_train_precision', spec);
|
||||||
|
</script>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@ -14,53 +14,72 @@ from quickdraw_bot.utils.utils import load_config
|
|||||||
class Prepare:
|
class Prepare:
|
||||||
def __init__(self, config_path: str = './assets/config.yaml'):
|
def __init__(self, config_path: str = './assets/config.yaml'):
|
||||||
self.config = load_config(config_path)['prepare']
|
self.config = load_config(config_path)['prepare']
|
||||||
self.set_random_seed()
|
self._set_random_seed()
|
||||||
|
|
||||||
def set_random_seed(self):
|
def _set_random_seed(self):
|
||||||
random.seed(self.config['random_seed'])
|
random.seed(self.config['random_seed'])
|
||||||
np.random.seed(self.config['random_seed'])
|
np.random.seed(self.config['random_seed'])
|
||||||
|
|
||||||
def load_dataset(self) -> dict[str, np.ndarray]:
|
def _load_dataset(self) -> dict[str, np.ndarray]:
|
||||||
data: dict[str, np.ndarray] = {}
|
data: dict[str, list] = {
|
||||||
|
'images': [],
|
||||||
|
'cate_names': [],
|
||||||
|
'cate_ids': []
|
||||||
|
}
|
||||||
|
cls_id_map: dict[str, int] = {}
|
||||||
raw_data_dir = Path(self.config['data_dir']) / 'raw'
|
raw_data_dir = Path(self.config['data_dir']) / 'raw'
|
||||||
for npy_file in raw_data_dir.glob('*.npy'):
|
for npy_file in sorted(raw_data_dir.glob('*.npy')):
|
||||||
class_name = npy_file.stem
|
class_name = npy_file.stem.removeprefix('full_numpy_bitmap_')
|
||||||
|
if class_name not in cls_id_map:
|
||||||
|
cls_id_map[class_name] = len(cls_id_map)
|
||||||
images = np.load(npy_file) # shape: (N, 784)
|
images = np.load(npy_file) # shape: (N, 784)
|
||||||
images = images.reshape(-1, 1, 28, 28) # shape: (N, 1, 28, 28)
|
images = images.reshape(-1, 1, 28, 28) # shape: (N, 1, 28, 28)
|
||||||
images = images.astype(np.int8)
|
images = images.astype(np.int8)
|
||||||
if images.shape[0] < self.config['num_of_img_per_class']:
|
if images.shape[0] < self.config['num_of_img_per_class']:
|
||||||
print(f'Class {class_name} has less than {self.config["num_of_img_per_class"]} samples, keep all')
|
print(f'Class {class_name} has less than {self.config["num_of_img_per_class"]} samples, keep all')
|
||||||
data[class_name] = images
|
data['images'].extend(images)
|
||||||
|
data['cate_names'].extend([class_name] * images.shape[0])
|
||||||
|
data['cate_ids'].extend([cls_id_map[class_name]] * images.shape[0])
|
||||||
else:
|
else:
|
||||||
random_indice = np.random.choice(images.shape[0], self.config['num_of_img_per_class'], replace=False)
|
random_indice = np.random.choice(images.shape[0], self.config['num_of_img_per_class'], replace=False)
|
||||||
data[class_name] = images[random_indice]
|
data['images'].extend(images[random_indice])
|
||||||
|
data['cate_names'].extend([class_name] * self.config['num_of_img_per_class'])
|
||||||
|
data['cate_ids'].extend([cls_id_map[class_name]] * self.config['num_of_img_per_class'])
|
||||||
|
data['images'] = np.array(data['images']).astype(np.uint8)
|
||||||
|
data['cate_names'] = np.array(data['cate_names']).astype('S30')
|
||||||
|
data['cate_ids'] = np.array(data['cate_ids']).astype(np.uint16)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def split_data(self, data: dict[str, np.ndarray]) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
|
def _split_data(self, data: dict[str, np.ndarray]) -> dict[str, dict[str, np.ndarray]]:
|
||||||
sets: dict[str, dict[str, list[np.ndarray]]] = {}
|
weights = self.config['data_split']
|
||||||
weights = {name: weight for name, weight in self.config['data_split'].items()}
|
if abs(sum(weights.values()) - 1.0) > 1e-6:
|
||||||
if sum(weights.values()) != 1.0:
|
|
||||||
raise ValueError('Sum of data_split weights must be 1.0')
|
raise ValueError('Sum of data_split weights must be 1.0')
|
||||||
for class_name, images in data.items():
|
|
||||||
for image in images:
|
element_count = len(next(iter(data.values())))
|
||||||
selection = np.random.choice(list(weights.keys()), p=list(weights.values()))
|
shuffled_indices = np.random.permutation(element_count)
|
||||||
if selection not in sets:
|
|
||||||
sets[selection] = {}
|
sets: dict[str, dict[str, np.ndarray]] = {}
|
||||||
if class_name not in sets[selection]:
|
start = 0
|
||||||
sets[selection][class_name] = []
|
for i, (name, weight) in enumerate(weights.items()):
|
||||||
sets[selection][class_name].append(image)
|
if i == len(weights) - 1:
|
||||||
|
idx = shuffled_indices[start:]
|
||||||
|
else:
|
||||||
|
end = start + round(weight * element_count)
|
||||||
|
idx = shuffled_indices[start:end]
|
||||||
|
start = end
|
||||||
|
sets[name] = {key: value[idx] for key, value in data.items()}
|
||||||
return sets
|
return sets
|
||||||
|
|
||||||
def save_npz(self, sets: dict[str, dict[str, list[np.ndarray]]]) -> None:
|
def _save_npz(self, sets: dict[str, dict[str, np.ndarray]]) -> None:
|
||||||
save_dir = Path(self.config['data_dir']) / 'processed'
|
save_dir = Path(self.config['data_dir']) / 'processed'
|
||||||
save_dir.mkdir(exist_ok=True)
|
save_dir.mkdir(exist_ok=True)
|
||||||
for usage, data_dict in sets.items():
|
for usage, data in sets.items():
|
||||||
np.savez(f'{save_dir}/{usage}.npz', **data_dict)
|
np.savez(f'{save_dir}/{usage}.npz', **data)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
data = self.load_dataset()
|
data = self._load_dataset()
|
||||||
sets = self.split_data(data)
|
sets = self._split_data(data)
|
||||||
self.save_npz(sets)
|
self._save_npz(sets)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -0,0 +1,174 @@
|
|||||||
|
# train.py
|
||||||
|
#
|
||||||
|
# author: deng
|
||||||
|
# date : 20260617
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from dvclive import Live
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchmetrics import MetricCollection
|
||||||
|
from torchmetrics.classification import Accuracy, ConfusionMatrix, F1Score, Precision, Recall
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from quickdraw_bot.utils.dataset import QuickDrawDataset
|
||||||
|
from quickdraw_bot.utils.model import BabyCNN
|
||||||
|
from quickdraw_bot.utils.utils import load_config
|
||||||
|
|
||||||
|
|
||||||
|
class Train:
|
||||||
|
def __init__(self, config_path: str = './assets/config.yaml'):
|
||||||
|
self.config = load_config(config_path)['train']
|
||||||
|
self._device = torch.device(self.config['device_type'])
|
||||||
|
|
||||||
|
self._ensure_deterministic()
|
||||||
|
|
||||||
|
def _ensure_deterministic(self) -> None:
|
||||||
|
torch.use_deterministic_algorithms(mode=True, warn_only=True)
|
||||||
|
random.seed(self.config['random_seed'])
|
||||||
|
np.random.seed(self.config['random_seed'])
|
||||||
|
torch.manual_seed(self.config['random_seed'])
|
||||||
|
|
||||||
|
def _get_dataloader(self):
|
||||||
|
train_dataset = QuickDrawDataset(
|
||||||
|
data_npz_path=self.config['train_npz'],
|
||||||
|
enable_data_aug=True,
|
||||||
|
file_lazy_load=self.config['file_lazy_load'],
|
||||||
|
return_cate_name=False,
|
||||||
|
# vis_dir='./tmp'
|
||||||
|
)
|
||||||
|
valid_dataset = QuickDrawDataset(
|
||||||
|
data_npz_path=self.config['valid_npz'],
|
||||||
|
enable_data_aug=False,
|
||||||
|
file_lazy_load=self.config['file_lazy_load'],
|
||||||
|
return_cate_name=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=self.config['batch_size'],
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=4,
|
||||||
|
pin_memory=False, # not support for mps
|
||||||
|
persistent_workers=True
|
||||||
|
)
|
||||||
|
valid_dataloader = DataLoader(
|
||||||
|
valid_dataset,
|
||||||
|
batch_size=self.config['batch_size'],
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=1,
|
||||||
|
pin_memory=False,
|
||||||
|
persistent_workers=True
|
||||||
|
)
|
||||||
|
return train_dataloader, valid_dataloader
|
||||||
|
|
||||||
|
def _get_model(self) -> torch.nn.Module:
|
||||||
|
model = BabyCNN(
|
||||||
|
num_classes=self.config['num_of_class'],
|
||||||
|
dropout_p=0.3
|
||||||
|
).to(self._device)
|
||||||
|
model.train()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _get_optimizer(self, model: torch.nn.Module) -> torch.optim.Optimizer:
|
||||||
|
if self.config['optimizer_name'] == 'adam':
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=self.config['learning_rate'])
|
||||||
|
elif self.config['optimizer_name'] == 'sgd':
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=self.config['learning_rate'])
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown optimizer name: {self.config["optimizer_name"]}')
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def _get_scheduler(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
|
||||||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=0.0001)
|
||||||
|
if self.config['warmup_epochs'] > 0:
|
||||||
|
warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=self.config['warmup_epochs'])
|
||||||
|
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup, scheduler], milestones=[self.config['warmup_epochs']])
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
def _get_loss(self) -> torch.nn.modules.loss._Loss:
|
||||||
|
loss = torch.nn.CrossEntropyLoss(
|
||||||
|
label_smoothing=0.1
|
||||||
|
).to(self._device)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _get_metrics(self) -> tuple[MetricCollection, ConfusionMatrix]:
|
||||||
|
metric_collection = MetricCollection([
|
||||||
|
Accuracy(task='multiclass', num_classes=self.config['num_of_class'], top_k=1),
|
||||||
|
Precision(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
||||||
|
Recall(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
||||||
|
F1Score(task='multiclass', num_classes=self.config['num_of_class'], average='macro'),
|
||||||
|
]).to(self._device)
|
||||||
|
confusion_matrix = ConfusionMatrix(
|
||||||
|
task='multiclass',
|
||||||
|
threshold=0.5,
|
||||||
|
num_classes=self.config['num_of_class'],
|
||||||
|
).to(self._device)
|
||||||
|
return metric_collection, confusion_matrix
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
train_dataloader, valid_dataloader = self._get_dataloader()
|
||||||
|
model = self._get_model()
|
||||||
|
optimizer = self._get_optimizer(model)
|
||||||
|
scheduler = self._get_scheduler(optimizer)
|
||||||
|
loss = self._get_loss()
|
||||||
|
metrics, _ = self._get_metrics()
|
||||||
|
|
||||||
|
with Live(
|
||||||
|
dir='./doc/exp/train',
|
||||||
|
report='html',
|
||||||
|
dvcyaml='./assets/dvc.yaml',
|
||||||
|
exp_message=self.config['exp_msg']) as live:
|
||||||
|
|
||||||
|
for epoch in tqdm(range(self.config['num_of_epochs']), desc='Training Epoch'):
|
||||||
|
metrics.reset()
|
||||||
|
model.train()
|
||||||
|
total_train_loss = 0.
|
||||||
|
for inputs, targets in train_dataloader:
|
||||||
|
inputs = inputs.to(self._device)
|
||||||
|
targets = targets.to(self._device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(inputs)
|
||||||
|
train_loss = loss(outputs, targets)
|
||||||
|
total_train_loss += train_loss.item()
|
||||||
|
train_loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
metrics.update(outputs, targets)
|
||||||
|
train_metrics = metrics.compute()
|
||||||
|
avg_train_loss = total_train_loss / len(train_dataloader)
|
||||||
|
|
||||||
|
metrics.reset()
|
||||||
|
model.eval()
|
||||||
|
total_valid_loss = 0.
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, targets in valid_dataloader:
|
||||||
|
inputs = inputs.to(self._device)
|
||||||
|
targets = targets.to(self._device)
|
||||||
|
outputs = model(inputs)
|
||||||
|
valid_loss = loss(outputs, targets)
|
||||||
|
total_valid_loss += valid_loss.item()
|
||||||
|
metrics.update(outputs, targets)
|
||||||
|
valid_metrics = metrics.compute()
|
||||||
|
avg_valid_loss = total_valid_loss / len(valid_dataloader)
|
||||||
|
|
||||||
|
live.log_metric('train/loss', avg_train_loss)
|
||||||
|
live.log_metric('train/accuracy', train_metrics['MulticlassAccuracy'].item())
|
||||||
|
live.log_metric('train/precision', train_metrics['MulticlassPrecision'].item())
|
||||||
|
live.log_metric('train/recall', train_metrics['MulticlassRecall'].item())
|
||||||
|
live.log_metric('train/f1', train_metrics['MulticlassF1Score'].item())
|
||||||
|
live.log_metric('valid/loss', avg_valid_loss)
|
||||||
|
live.log_metric('valid/accuracy', valid_metrics['MulticlassAccuracy'].item())
|
||||||
|
live.log_metric('valid/precision', valid_metrics['MulticlassPrecision'].item())
|
||||||
|
live.log_metric('valid/recall', valid_metrics['MulticlassRecall'].item())
|
||||||
|
live.log_metric('valid/f1', valid_metrics['MulticlassF1Score'].item())
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
live.next_step()
|
||||||
|
|
||||||
|
torch.save(model, './assets/model.pth')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
Train().run()
|
||||||
88
quickdraw_bot/utils/dataset.py
Normal file
88
quickdraw_bot/utils/dataset.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# dataset.py
|
||||||
|
#
|
||||||
|
# author: deng
|
||||||
|
# date : 20260617
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
|
|
||||||
|
class QuickDrawDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
data_npz_path: str,
|
||||||
|
image_shape: tuple[int, int, int] = (1, 28, 28),
|
||||||
|
enable_data_aug: bool = False,
|
||||||
|
file_lazy_load: bool = False,
|
||||||
|
return_cate_name: bool = False,
|
||||||
|
vis_dir: str = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._images: torch.Tensor | np.ndarray = []
|
||||||
|
self._cate_names: list[str] = []
|
||||||
|
self._cate_ids: torch.Tensor = []
|
||||||
|
self._transform: callable = None
|
||||||
|
|
||||||
|
self._enable_data_aug = enable_data_aug
|
||||||
|
self._data_npz_path = data_npz_path
|
||||||
|
self._image_shape = image_shape
|
||||||
|
self._file_lazy_load = file_lazy_load
|
||||||
|
self._return_cate_name = return_cate_name
|
||||||
|
self._vis_dir = Path(vis_dir) if vis_dir is not None else None
|
||||||
|
|
||||||
|
self._set_data_transform()
|
||||||
|
self._collect_data()
|
||||||
|
|
||||||
|
def _set_data_transform(self) -> None:
|
||||||
|
aug_pipeline = []
|
||||||
|
if self._enable_data_aug:
|
||||||
|
aug_pipeline = [
|
||||||
|
v2.RandomHorizontalFlip(p=0.2),
|
||||||
|
v2.RandomApply([v2.RandomAffine(degrees=(-30, 30), translate=(0.2, 0.2), scale=(0.8, 1.2), shear=(-10, 10))], p=0.5),
|
||||||
|
v2.RandomPerspective(distortion_scale=0.15, p=0.2),
|
||||||
|
v2.RandomApply([v2.ElasticTransform(alpha=15.0, sigma=3.0)], p=0.2),
|
||||||
|
v2.RandomErasing(p=0.2, scale=(0.02, 0.2))
|
||||||
|
]
|
||||||
|
self._transform = v2.Compose([
|
||||||
|
*aug_pipeline,
|
||||||
|
v2.Resize(self._image_shape[1:]),
|
||||||
|
v2.ToDtype(torch.float32, scale=True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def _collect_data(self) -> None:
|
||||||
|
if self._file_lazy_load:
|
||||||
|
self._npz_file = np.load(self._data_npz_path, mmap_mode='r')
|
||||||
|
self._cate_names = [cate_name.decode() for cate_name in self._npz_file['cate_names']]
|
||||||
|
self._cate_ids = torch.from_numpy(self._npz_file['cate_ids'].copy()).long()
|
||||||
|
self._images = self._npz_file['images']
|
||||||
|
else:
|
||||||
|
with np.load(self._data_npz_path, mmap_mode=None) as npz_file:
|
||||||
|
self._cate_names = [cate_name.decode() for cate_name in npz_file['cate_names']]
|
||||||
|
self._cate_ids = torch.from_numpy(npz_file['cate_ids']).long()
|
||||||
|
self._images = torch.from_numpy(npz_file['images'])
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._images)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self._file_lazy_load:
|
||||||
|
x = torch.from_numpy(self._images[index])
|
||||||
|
else:
|
||||||
|
x = self._images[index]
|
||||||
|
x = self._transform(x)
|
||||||
|
y = self._cate_ids[index]
|
||||||
|
|
||||||
|
if self._vis_dir is not None:
|
||||||
|
vis_path = self._vis_dir / f'{index:05d}_{self._cate_names[index]}.png'
|
||||||
|
if not vis_path.exists():
|
||||||
|
v2.ToPILImage()(x).save(vis_path)
|
||||||
|
|
||||||
|
if self._return_cate_name:
|
||||||
|
return x, y, self._cate_names[index]
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
def set_data_aug(self, enable_data_aug: bool) -> None:
|
||||||
|
self._enable_data_aug = enable_data_aug
|
||||||
|
self._set_data_transform()
|
||||||
53
quickdraw_bot/utils/model.py
Normal file
53
quickdraw_bot/utils/model.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# model.py
|
||||||
|
#
|
||||||
|
# author: deng
|
||||||
|
# date : 20260617
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class BabyCNN(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
num_classes: int = 10,
|
||||||
|
dropout_p: float = 0.5) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Conv Block 1: 28x28 -> 14x14
|
||||||
|
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
|
||||||
|
self.bn1 = nn.BatchNorm2d(num_features=32)
|
||||||
|
|
||||||
|
# Conv Block 2: 14x14 -> 7x7
|
||||||
|
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
|
||||||
|
self.bn2 = nn.BatchNorm2d(num_features=64)
|
||||||
|
|
||||||
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||||
|
self.dropout = nn.Dropout(p=dropout_p)
|
||||||
|
|
||||||
|
# FC Layers
|
||||||
|
self.fc1 = nn.Linear(in_features=64 * 7 * 7, out_features=128)
|
||||||
|
self.fc2 = nn.Linear(in_features=128, out_features=num_classes)
|
||||||
|
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.pool(F.relu(self.bn1(self.conv1(x))))
|
||||||
|
x = self.pool(F.relu(self.bn2(self.conv2(x))))
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
x = self.dropout(F.relu(self.fc1(x)))
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
Reference in New Issue
Block a user