init
This commit is contained in:
+44
@@ -0,0 +1,44 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from onnxruntime.quantization import quantize_dynamic, QuantType
|
||||
from onnxruntime.quantization.shape_inference import quant_pre_process
|
||||
|
||||
|
||||
def quantize_model(input_path: Path, output_path: Path):
|
||||
# Create temporary path for the pre-processed model
|
||||
preprocessed_path = input_path.with_name(f"{input_path.stem}_preprocessed.onnx")
|
||||
|
||||
print(f"Pre-processing {input_path.name}...")
|
||||
try:
|
||||
quant_pre_process(str(input_path), str(preprocessed_path))
|
||||
target_input = preprocessed_path
|
||||
except Exception as e:
|
||||
print(f"Pre-processing skipped or failed: {e}")
|
||||
target_input = input_path
|
||||
|
||||
print(f"Quantizing {target_input.name}...")
|
||||
try:
|
||||
quantize_dynamic(
|
||||
model_input=str(target_input),
|
||||
model_output=str(output_path),
|
||||
weight_type=QuantType.QUInt8,
|
||||
# Limit quantization to MatMul. This bypasses the Conv layers
|
||||
# that cause weight initialization errors, while still optimizing
|
||||
# the heavy transformer layers.
|
||||
op_types_to_quantize=["MatMul"]
|
||||
)
|
||||
print(f"Quantization complete: {output_path}")
|
||||
finally:
|
||||
# Clean up temporary preprocessed file if it was created
|
||||
if preprocessed_path.exists() and preprocessed_path != input_path:
|
||||
preprocessed_path.unlink()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--model", required=True, help="Path to the ONNX model to quantize")
|
||||
args = p.parse_args()
|
||||
|
||||
in_path = Path(args.model)
|
||||
out_path = in_path.with_name(f"{in_path.stem}_quant.onnx")
|
||||
quantize_model(in_path, out_path)
|
||||
Reference in New Issue
Block a user