Finding a good balance

This commit is contained in:
2026-05-31 00:50:46 -05:00
parent 5578b84fd8
commit bee1ed65a4
8 changed files with 453 additions and 117 deletions
+34
View File
@@ -0,0 +1,34 @@
# optimize_models.py
from onnxruntime.transformers.optimizer import optimize_model
from onnxruntime.transformers.fusion_options import FusionOptions
def optimize_custom(input_path, output_path):
print(f"Optimizing {input_path}...")
# Load default BERT fusion options
options = FusionOptions("bert")
# Disable LayerNorm fusions that break on AdaLN / dynamic biases
options.enable_skip_layer_norm = False
options.enable_layer_norm = False
# Run the optimizer
optimizer = optimize_model(
input=input_path,
model_type="bert",
optimization_options=options
)
optimizer.save_model_to_file(output_path)
print(f"Saved optimized model to {output_path}\n")
if __name__ == "__main__":
optimize_custom("outputs/encode.onnx", "outputs/encode_opt.onnx")
optimize_custom("outputs/decode.onnx", "outputs/decode_opt.onnx")
# ssl.onnx (WavLM) is a standard BERT architecture, so we can leave
# all standard fusions enabled for maximum speed.
print("Optimizing outputs/ssl.onnx...")
ssl_opt = optimize_model("outputs/ssl.onnx", model_type="bert")
ssl_opt.save_model_to_file("outputs/ssl_opt.onnx")
print("Saved optimized model to outputs/ssl_opt.onnx")