44 lines
1.6 KiB
Python
44 lines
1.6 KiB
Python
import argparse
|
|
from pathlib import Path
|
|
from onnxruntime.quantization import quantize_dynamic, QuantType
|
|
from onnxruntime.quantization.shape_inference import quant_pre_process
|
|
|
|
|
|
def quantize_model(input_path: Path, output_path: Path):
|
|
# Create temporary path for the pre-processed model
|
|
preprocessed_path = input_path.with_name(f"{input_path.stem}_preprocessed.onnx")
|
|
|
|
print(f"Pre-processing {input_path.name}...")
|
|
try:
|
|
quant_pre_process(str(input_path), str(preprocessed_path))
|
|
target_input = preprocessed_path
|
|
except Exception as e:
|
|
print(f"Pre-processing skipped or failed: {e}")
|
|
target_input = input_path
|
|
|
|
print(f"Quantizing {target_input.name}...")
|
|
try:
|
|
quantize_dynamic(
|
|
model_input=str(target_input),
|
|
model_output=str(output_path),
|
|
weight_type=QuantType.QUInt8,
|
|
# Limit quantization to MatMul. This bypasses the Conv layers
|
|
# that cause weight initialization errors, while still optimizing
|
|
# the heavy transformer layers.
|
|
op_types_to_quantize=["MatMul"]
|
|
)
|
|
print(f"Quantization complete: {output_path}")
|
|
finally:
|
|
# Clean up temporary preprocessed file if it was created
|
|
if preprocessed_path.exists() and preprocessed_path != input_path:
|
|
preprocessed_path.unlink()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--model", required=True, help="Path to the ONNX model to quantize")
|
|
args = p.parse_args()
|
|
|
|
in_path = Path(args.model)
|
|
out_path = in_path.with_name(f"{in_path.stem}_quant.onnx")
|
|
quantize_model(in_path, out_path) |