New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Adding a pass to replace interpolate function with aten::__interpolate #35744
Conversation
…oof/interpolate # Conflicts: # torch/onnx/symbolic_opset12.py
💊 Build failures summary and remediationsAs of commit 9df600a (more details on the Dr. CI page):
🕵️ 2 new failures recognized by patternsThe following build failures do not appear to be due to upstream breakages (reran 1 job to discount flakiness): pytorch_xla_linux_xenial_py3_6_clang7_test (1/2)Step: "Test" (full log | pattern match details | 🔁 rerun) <confirmed not flaky by 2 failures>
|
cc @houseroad @eellison and @fmassa for review. |
…oof/interpolate
…oof/interpolate
…oof/interpolate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks good, a few comments
torch/csrc/jit/api/function_impl.cpp
Outdated
@@ -64,6 +65,7 @@ const c10::FunctionSchema& GraphFunction::getSchema() const { | |||
void preoptimizeGraph(std::shared_ptr<Graph>& graph) { | |||
// TODO: Invoke cleanup passes before and after inlining to reduce amount of | |||
// code we're copying. | |||
PreInlineONNX(*graph); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be running every time we run inlining on a graph, just as an explicitly invoked pass by ONNX before ONNX runs inlining.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ONNX Calls the tracer which internally inlines functions. I'm trying to see where is the best place to put the pre-inline pass. Maybe somewhere in the tracer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we looked at the example together, seems like tracer doesn't eagerly inline so it should be possible outside of the tracer.
@torch.jit.script
def fn(x):
return x + 2
@torch.jit.script
def fn2(x):
return fn(x) + 3
def fn3(x):
return x + fn2(x)
traced = torch.jit.trace(fn3, (torch.rand(3, 4),))
print(traced.graph)
graph(%x : Double(3, 4)):
%1 : Function = prim::Constant[name="fn2"]()
%2 : Tensor = prim::CallFunction(%1, %x)
%3 : int = prim::Constant[value=1]() # test/test_jit.py:16104:0
%4 : Double(3, 4) = aten::add(%x, %2, %3) # test/test_jit.py:16104:0
return (%4)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually putting this preinline pass before calling inline does not fix the graph.
So the problem is with function->optimized_graph
This optimized_graph is actually inlined within the tracer, and it does not get updated when the function->graph is updated.
I maybe able to add an API to update the optimize_graph.
Here it is:
std::shared_ptr<Graph> optimized_graph() const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What doesn't work if you put the pass here ?
https://github.com/pytorch/pytorch/blob/master/torch/onnx/utils.py#L115
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because this pass is modifying the function graph, not the function optimized_graph. And the latter is used by the downstream code. Let me know if it's easier to have a quick call about this.
…oof/interpolate
…nto neraoof/interpolate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
few more comments
} | ||
} break; | ||
case prim::CallMethod: { | ||
const std::string& name = cur->s(attr::name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test for the new code path ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm adding a test in onnxruntime backend tests to cover this case (interpolate call within submodule). Let me know if you think of a more generalized way of testing this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!!! Accepting, but I believe you have one commented out line in test_pytorch_onnx_onnxruntime.py
before we can land this.
} else { | ||
cur->removeInput(0); | ||
functionCallSubstitution(fun_type->function()->graph()->block()); | ||
GRAPH_UPDATE( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think the graph updates logging should be in the other branch no ? (we log when when we change something)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this makes sense, will update it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
test is failing. |
@houseroad Test failure is fixed. I had to use a test case which is supported in opset 9 and higher, and set the flag to not run the test for lower opsets. |
@houseroad @eellison Could you please import this? ONNX tests are passing now. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Could you link the PR that removes the aten::__interpolate (maybe in the description), so that is easier to follow along your PR. Thanks! |
@houseroad @eellison Can we merge the PR? Thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@houseroad merged this pull request in f99a28f. |
Since aten;:__interpolate is removed in #34514, we need a pass replace interpolate function with aten::__interpolate for ONNX export.