34 lines
1.2 KiB
Python
34 lines
1.2 KiB
Python
# optimize_models.py
|
|
from onnxruntime.transformers.optimizer import optimize_model
|
|
from onnxruntime.transformers.fusion_options import FusionOptions
|
|
|
|
def optimize_custom(input_path, output_path):
|
|
print(f"Optimizing {input_path}...")
|
|
|
|
# Load default BERT fusion options
|
|
options = FusionOptions("bert")
|
|
|
|
# Disable LayerNorm fusions that break on AdaLN / dynamic biases
|
|
options.enable_skip_layer_norm = False
|
|
options.enable_layer_norm = False
|
|
|
|
# Run the optimizer
|
|
optimizer = optimize_model(
|
|
input=input_path,
|
|
model_type="bert",
|
|
optimization_options=options
|
|
)
|
|
|
|
optimizer.save_model_to_file(output_path)
|
|
print(f"Saved optimized model to {output_path}\n")
|
|
|
|
if __name__ == "__main__":
|
|
optimize_custom("outputs/encode.onnx", "outputs/encode_opt.onnx")
|
|
optimize_custom("outputs/decode.onnx", "outputs/decode_opt.onnx")
|
|
|
|
# ssl.onnx (WavLM) is a standard BERT architecture, so we can leave
|
|
# all standard fusions enabled for maximum speed.
|
|
print("Optimizing outputs/ssl.onnx...")
|
|
ssl_opt = optimize_model("outputs/ssl.onnx", model_type="bert")
|
|
ssl_opt.save_model_to_file("outputs/ssl_opt.onnx")
|
|
print("Saved optimized model to outputs/ssl_opt.onnx") |