This commit is contained in:
2026-05-30 00:24:53 -05:00
commit 427637a0d3
9 changed files with 2438 additions and 0 deletions
+44
View File
@@ -0,0 +1,44 @@
import argparse
from pathlib import Path
from onnxruntime.quantization import quantize_dynamic, QuantType
from onnxruntime.quantization.shape_inference import quant_pre_process
def quantize_model(input_path: Path, output_path: Path):
# Create temporary path for the pre-processed model
preprocessed_path = input_path.with_name(f"{input_path.stem}_preprocessed.onnx")
print(f"Pre-processing {input_path.name}...")
try:
quant_pre_process(str(input_path), str(preprocessed_path))
target_input = preprocessed_path
except Exception as e:
print(f"Pre-processing skipped or failed: {e}")
target_input = input_path
print(f"Quantizing {target_input.name}...")
try:
quantize_dynamic(
model_input=str(target_input),
model_output=str(output_path),
weight_type=QuantType.QUInt8,
# Limit quantization to MatMul. This bypasses the Conv layers
# that cause weight initialization errors, while still optimizing
# the heavy transformer layers.
op_types_to_quantize=["MatMul"]
)
print(f"Quantization complete: {output_path}")
finally:
# Clean up temporary preprocessed file if it was created
if preprocessed_path.exists() and preprocessed_path != input_path:
preprocessed_path.unlink()
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--model", required=True, help="Path to the ONNX model to quantize")
args = p.parse_args()
in_path = Path(args.model)
out_path = in_path.with_name(f"{in_path.stem}_quant.onnx")
quantize_model(in_path, out_path)