New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Added support for constant folding onnx::Add and onnx::Sub #35869
Conversation
💊 CircleCI build failures summary and remediationsAs of commit 8afbc09 (more details on the Dr. CI page): ✅ None of the build failures appear to be your fault 💚
🚧 2 upstream failures:These were probably caused by upstream breakages:
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 20 times. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
CI failures are unrelated |
@pytorchbot retest this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some small comments
test/onnx/test_utility_funs.py
Outdated
operator_export_type=OperatorExportTypes.ONNX) | ||
for node in graph.nodes(): | ||
assert node.kind() != "onnx::Add" | ||
assert len(list(graph.nodes())) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also add a value check to verify the correctness of the constant folding behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
test/onnx/test_utility_funs.py
Outdated
operator_export_type=OperatorExportTypes.ONNX) | ||
for node in graph.nodes(): | ||
assert node.kind() != "onnx::Sub" | ||
assert len(list(graph.nodes())) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -233,6 +233,12 @@ c10::optional<at::Tensor> runTorchBackendForOnnx( | |||
} else if (node->kind() == onnx::Mul) { | |||
updated_val = at::mul(inputTensorValues[0], inputTensorValues[1]); | |||
return c10::optional<at::Tensor>(updated_val); | |||
} else if (node->kind() == onnx::Sub) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A potential more generic approach is to abstract the Mul/Div/Sub/Add as binary op, the pseudo code may be as following:
if (node->kind() == onnx::Div || node->kind() == onnx::Mul ...) {
kind2func k2f = { {onnx::Mul: at::mul}, ... };
function f = k2f(node->kind());
return c10::optionalat::Tensor(f(inputTensorValues[0], inputTensorValues[1]));
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After discussion with Tianyou, I take back my comments since aten::add/sub has different argument number against aten::div/mul ( 3 against 2), although we may use bind tricks to unify the interface, it is a little bit overkill. I think Tianyou's current implementation is good enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good. Could you rebase to master and let the ONNX related CI run?
2a2363e
to
8afbc09
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Done |
@houseroad merged this pull request in 8dba98d. |
…torch#35869) Summary: Added support for constant folding onnx::Add and onnx::Sub Pull Request resolved: pytorch#35869 Reviewed By: hl475 Differential Revision: D20865640 Pulled By: houseroad fbshipit-source-id: 2b8c1cc196959b5b5b9ce018dbdcb74d59a92d9f
Added support for constant folding onnx::Add and onnx::Sub