convert pt model to onnx
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@ -20,3 +20,8 @@ dvc/config.local
|
||||
|
||||
# .DS_Store
|
||||
.DS_Store
|
||||
|
||||
# Models
|
||||
*.pth
|
||||
*.onnx
|
||||
*.onnx.data
|
||||
@ -25,7 +25,8 @@ dev = [
|
||||
deploy = [
|
||||
"onnx~=1.22.0",
|
||||
"onnxruntime~=1.27.0",
|
||||
"openvino~=2026.2.0"
|
||||
"onnxscript~=0.7.0",
|
||||
"openvino~=2026.2.0",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
@ -22,6 +22,8 @@ train:
|
||||
eval:
|
||||
test_data_dir: ./data/processed/test
|
||||
model_path: ./assets/model.pth
|
||||
random_seed: 1
|
||||
deploy:
|
||||
random_seed: 1
|
||||
to_onnx:
|
||||
model_path: ./assets/model.pth
|
||||
image_size: [28, 28]
|
||||
cls_map_path: ./data/processed/cate_id_cate_name_map.json
|
||||
onnx_path: ./assets/model.onnx
|
||||
5
quickdraw_bot/assets/model.onnx.dvc
Normal file
5
quickdraw_bot/assets/model.onnx.dvc
Normal file
@ -0,0 +1,5 @@
|
||||
outs:
|
||||
- md5: 5a641ce9dc5db7b16a4868773eb968dc
|
||||
size: 1755967
|
||||
hash: md5
|
||||
path: model.onnx
|
||||
@ -3,6 +3,7 @@
|
||||
# author: deng
|
||||
# date : 20260616
|
||||
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
@ -21,7 +22,7 @@ class Prepare:
|
||||
random.seed(self.config['random_seed'])
|
||||
np.random.seed(self.config['random_seed'])
|
||||
|
||||
def _load_dataset(self) -> dict[str, np.ndarray]:
|
||||
def _load_dataset(self) -> tuple[dict[str, np.ndarray], dict[str, int]]:
|
||||
data: dict[str, list] = {'images': [], 'cate_names': [], 'cate_ids': []}
|
||||
cls_id_map: dict[str, int] = {}
|
||||
raw_data_dir = Path(self.config['data_dir']) / 'raw'
|
||||
@ -45,7 +46,7 @@ class Prepare:
|
||||
data['images'] = np.array(data['images']).astype(np.uint8)
|
||||
data['cate_names'] = np.array(data['cate_names']).astype('S30')
|
||||
data['cate_ids'] = np.array(data['cate_ids']).astype(np.uint16)
|
||||
return data
|
||||
return data, cls_id_map
|
||||
|
||||
def _split_data(self, data: dict[str, np.ndarray]) -> dict[str, dict[str, np.ndarray]]:
|
||||
weights = self.config['data_split']
|
||||
@ -67,7 +68,7 @@ class Prepare:
|
||||
sets[name] = {key: value[idx] for key, value in data.items()}
|
||||
return sets
|
||||
|
||||
def _save_data(self, sets: dict[str, dict[str, np.ndarray]]) -> None:
|
||||
def _save_data(self, sets: dict[str, dict[str, np.ndarray]], cls_id_map: dict[str, int]) -> None:
|
||||
save_dir = Path(self.config['data_dir']) / 'processed'
|
||||
if save_dir.exists():
|
||||
rmtree(save_dir)
|
||||
@ -77,11 +78,14 @@ class Prepare:
|
||||
usage_dir.mkdir()
|
||||
for key, value in data.items():
|
||||
np.save(f'{usage_dir}/{key}.npy', value)
|
||||
cate_id_cate_name_map = {v: k for k, v in cls_id_map.items()}
|
||||
with open(f'{save_dir}/cate_id_cate_name_map.json', 'w') as f:
|
||||
json.dump(cate_id_cate_name_map, f, indent=2)
|
||||
|
||||
def run(self):
|
||||
data = self._load_dataset()
|
||||
data, cls_id_map = self._load_dataset()
|
||||
sets = self._split_data(data)
|
||||
self._save_data(sets)
|
||||
self._save_data(sets, cls_id_map)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
98
quickdraw_bot/to_onnx.py
Normal file
98
quickdraw_bot/to_onnx.py
Normal file
@ -0,0 +1,98 @@
|
||||
# to_onnx.py
|
||||
#
|
||||
# author: deng
|
||||
# date : 20260618
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
from quickdraw_bot.utils.utils import load_config
|
||||
|
||||
|
||||
class Pipeline(torch.nn.Module):
|
||||
def __init__(self, model: torch.nn.Module, input_size: tuple[int, int]):
|
||||
super().__init__()
|
||||
self._model = model
|
||||
self._input_size = input_size
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = inputs.float() / 255.0
|
||||
inputs = torch.nn.functional.interpolate(inputs, size=self._input_size, mode='bilinear', align_corners=False)
|
||||
logits = self._model(inputs)
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
return probs
|
||||
|
||||
|
||||
class ToONNX:
|
||||
def __init__(self, config_path: str = './assets/config.yaml'):
|
||||
self.config = load_config(config_path)['to_onnx']
|
||||
|
||||
def _get_cls_map(self) -> dict[int, str]:
|
||||
cls_map_path = Path(self.config['cls_map_path'])
|
||||
if not cls_map_path.exists():
|
||||
return None
|
||||
with open(cls_map_path, 'r') as f:
|
||||
cls_map = json.load(f)
|
||||
return cls_map
|
||||
|
||||
def _get_model(self) -> torch.nn.Module:
|
||||
model = torch.load(self.config['model_path'], map_location='cpu', weights_only=False)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def _get_pipeline(self) -> torch.nn.Module:
|
||||
pipeline = Pipeline(
|
||||
model=self._get_model(),
|
||||
input_size=tuple(self.config['image_size']),
|
||||
)
|
||||
return pipeline
|
||||
|
||||
def run(self) -> None:
|
||||
pipeline = self._get_pipeline()
|
||||
dummy_input = torch.randint(0, 256, (1, 1, self.config['image_size'][0], self.config['image_size'][1]), dtype=torch.uint8)
|
||||
torch.onnx.export(
|
||||
pipeline,
|
||||
dummy_input,
|
||||
self.config['onnx_path'],
|
||||
input_names=['inputs'],
|
||||
output_names=['outputs'],
|
||||
dynamic_axes={'inputs': {0: 'batch_size', 2: 'height', 3: 'width'}, 'outputs': {0: 'batch_size'}},
|
||||
opset_version=18,
|
||||
dynamo=False,
|
||||
verbose=True,
|
||||
)
|
||||
print(f'Done! ONNX file saved to {self.config["onnx_path"]}')
|
||||
|
||||
def test_infer(self, image_path: str) -> None:
|
||||
|
||||
# Get cls map
|
||||
cls_map: dict[str, str] = self._get_cls_map()
|
||||
cls_names = list(cls_map.values())
|
||||
|
||||
# Load Image
|
||||
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
||||
image = image[np.newaxis, np.newaxis, :, :]
|
||||
|
||||
# Infer
|
||||
onnx_session = ort.InferenceSession(
|
||||
self.config['onnx_path'],
|
||||
providers=['CPUExecutionProvider'],
|
||||
)
|
||||
start_time = time.perf_counter()
|
||||
outputs = onnx_session.run(None, {'inputs': image})[0]
|
||||
elapsed_time_ms = (time.perf_counter() - start_time) * 1000
|
||||
result = [{cls_name: round(float(prob), 3) for cls_name, prob in zip(cls_names, prob_arr)} for prob_arr in outputs]
|
||||
print(f'Elapsed time: {elapsed_time_ms:.2f} ms')
|
||||
print(f'Image: {image_path}')
|
||||
print(f'Result: {result}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ToONNX().run()
|
||||
ToONNX().test_infer('./assets/favicon.png')
|
||||
35
uv.lock
generated
35
uv.lock
generated
@ -1793,6 +1793,22 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/00/50/257a880384a1dd502d543b0067945074d63cd17d0840e958355bc8197da8/onnx-1.22.0-cp314-cp314t-win_arm64.whl", hash = "sha256:8e268cdc0547e3949799ffd4a44451dc2b9080b57d0824a2db680b6ec65506f0", size = 17231391, upload-time = "2026-06-15T12:50:03.047Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnx-ir"
|
||||
version = "0.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "ml-dtypes" },
|
||||
{ name = "numpy" },
|
||||
{ name = "onnx" },
|
||||
{ name = "sympy" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/35/e6/672fefb2f108d077f58181a7babf4c0f8d1182a30353ffc9c79c63afc5ee/onnx_ir-0.2.1.tar.gz", hash = "sha256:8b8b10a93f43e65962104de6070c43c5dacb0e3cdfefc7c8059dd83c9db64f35", size = 144279, upload-time = "2026-04-20T20:21:47.735Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/aa/f7a53321c60b9ad9ee184b6018292ed6b5389947592a2c8c09c736bb7f9e/onnx_ir-0.2.1-py3-none-any.whl", hash = "sha256:c7285da889312f91882de2092e298a9eeeefbfc1d1951c49d983992967eb09a7", size = 166792, upload-time = "2026-04-20T20:21:46.357Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnxruntime"
|
||||
version = "1.27.0"
|
||||
@ -1820,6 +1836,23 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/f6/2bac21f722aa45d876d4a51f26bd0ef30e704068a3cd5021a5a7cd784271/onnxruntime-1.27.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:370d211e1ceeac4cd5f45301655463ac59e27cdc74d9f7aeb2d19ff4b7a76715", size = 18670781, upload-time = "2026-06-15T22:43:17.151Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnxscript"
|
||||
version = "0.7.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "ml-dtypes" },
|
||||
{ name = "numpy" },
|
||||
{ name = "onnx" },
|
||||
{ name = "onnx-ir" },
|
||||
{ name = "packaging" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9b/99/fd948eba63ba65b52265a4cd09a14f96bb9f5b730fcef58876c4358bf406/onnxscript-0.7.0.tar.gz", hash = "sha256:c95ed7b339b02cface56ee27689565c46612e1fc542c562298dddfdad5268dc5", size = 612032, upload-time = "2026-04-20T17:09:19.775Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/ce/2ed92575cc3be4ea1db5f38f16f20765f9b20b69b14d6c1d9972658a8ee9/onnxscript-0.7.0-py3-none-any.whl", hash = "sha256:5b356907d4501e9919f8599c91d8da967406a37b1fac2b40caa55a49acf242ea", size = 714842, upload-time = "2026-04-20T17:09:22.089Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opencv-python-headless"
|
||||
version = "4.13.0.92"
|
||||
@ -2446,6 +2479,7 @@ dependencies = [
|
||||
deploy = [
|
||||
{ name = "onnx" },
|
||||
{ name = "onnxruntime" },
|
||||
{ name = "onnxscript" },
|
||||
{ name = "openvino" },
|
||||
]
|
||||
dev = [
|
||||
@ -2471,6 +2505,7 @@ requires-dist = [
|
||||
deploy = [
|
||||
{ name = "onnx", specifier = "~=1.22.0" },
|
||||
{ name = "onnxruntime", specifier = "~=1.27.0" },
|
||||
{ name = "onnxscript", specifier = "~=0.7.0" },
|
||||
{ name = "openvino", specifier = "~=2026.2.0" },
|
||||
]
|
||||
dev = [
|
||||
|
||||
Reference in New Issue
Block a user