Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Adding a pass to replace interpolate function with aten::__interpolate #35744

Closed
wants to merge 53 commits into from

Conversation

neginraoof
Copy link
Contributor

@neginraoof neginraoof commented Mar 31, 2020

Since aten;:__interpolate is removed in #34514, we need a pass replace interpolate function with aten::__interpolate for ONNX export.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 31, 2020
@dr-ci
Copy link

dr-ci bot commented Mar 31, 2020

💊 Build failures summary and remediations

As of commit 9df600a (more details on the Dr. CI page):


  • 2/2 failures introduced in this PR

🕵️ 2 new failures recognized by patterns

The following build failures do not appear to be due to upstream breakages (reran 1 job to discount flakiness):

See CircleCI build pytorch_xla_linux_xenial_py3_6_clang7_test (1/2)

Step: "Test" (full log | pattern match details | 🔁 rerun) <confirmed not flaky by 2 failures>

Apr 13 22:31:42 ERROR [0.001s]: TestViewOpsXLA (unittest.loader._FailedTest)
Apr 13 22:31:39 + run_dynamic python3 /var/lib/jenkins/workspace/xla/test/../../test/test_torch.py -v TestViewOpsXLA 
Apr 13 22:31:39 + XLA_EXPERIMENTAL=nonzero:masked_select 
Apr 13 22:31:39 + python3 /var/lib/jenkins/workspace/xla/test/../../test/test_torch.py -v TestViewOpsXLA 
Apr 13 22:31:42 Test results will be stored in test-reports/python-unittest 
Apr 13 22:31:42  
Apr 13 22:31:42 Running tests... 
Apr 13 22:31:42 ---------------------------------------------------------------------- 
Apr 13 22:31:42   TestViewOpsXLA (unittest.loader._FailedTest) ... ERROR (0.001s) 
Apr 13 22:31:42  
Apr 13 22:31:42 ====================================================================== 
Apr 13 22:31:42 ERROR [0.001s]: TestViewOpsXLA (unittest.loader._FailedTest) 
Apr 13 22:31:42 ---------------------------------------------------------------------- 
Apr 13 22:31:42 AttributeError: module '__main__' has no attribute 'TestViewOpsXLA' 
Apr 13 22:31:42  
Apr 13 22:31:42 ---------------------------------------------------------------------- 
Apr 13 22:31:42 Ran 1 test in 0.001s 
Apr 13 22:31:42  
Apr 13 22:31:42 FAILED (errors=1) 
Apr 13 22:31:42  
Apr 13 22:31:42 Generating XML reports... 
Apr 13 22:31:42 Generated XML report: test-reports/python-unittest/TEST-unittest.loader._FailedTest-20200413223142.xml 

See CircleCI build pytorch_ios_11_2_1_x86_64_build (2/2)

Step: "Checkout code" (full log | pattern match details | 🔁 rerun) <confirmed not flaky by 4 failures>

fatal: the remote end hung up unexpectedly
Cloning into '.'... 
Warning: Permanently added the RSA host key for IP address '140.82.112.3' to the list of known hosts.  
remote: Enumerating objects: 377185         remote: Enumerating objects: 280, done.         
remote: Counting objects:  98% (275/280)         remote: Counting objects:  99% (278/280)         remote: Counting objects: 100% (280/280)         remote: Counting objects: 100% (280/280), done.         
remote: Compressing objects:  98% (208/212)         remote: Compressing objects:  99% (210/212)         remote: Compressing objects: 100% (212/212)         remote: Compressing objects: 100% (212/212), done.         
Receiving objects:   1% (3775/377465) Receiving objects:   2% (7550/377465) packet_write_wait: Connection to 140.82.112.3 port 22: Broken pipe  
fatal: the remote end hung up unexpectedly 
fatal: early EOF 
fatal: index-pack failed 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 118 times.

@neginraoof
Copy link
Contributor Author

cc @houseroad @eellison and @fmassa for review.
Thanks.

@vincentqb vincentqb added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 31, 2020
@eellison eellison self-requested a review March 31, 2020 20:30
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good, a few comments

