Finding a good balance
This commit is contained in:
+34
@@ -0,0 +1,34 @@
|
||||
# optimize_models.py
|
||||
from onnxruntime.transformers.optimizer import optimize_model
|
||||
from onnxruntime.transformers.fusion_options import FusionOptions
|
||||
|
||||
def optimize_custom(input_path, output_path):
|
||||
print(f"Optimizing {input_path}...")
|
||||
|
||||
# Load default BERT fusion options
|
||||
options = FusionOptions("bert")
|
||||
|
||||
# Disable LayerNorm fusions that break on AdaLN / dynamic biases
|
||||
options.enable_skip_layer_norm = False
|
||||
options.enable_layer_norm = False
|
||||
|
||||
# Run the optimizer
|
||||
optimizer = optimize_model(
|
||||
input=input_path,
|
||||
model_type="bert",
|
||||
optimization_options=options
|
||||
)
|
||||
|
||||
optimizer.save_model_to_file(output_path)
|
||||
print(f"Saved optimized model to {output_path}\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
optimize_custom("outputs/encode.onnx", "outputs/encode_opt.onnx")
|
||||
optimize_custom("outputs/decode.onnx", "outputs/decode_opt.onnx")
|
||||
|
||||
# ssl.onnx (WavLM) is a standard BERT architecture, so we can leave
|
||||
# all standard fusions enabled for maximum speed.
|
||||
print("Optimizing outputs/ssl.onnx...")
|
||||
ssl_opt = optimize_model("outputs/ssl.onnx", model_type="bert")
|
||||
ssl_opt.save_model_to_file("outputs/ssl_opt.onnx")
|
||||
print("Saved optimized model to outputs/ssl_opt.onnx")
|
||||
Reference in New Issue
Block a user