convert pt model to onnx

This commit is contained in:
2026-06-18 16:20:14 +08:00
parent 9c46e1c345
commit cf4c8a090a
8 changed files with 160 additions and 10 deletions

5
.gitignore vendored
View File

@ -20,3 +20,8 @@ dvc/config.local
# .DS_Store # .DS_Store
.DS_Store .DS_Store
# Models
*.pth
*.onnx
*.onnx.data

View File

@ -25,7 +25,8 @@ dev = [
deploy = [ deploy = [
"onnx~=1.22.0", "onnx~=1.22.0",
"onnxruntime~=1.27.0", "onnxruntime~=1.27.0",
"openvino~=2026.2.0" "onnxscript~=0.7.0",
"openvino~=2026.2.0",
] ]
[tool.ruff] [tool.ruff]

View File

@ -22,6 +22,8 @@ train:
eval: eval:
test_data_dir: ./data/processed/test test_data_dir: ./data/processed/test
model_path: ./assets/model.pth model_path: ./assets/model.pth
random_seed: 1 to_onnx:
deploy: model_path: ./assets/model.pth
random_seed: 1 image_size: [28, 28]
cls_map_path: ./data/processed/cate_id_cate_name_map.json
onnx_path: ./assets/model.onnx

View File

@ -0,0 +1,5 @@
outs:
- md5: 5a641ce9dc5db7b16a4868773eb968dc
size: 1755967
hash: md5
path: model.onnx

View File

@ -3,6 +3,7 @@
# author: deng # author: deng
# date : 20260616 # date : 20260616
import json
import random import random
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
@ -21,7 +22,7 @@ class Prepare:
random.seed(self.config['random_seed']) random.seed(self.config['random_seed'])
np.random.seed(self.config['random_seed']) np.random.seed(self.config['random_seed'])
def _load_dataset(self) -> dict[str, np.ndarray]: def _load_dataset(self) -> tuple[dict[str, np.ndarray], dict[str, int]]:
data: dict[str, list] = {'images': [], 'cate_names': [], 'cate_ids': []} data: dict[str, list] = {'images': [], 'cate_names': [], 'cate_ids': []}
cls_id_map: dict[str, int] = {} cls_id_map: dict[str, int] = {}
raw_data_dir = Path(self.config['data_dir']) / 'raw' raw_data_dir = Path(self.config['data_dir']) / 'raw'
@ -45,7 +46,7 @@ class Prepare:
data['images'] = np.array(data['images']).astype(np.uint8) data['images'] = np.array(data['images']).astype(np.uint8)
data['cate_names'] = np.array(data['cate_names']).astype('S30') data['cate_names'] = np.array(data['cate_names']).astype('S30')
data['cate_ids'] = np.array(data['cate_ids']).astype(np.uint16) data['cate_ids'] = np.array(data['cate_ids']).astype(np.uint16)
return data return data, cls_id_map
def _split_data(self, data: dict[str, np.ndarray]) -> dict[str, dict[str, np.ndarray]]: def _split_data(self, data: dict[str, np.ndarray]) -> dict[str, dict[str, np.ndarray]]:
weights = self.config['data_split'] weights = self.config['data_split']
@ -67,7 +68,7 @@ class Prepare:
sets[name] = {key: value[idx] for key, value in data.items()} sets[name] = {key: value[idx] for key, value in data.items()}
return sets return sets
def _save_data(self, sets: dict[str, dict[str, np.ndarray]]) -> None: def _save_data(self, sets: dict[str, dict[str, np.ndarray]], cls_id_map: dict[str, int]) -> None:
save_dir = Path(self.config['data_dir']) / 'processed' save_dir = Path(self.config['data_dir']) / 'processed'
if save_dir.exists(): if save_dir.exists():
rmtree(save_dir) rmtree(save_dir)
@ -77,11 +78,14 @@ class Prepare:
usage_dir.mkdir() usage_dir.mkdir()
for key, value in data.items(): for key, value in data.items():
np.save(f'{usage_dir}/{key}.npy', value) np.save(f'{usage_dir}/{key}.npy', value)
cate_id_cate_name_map = {v: k for k, v in cls_id_map.items()}
with open(f'{save_dir}/cate_id_cate_name_map.json', 'w') as f:
json.dump(cate_id_cate_name_map, f, indent=2)
def run(self): def run(self):
data = self._load_dataset() data, cls_id_map = self._load_dataset()
sets = self._split_data(data) sets = self._split_data(data)
self._save_data(sets) self._save_data(sets, cls_id_map)
if __name__ == '__main__': if __name__ == '__main__':

98
quickdraw_bot/to_onnx.py Normal file
View File

@ -0,0 +1,98 @@
# to_onnx.py
#
# author: deng
# date : 20260618
import json
import time
from pathlib import Path
import cv2
import numpy as np
import onnxruntime as ort
import torch
from quickdraw_bot.utils.utils import load_config
class Pipeline(torch.nn.Module):
def __init__(self, model: torch.nn.Module, input_size: tuple[int, int]):
super().__init__()
self._model = model
self._input_size = input_size
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = inputs.float() / 255.0
inputs = torch.nn.functional.interpolate(inputs, size=self._input_size, mode='bilinear', align_corners=False)
logits = self._model(inputs)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
class ToONNX:
def __init__(self, config_path: str = './assets/config.yaml'):
self.config = load_config(config_path)['to_onnx']
def _get_cls_map(self) -> dict[int, str]:
cls_map_path = Path(self.config['cls_map_path'])
if not cls_map_path.exists():
return None
with open(cls_map_path, 'r') as f:
cls_map = json.load(f)
return cls_map
def _get_model(self) -> torch.nn.Module:
model = torch.load(self.config['model_path'], map_location='cpu', weights_only=False)
model.eval()
return model
def _get_pipeline(self) -> torch.nn.Module:
pipeline = Pipeline(
model=self._get_model(),
input_size=tuple(self.config['image_size']),
)
return pipeline
def run(self) -> None:
pipeline = self._get_pipeline()
dummy_input = torch.randint(0, 256, (1, 1, self.config['image_size'][0], self.config['image_size'][1]), dtype=torch.uint8)
torch.onnx.export(
pipeline,
dummy_input,
self.config['onnx_path'],
input_names=['inputs'],
output_names=['outputs'],
dynamic_axes={'inputs': {0: 'batch_size', 2: 'height', 3: 'width'}, 'outputs': {0: 'batch_size'}},
opset_version=18,
dynamo=False,
verbose=True,
)
print(f'Done! ONNX file saved to {self.config["onnx_path"]}')
def test_infer(self, image_path: str) -> None:
# Get cls map
cls_map: dict[str, str] = self._get_cls_map()
cls_names = list(cls_map.values())
# Load Image
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
image = image[np.newaxis, np.newaxis, :, :]
# Infer
onnx_session = ort.InferenceSession(
self.config['onnx_path'],
providers=['CPUExecutionProvider'],
)
start_time = time.perf_counter()
outputs = onnx_session.run(None, {'inputs': image})[0]
elapsed_time_ms = (time.perf_counter() - start_time) * 1000
result = [{cls_name: round(float(prob), 3) for cls_name, prob in zip(cls_names, prob_arr)} for prob_arr in outputs]
print(f'Elapsed time: {elapsed_time_ms:.2f} ms')
print(f'Image: {image_path}')
print(f'Result: {result}')
if __name__ == '__main__':
ToONNX().run()
ToONNX().test_infer('./assets/favicon.png')

35
uv.lock generated
View File

@ -1793,6 +1793,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/00/50/257a880384a1dd502d543b0067945074d63cd17d0840e958355bc8197da8/onnx-1.22.0-cp314-cp314t-win_arm64.whl", hash = "sha256:8e268cdc0547e3949799ffd4a44451dc2b9080b57d0824a2db680b6ec65506f0", size = 17231391, upload-time = "2026-06-15T12:50:03.047Z" }, { url = "https://files.pythonhosted.org/packages/00/50/257a880384a1dd502d543b0067945074d63cd17d0840e958355bc8197da8/onnx-1.22.0-cp314-cp314t-win_arm64.whl", hash = "sha256:8e268cdc0547e3949799ffd4a44451dc2b9080b57d0824a2db680b6ec65506f0", size = 17231391, upload-time = "2026-06-15T12:50:03.047Z" },
] ]
[[package]]
name = "onnx-ir"
version = "0.2.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "ml-dtypes" },
{ name = "numpy" },
{ name = "onnx" },
{ name = "sympy" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/35/e6/672fefb2f108d077f58181a7babf4c0f8d1182a30353ffc9c79c63afc5ee/onnx_ir-0.2.1.tar.gz", hash = "sha256:8b8b10a93f43e65962104de6070c43c5dacb0e3cdfefc7c8059dd83c9db64f35", size = 144279, upload-time = "2026-04-20T20:21:47.735Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8c/aa/f7a53321c60b9ad9ee184b6018292ed6b5389947592a2c8c09c736bb7f9e/onnx_ir-0.2.1-py3-none-any.whl", hash = "sha256:c7285da889312f91882de2092e298a9eeeefbfc1d1951c49d983992967eb09a7", size = 166792, upload-time = "2026-04-20T20:21:46.357Z" },
]
[[package]] [[package]]
name = "onnxruntime" name = "onnxruntime"
version = "1.27.0" version = "1.27.0"
@ -1820,6 +1836,23 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/f6/2bac21f722aa45d876d4a51f26bd0ef30e704068a3cd5021a5a7cd784271/onnxruntime-1.27.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:370d211e1ceeac4cd5f45301655463ac59e27cdc74d9f7aeb2d19ff4b7a76715", size = 18670781, upload-time = "2026-06-15T22:43:17.151Z" }, { url = "https://files.pythonhosted.org/packages/b7/f6/2bac21f722aa45d876d4a51f26bd0ef30e704068a3cd5021a5a7cd784271/onnxruntime-1.27.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:370d211e1ceeac4cd5f45301655463ac59e27cdc74d9f7aeb2d19ff4b7a76715", size = 18670781, upload-time = "2026-06-15T22:43:17.151Z" },
] ]
[[package]]
name = "onnxscript"
version = "0.7.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "ml-dtypes" },
{ name = "numpy" },
{ name = "onnx" },
{ name = "onnx-ir" },
{ name = "packaging" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/9b/99/fd948eba63ba65b52265a4cd09a14f96bb9f5b730fcef58876c4358bf406/onnxscript-0.7.0.tar.gz", hash = "sha256:c95ed7b339b02cface56ee27689565c46612e1fc542c562298dddfdad5268dc5", size = 612032, upload-time = "2026-04-20T17:09:19.775Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b9/ce/2ed92575cc3be4ea1db5f38f16f20765f9b20b69b14d6c1d9972658a8ee9/onnxscript-0.7.0-py3-none-any.whl", hash = "sha256:5b356907d4501e9919f8599c91d8da967406a37b1fac2b40caa55a49acf242ea", size = 714842, upload-time = "2026-04-20T17:09:22.089Z" },
]
[[package]] [[package]]
name = "opencv-python-headless" name = "opencv-python-headless"
version = "4.13.0.92" version = "4.13.0.92"
@ -2446,6 +2479,7 @@ dependencies = [
deploy = [ deploy = [
{ name = "onnx" }, { name = "onnx" },
{ name = "onnxruntime" }, { name = "onnxruntime" },
{ name = "onnxscript" },
{ name = "openvino" }, { name = "openvino" },
] ]
dev = [ dev = [
@ -2471,6 +2505,7 @@ requires-dist = [
deploy = [ deploy = [
{ name = "onnx", specifier = "~=1.22.0" }, { name = "onnx", specifier = "~=1.22.0" },
{ name = "onnxruntime", specifier = "~=1.27.0" }, { name = "onnxruntime", specifier = "~=1.27.0" },
{ name = "onnxscript", specifier = "~=0.7.0" },
{ name = "openvino", specifier = "~=2026.2.0" }, { name = "openvino", specifier = "~=2026.2.0" },
] ]
dev = [ dev = [