From cf4c8a090a326621a2aba99d148c808bcf768fa7 Mon Sep 17 00:00:00 2001 From: deng Date: Thu, 18 Jun 2026 16:20:14 +0800 Subject: [PATCH] convert pt model to onnx --- .gitignore | 7 ++- pyproject.toml | 3 +- quickdraw_bot/assets/config.yaml | 8 ++- quickdraw_bot/assets/model.onnx.dvc | 5 ++ quickdraw_bot/deploy.py | 0 quickdraw_bot/prepare.py | 14 +++-- quickdraw_bot/to_onnx.py | 98 +++++++++++++++++++++++++++++ uv.lock | 35 +++++++++++ 8 files changed, 160 insertions(+), 10 deletions(-) create mode 100644 quickdraw_bot/assets/model.onnx.dvc delete mode 100644 quickdraw_bot/deploy.py create mode 100644 quickdraw_bot/to_onnx.py diff --git a/.gitignore b/.gitignore index 44b89e5..a6a8a74 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,9 @@ quickdraw_bot/tmp dvc/config.local # .DS_Store -.DS_Store \ No newline at end of file +.DS_Store + +# Models +*.pth +*.onnx +*.onnx.data \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b3121e1..b566c13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,8 @@ dev = [ deploy = [ "onnx~=1.22.0", "onnxruntime~=1.27.0", - "openvino~=2026.2.0" + "onnxscript~=0.7.0", + "openvino~=2026.2.0", ] [tool.ruff] diff --git a/quickdraw_bot/assets/config.yaml b/quickdraw_bot/assets/config.yaml index e3f677d..49b14c4 100644 --- a/quickdraw_bot/assets/config.yaml +++ b/quickdraw_bot/assets/config.yaml @@ -22,6 +22,8 @@ train: eval: test_data_dir: ./data/processed/test model_path: ./assets/model.pth - random_seed: 1 -deploy: - random_seed: 1 \ No newline at end of file +to_onnx: + model_path: ./assets/model.pth + image_size: [28, 28] + cls_map_path: ./data/processed/cate_id_cate_name_map.json + onnx_path: ./assets/model.onnx \ No newline at end of file diff --git a/quickdraw_bot/assets/model.onnx.dvc b/quickdraw_bot/assets/model.onnx.dvc new file mode 100644 index 0000000..df2ae20 --- /dev/null +++ b/quickdraw_bot/assets/model.onnx.dvc @@ -0,0 +1,5 @@ +outs: +- md5: 5a641ce9dc5db7b16a4868773eb968dc + size: 1755967 + hash: md5 + path: model.onnx diff --git a/quickdraw_bot/deploy.py b/quickdraw_bot/deploy.py deleted file mode 100644 index e69de29..0000000 diff --git a/quickdraw_bot/prepare.py b/quickdraw_bot/prepare.py index a0ca22b..70a32f3 100644 --- a/quickdraw_bot/prepare.py +++ b/quickdraw_bot/prepare.py @@ -3,6 +3,7 @@ # author: deng # date : 20260616 +import json import random from pathlib import Path from shutil import rmtree @@ -21,7 +22,7 @@ class Prepare: random.seed(self.config['random_seed']) np.random.seed(self.config['random_seed']) - def _load_dataset(self) -> dict[str, np.ndarray]: + def _load_dataset(self) -> tuple[dict[str, np.ndarray], dict[str, int]]: data: dict[str, list] = {'images': [], 'cate_names': [], 'cate_ids': []} cls_id_map: dict[str, int] = {} raw_data_dir = Path(self.config['data_dir']) / 'raw' @@ -45,7 +46,7 @@ class Prepare: data['images'] = np.array(data['images']).astype(np.uint8) data['cate_names'] = np.array(data['cate_names']).astype('S30') data['cate_ids'] = np.array(data['cate_ids']).astype(np.uint16) - return data + return data, cls_id_map def _split_data(self, data: dict[str, np.ndarray]) -> dict[str, dict[str, np.ndarray]]: weights = self.config['data_split'] @@ -67,7 +68,7 @@ class Prepare: sets[name] = {key: value[idx] for key, value in data.items()} return sets - def _save_data(self, sets: dict[str, dict[str, np.ndarray]]) -> None: + def _save_data(self, sets: dict[str, dict[str, np.ndarray]], cls_id_map: dict[str, int]) -> None: save_dir = Path(self.config['data_dir']) / 'processed' if save_dir.exists(): rmtree(save_dir) @@ -77,11 +78,14 @@ class Prepare: usage_dir.mkdir() for key, value in data.items(): np.save(f'{usage_dir}/{key}.npy', value) + cate_id_cate_name_map = {v: k for k, v in cls_id_map.items()} + with open(f'{save_dir}/cate_id_cate_name_map.json', 'w') as f: + json.dump(cate_id_cate_name_map, f, indent=2) def run(self): - data = self._load_dataset() + data, cls_id_map = self._load_dataset() sets = self._split_data(data) - self._save_data(sets) + self._save_data(sets, cls_id_map) if __name__ == '__main__': diff --git a/quickdraw_bot/to_onnx.py b/quickdraw_bot/to_onnx.py new file mode 100644 index 0000000..3d5090a --- /dev/null +++ b/quickdraw_bot/to_onnx.py @@ -0,0 +1,98 @@ +# to_onnx.py +# +# author: deng +# date : 20260618 + +import json +import time +from pathlib import Path + +import cv2 +import numpy as np +import onnxruntime as ort +import torch + +from quickdraw_bot.utils.utils import load_config + + +class Pipeline(torch.nn.Module): + def __init__(self, model: torch.nn.Module, input_size: tuple[int, int]): + super().__init__() + self._model = model + self._input_size = input_size + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + inputs = inputs.float() / 255.0 + inputs = torch.nn.functional.interpolate(inputs, size=self._input_size, mode='bilinear', align_corners=False) + logits = self._model(inputs) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +class ToONNX: + def __init__(self, config_path: str = './assets/config.yaml'): + self.config = load_config(config_path)['to_onnx'] + + def _get_cls_map(self) -> dict[int, str]: + cls_map_path = Path(self.config['cls_map_path']) + if not cls_map_path.exists(): + return None + with open(cls_map_path, 'r') as f: + cls_map = json.load(f) + return cls_map + + def _get_model(self) -> torch.nn.Module: + model = torch.load(self.config['model_path'], map_location='cpu', weights_only=False) + model.eval() + return model + + def _get_pipeline(self) -> torch.nn.Module: + pipeline = Pipeline( + model=self._get_model(), + input_size=tuple(self.config['image_size']), + ) + return pipeline + + def run(self) -> None: + pipeline = self._get_pipeline() + dummy_input = torch.randint(0, 256, (1, 1, self.config['image_size'][0], self.config['image_size'][1]), dtype=torch.uint8) + torch.onnx.export( + pipeline, + dummy_input, + self.config['onnx_path'], + input_names=['inputs'], + output_names=['outputs'], + dynamic_axes={'inputs': {0: 'batch_size', 2: 'height', 3: 'width'}, 'outputs': {0: 'batch_size'}}, + opset_version=18, + dynamo=False, + verbose=True, + ) + print(f'Done! ONNX file saved to {self.config["onnx_path"]}') + + def test_infer(self, image_path: str) -> None: + + # Get cls map + cls_map: dict[str, str] = self._get_cls_map() + cls_names = list(cls_map.values()) + + # Load Image + image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) + image = image[np.newaxis, np.newaxis, :, :] + + # Infer + onnx_session = ort.InferenceSession( + self.config['onnx_path'], + providers=['CPUExecutionProvider'], + ) + start_time = time.perf_counter() + outputs = onnx_session.run(None, {'inputs': image})[0] + elapsed_time_ms = (time.perf_counter() - start_time) * 1000 + result = [{cls_name: round(float(prob), 3) for cls_name, prob in zip(cls_names, prob_arr)} for prob_arr in outputs] + print(f'Elapsed time: {elapsed_time_ms:.2f} ms') + print(f'Image: {image_path}') + print(f'Result: {result}') + + +if __name__ == '__main__': + ToONNX().run() + ToONNX().test_infer('./assets/favicon.png') diff --git a/uv.lock b/uv.lock index 8ab7e1b..3023ea2 100644 --- a/uv.lock +++ b/uv.lock @@ -1793,6 +1793,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/50/257a880384a1dd502d543b0067945074d63cd17d0840e958355bc8197da8/onnx-1.22.0-cp314-cp314t-win_arm64.whl", hash = "sha256:8e268cdc0547e3949799ffd4a44451dc2b9080b57d0824a2db680b6ec65506f0", size = 17231391, upload-time = "2026-06-15T12:50:03.047Z" }, ] +[[package]] +name = "onnx-ir" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "onnx" }, + { name = "sympy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/35/e6/672fefb2f108d077f58181a7babf4c0f8d1182a30353ffc9c79c63afc5ee/onnx_ir-0.2.1.tar.gz", hash = "sha256:8b8b10a93f43e65962104de6070c43c5dacb0e3cdfefc7c8059dd83c9db64f35", size = 144279, upload-time = "2026-04-20T20:21:47.735Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/aa/f7a53321c60b9ad9ee184b6018292ed6b5389947592a2c8c09c736bb7f9e/onnx_ir-0.2.1-py3-none-any.whl", hash = "sha256:c7285da889312f91882de2092e298a9eeeefbfc1d1951c49d983992967eb09a7", size = 166792, upload-time = "2026-04-20T20:21:46.357Z" }, +] + [[package]] name = "onnxruntime" version = "1.27.0" @@ -1820,6 +1836,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/f6/2bac21f722aa45d876d4a51f26bd0ef30e704068a3cd5021a5a7cd784271/onnxruntime-1.27.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:370d211e1ceeac4cd5f45301655463ac59e27cdc74d9f7aeb2d19ff4b7a76715", size = 18670781, upload-time = "2026-06-15T22:43:17.151Z" }, ] +[[package]] +name = "onnxscript" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ml-dtypes" }, + { name = "numpy" }, + { name = "onnx" }, + { name = "onnx-ir" }, + { name = "packaging" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9b/99/fd948eba63ba65b52265a4cd09a14f96bb9f5b730fcef58876c4358bf406/onnxscript-0.7.0.tar.gz", hash = "sha256:c95ed7b339b02cface56ee27689565c46612e1fc542c562298dddfdad5268dc5", size = 612032, upload-time = "2026-04-20T17:09:19.775Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/ce/2ed92575cc3be4ea1db5f38f16f20765f9b20b69b14d6c1d9972658a8ee9/onnxscript-0.7.0-py3-none-any.whl", hash = "sha256:5b356907d4501e9919f8599c91d8da967406a37b1fac2b40caa55a49acf242ea", size = 714842, upload-time = "2026-04-20T17:09:22.089Z" }, +] + [[package]] name = "opencv-python-headless" version = "4.13.0.92" @@ -2446,6 +2479,7 @@ dependencies = [ deploy = [ { name = "onnx" }, { name = "onnxruntime" }, + { name = "onnxscript" }, { name = "openvino" }, ] dev = [ @@ -2471,6 +2505,7 @@ requires-dist = [ deploy = [ { name = "onnx", specifier = "~=1.22.0" }, { name = "onnxruntime", specifier = "~=1.27.0" }, + { name = "onnxscript", specifier = "~=0.7.0" }, { name = "openvino", specifier = "~=2026.2.0" }, ] dev = [