import argparse from pathlib import Path from onnxruntime.quantization import quantize_dynamic, QuantType from onnxruntime.quantization.shape_inference import quant_pre_process def quantize_model(input_path: Path, output_path: Path): # Create temporary path for the pre-processed model preprocessed_path = input_path.with_name(f"{input_path.stem}_preprocessed.onnx") print(f"Pre-processing {input_path.name}...") try: quant_pre_process(str(input_path), str(preprocessed_path)) target_input = preprocessed_path except Exception as e: print(f"Pre-processing skipped or failed: {e}") target_input = input_path print(f"Quantizing {target_input.name}...") try: quantize_dynamic( model_input=str(target_input), model_output=str(output_path), weight_type=QuantType.QUInt8, # Limit quantization to MatMul. This bypasses the Conv layers # that cause weight initialization errors, while still optimizing # the heavy transformer layers. op_types_to_quantize=["MatMul"] ) print(f"Quantization complete: {output_path}") finally: # Clean up temporary preprocessed file if it was created if preprocessed_path.exists() and preprocessed_path != input_path: preprocessed_path.unlink() if __name__ == "__main__": p = argparse.ArgumentParser() p.add_argument("--model", required=True, help="Path to the ONNX model to quantize") args = p.parse_args() in_path = Path(args.model) out_path = in_path.with_name(f"{in_path.stem}_quant.onnx") quantize_model(in_path, out_path)