@@ -64,6 +65,7 @@ const c10::FunctionSchema& GraphFunction::getSchema() const {
void preoptimizeGraph(std::shared_ptr<Graph>& graph) {
// TODO: Invoke cleanup passes before and after inlining to reduce amount of
// code we're copying.
PreInlineONNX(*graph);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be running every time we run inlining on a graph, just as an explicitly invoked pass by ONNX before ONNX runs inlining.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ONNX Calls the tracer which internally inlines functions. I'm trying to see where is the best place to put the pre-inline pass. Maybe somewhere in the tracer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we looked at the example together, seems like tracer doesn't eagerly inline so it should be possible outside of the tracer.

@torch.jit.script
def fn(x):
    return x + 2
@torch.jit.script
def fn2(x):
    return fn(x) + 3
def fn3(x):
    return x + fn2(x)
traced = torch.jit.trace(fn3, (torch.rand(3, 4),))
print(traced.graph)

graph(%x : Double(3, 4)):
  %1 : Function = prim::Constant[name="fn2"]()
  %2 : Tensor = prim::CallFunction(%1, %x)
  %3 : int = prim::Constant[value=1]() # test/test_jit.py:16104:0
  %4 : Double(3, 4) = aten::add(%x, %2, %3) # test/test_jit.py:16104:0
  return (%4)


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually putting this preinline pass before calling inline does not fix the graph.
So the problem is with function->optimized_graph
This optimized_graph is actually inlined within the tracer, and it does not get updated when the function->graph is updated.
I maybe able to add an API to update the optimize_graph.

Here it is:

std::shared_ptr<Graph> optimized_graph() const override {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@neginraoof neginraoof Apr 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this pass is modifying the function graph, not the function optimized_graph. And the latter is used by the downstream code. Let me know if it's easier to have a quick call about this.

torch/csrc/jit/passes/onnx/preinline_onnx.cpp Outdated Show resolved Hide resolved
torch/csrc/jit/passes/onnx/preinline_onnx.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

few more comments

torch/csrc/jit/passes/onnx/preclude_inlining.cpp Outdated Show resolved Hide resolved
torch/csrc/jit/passes/onnx/preclude_inlining.cpp Outdated Show resolved Hide resolved
torch/csrc/jit/passes/onnx/preclude_inlining.cpp Outdated Show resolved Hide resolved
}
} break;
case prim::CallMethod: {
const std::string& name = cur->s(attr::name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for the new code path ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm adding a test in onnxruntime backend tests to cover this case (interpolate call within submodule). Let me know if you think of a more generalized way of testing this.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!!! Accepting, but I believe you have one commented out line in test_pytorch_onnx_onnxruntime.py before we can land this.

} else {
cur->removeInput(0);
functionCallSubstitution(fun_type->function()->graph()->block());
GRAPH_UPDATE(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the graph updates logging should be in the other branch no ? (we log when when we change something)

Copy link
Contributor Author

@neginraoof neginraoof Apr 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this makes sense, will update it.

torch/csrc/jit/passes/onnx/function_substitution.cpp Outdated Show resolved Hide resolved
test/onnx/test_pytorch_onnx_onnxruntime.py Outdated Show resolved Hide resolved
torch/csrc/jit/passes/onnx/function_substitution.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@houseroad
Copy link
Member

test is failing.

@neginraoof
Copy link
Contributor Author

@houseroad Test failure is fixed. I had to use a test case which is supported in opset 9 and higher, and set the flag to not run the test for lower opsets.

@neginraoof
Copy link
Contributor Author

neginraoof commented Apr 10, 2020

@houseroad @eellison Could you please import this? ONNX tests are passing now. Thanks.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@KsenijaS
Copy link
Contributor

Could you link the PR that removes the aten::__interpolate (maybe in the description), so that is easier to follow along your PR. Thanks!

@neginraoof
Copy link
Contributor Author

@houseroad @eellison Can we merge the PR? Thanks

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@houseroad merged this pull request in f99a28f.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged oncall: jit Add this issue/PR to JIT oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet