add encoding to yaml loading and apply torch.no_grad to test

This commit is contained in:
deng 2024-01-04 20:42:58 +08:00
parent 465b163a73
commit 9404557763
8 changed files with 42 additions and 23 deletions

View File

@ -5,8 +5,8 @@ stages:
deps: deps:
- path: prepare.py - path: prepare.py
hash: md5 hash: md5
md5: a1c07d1d5caf6e5288560a189415785c md5: f0fd9944ebe22e1af020f031847991ed
size: 2979 size: 2992
params: params:
params.yaml: params.yaml:
prepare: prepare:
@ -19,21 +19,21 @@ stages:
outs: outs:
- path: data/processed - path: data/processed
hash: md5 hash: md5
md5: f4bf62ffa725ca9144b7852a283dc1da.dir md5: a47a61b2a1709f487fd286e9c54b89fc.dir
size: 295118798 size: 295124946
nfiles: 60000 nfiles: 60001
train: train:
cmd: python train.py cmd: python train.py
deps: deps:
- path: data/processed - path: data/processed
hash: md5 hash: md5
md5: f4bf62ffa725ca9144b7852a283dc1da.dir md5: a47a61b2a1709f487fd286e9c54b89fc.dir
size: 295118798 size: 295124946
nfiles: 60000 nfiles: 60001
- path: train.py - path: train.py
hash: md5 hash: md5
md5: b797ccf2fe61952bbf6d83fa51b0b11f md5: aabaf1a407badf48c97b14a69b0072ea
size: 3407 size: 3420
params: params:
params.yaml: params.yaml:
train: train:
@ -46,3 +46,23 @@ stages:
hash: md5 hash: md5
md5: 8ead2a7cd52d70b359d3cdc3df5e43e3 md5: 8ead2a7cd52d70b359d3cdc3df5e43e3
size: 102592994 size: 102592994
evaluate:
cmd: python evaluate.py
deps:
- path: data/processed
hash: md5
md5: a47a61b2a1709f487fd286e9c54b89fc.dir
size: 295124946
nfiles: 60001
- path: evaluate.py
hash: md5
md5: 8a9a2e95a6b64e632a4f2feac62d294b
size: 1473
- path: model.pt
hash: md5
md5: 8ead2a7cd52d70b359d3cdc3df5e43e3
size: 102592994
params:
params.yaml:
evaluate:
data_dir: data/processed

View File

@ -24,8 +24,6 @@ stages:
- model.pt - model.pt
params: params:
- evaluate - evaluate
outs:
- eval
params: params:
- dvclive/train/params.yaml - dvclive/train/params.yaml
- dvclive/eval/params.yaml - dvclive/eval/params.yaml

View File

@ -1,3 +1,3 @@
{ {
"test_acc": 0.7336928844451904 "test_acc": 0.7336809039115906
} }

View File

@ -1,2 +1,2 @@
step test_acc step test_acc
0 0.7336928844451904 0 0.7336809039115906

1 step test_acc
2 0 0.7336928844451904 0.7336809039115906

View File

@ -10,6 +10,6 @@ metrics.json
| test_acc | | test_acc |
|------------| |------------|
| 0.733693 | | 0.733681 |
![static/test_acc](static/test_acc.png) ![static/test_acc](static/test_acc.png)

View File

@ -21,7 +21,7 @@ def evaluate(params_path: str = 'params.yaml') -> None:
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'. params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
""" """
with open(params_path, 'r') as f: with open(params_path, encoding='utf-8') as f:
params = yaml.safe_load(f) params = yaml.safe_load(f)
data_dir = Path(params['evaluate']['data_dir']) data_dir = Path(params['evaluate']['data_dir'])
@ -39,11 +39,12 @@ def evaluate(params_path: str = 'params.yaml') -> None:
with Live(dir='dvclive/eval', report='md') as live: with Live(dir='dvclive/eval', report='md') as live:
live.log_params(params['evaluate']) live.log_params(params['evaluate'])
for data in test_dataloader: with torch.no_grad():
inputs, labels = data[0].to(device), data[1].to(device) for data in test_dataloader:
outputs = net(inputs) inputs, labels = data[0].to(device), data[1].to(device)
_ = metric(outputs, labels) outputs = net(inputs)
test_acc = metric.compute() _ = metric(outputs, labels)
test_acc = metric.compute()
print(f'test_acc:{test_acc}') print(f'test_acc:{test_acc}')
live.log_metric('test_acc', float(test_acc.cpu())) live.log_metric('test_acc', float(test_acc.cpu()))

View File

@ -19,7 +19,7 @@ def prepare(params_path: str = 'params.yaml') -> None:
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'. params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
""" """
with open(params_path, 'r') as f: with open(params_path, encoding='utf-8') as f:
params = yaml.safe_load(f) params = yaml.safe_load(f)
data_dir = Path(params['prepare']['data_dir']) data_dir = Path(params['prepare']['data_dir'])
save_dir = Path(params['prepare']['save_dir']) save_dir = Path(params['prepare']['save_dir'])

View File

@ -23,7 +23,7 @@ def train(params_path: str = 'params.yaml') -> None:
params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'. params_path (str, optional): path of parameter yaml. Defaults to 'params.yaml'.
""" """
with open(params_path, 'r') as f: with open(params_path, encoding='utf-8') as f:
params = yaml.safe_load(f) params = yaml.safe_load(f)
data_dir = Path(params['train']['data_dir']) data_dir = Path(params['train']['data_dir'])
epochs = params['train']['epochs'] epochs = params['train']['epochs']