Skip to content

Optimize FSDP2 Pytest Timings (12 -> 2 mins)#2787

Open
vthumbe1503 wants to merge 22 commits intoNVIDIA:mainfrom
vthumbe1503:fsdp_pytest_infra_change
Open

Optimize FSDP2 Pytest Timings (12 -> 2 mins)#2787
vthumbe1503 wants to merge 22 commits intoNVIDIA:mainfrom
vthumbe1503:fsdp_pytest_infra_change

Conversation

@vthumbe1503
Copy link
Collaborator

@vthumbe1503 vthumbe1503 commented Mar 22, 2026

Description

Rearchitect FSDP2 tests to share the same process and distributed process group using a nested Pytesting scheme. Outer level Pytest calls a torchrun that initializes a pytest session to run multiple tests sharing torch distributed group.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

ksivaman and others added 12 commits March 23, 2026 01:10
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* init

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* work finished

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* lint fixes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* fixes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: root <pgadzinski@nvidia.com>

* removed warning.warn

Signed-off-by: root <pgadzinski@nvidia.com>

* [PyTorch] Remove dead None-check for num_out_tokens in moe_permute_mask_map_forward

num_out_tokens is typed as int in the custom_op signature and can never
be None; the check was incorrectly carried over from the class-based
upstream version during merge conflict resolution.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

---------

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: root <pgadzinski@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…are detached (NVIDIA#2772)

[PyTorch] Change the restore tensor API to ensure tensors are detached from ctx

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ges it (NVIDIA#2781)

Install pytest in onnx L1 test as Pyt container no longer packages it

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…_descriptors (NVIDIA#2782)

* Fix zero-sized groups in update_tma_descriptors

Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* Update test_cast_mxfp8_grouped.cu

Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
This reverts commit e355c38.

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 force-pushed the fsdp_pytest_infra_change branch 2 times, most recently from 9e91caf to 4f0f7f9 Compare March 23, 2026 01:14
@vthumbe1503 vthumbe1503 marked this pull request as ready for review March 23, 2026 01:33
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 23, 2026

Greptile Summary

This PR rearchitects the FSDP2 pytest suite so the expensive distributed process-group setup happens once per inner pytest session (via a session-scoped dist_init fixture in conftest.py) rather than once per test (previously each test launched its own torchrun subprocess). The outer pytest in test_torch_fsdp2.py now issues just two subprocess.run calls — one torchrun -m pytest invocation for model tests and one for FusedAdam tests — instead of one subprocess per test-case combination. The reported speedup (12 → 2 min) reflects eliminating the NCCL communicator init overhead multiplied across all parametrized recipes.

Key changes:

  • conftest.py: new session-scoped dist_init + per-test _cleanup (with dist.barrier()) shared across all inner tests.
  • fsdp2_utils.py: shared helper functions extracted from both test modules to avoid the conftest double-import problem flagged in the previous review.
  • run_fsdp2_model.py / run_fsdp2_fused_adam.py: _setup() / dist.init_process_group() calls removed from individual tests; tests now receive distributed state via fixtures; checkpoint/save paths include recipe_name to prevent collisions between parametrized runs.
  • test_torch_fsdp2.py: all per-recipe outer test functions collapsed into two umbrella tests; timeout=600 and explicit returncode in (0, 5) check added (previously flagged timeout absence is resolved).
  • All concerns from prior review threads are addressed (barrier before cleanup, sys.path double-import, subprocess timeout, set_device before get_device_capability).

Confidence Score: 4/5

  • PR is safe to merge; all previously flagged concerns are resolved and remaining items are non-blocking style suggestions.
  • All four concerns from prior review threads (dist.barrier() before cleanup, conftest double-import, subprocess timeout, device capability vs. set_device ordering) are correctly addressed. The new architecture is sound: single NCCL group init per inner pytest session, per-test barrier-guarded cleanup, and recipe-namespaced checkpoint paths. Remaining P2s (--local-ranks-filter=0 hides non-rank-0 diagnostics, missing degenerate-mesh guard in test_distributed, xfail comment clarity) are style-level and don't affect correctness or reliability.
  • tests/pytorch/distributed/test_torch_fsdp2.py (--local-ranks-filter=0 diagnostic visibility) and tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py (degenerate sharding_dims guard).

Important Files Changed

Filename Overview
tests/pytorch/distributed/test_torch_fsdp2.py Outer pytest driver simplified to two subprocess wrappers — correctly accepts exit code 5 and has timeout=600; output from non-rank-0 processes is silenced via --local-ranks-filter=0.
tests/pytorch/distributed/fsdp2_tests/conftest.py Session-scoped dist_init, per-test _cleanup with dist.barrier(), and recipe_name parametrization. Module-level set_device correctly addresses the previous concern about get_device_capability() being called before dist_init.
tests/pytorch/distributed/fsdp2_tests/fsdp2_utils.py Clean extraction of shared utilities (get_recipe_from_string, save/restore_custom_attrs) into a standalone module, resolving the previous conftest double-import concern.
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py _setup() removed; all tests now receive recipe_name fixture; checkpoint paths include recipe_name to avoid conflicts across parametrized runs; test_fuse_wgrad_accumulation now uses in-process xfail rather than subprocess CalledProcessError.
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py _train() split into _run_training() (shared session) and _train() (standalone); NUM_PROCS-based parametrize for sharding_dims is evaluated at torchrun import time; acknowledged degenerate HSDP (2,1) case on 2-GPU machines.

Sequence Diagram

sequenceDiagram
    participant CI as CI / outer pytest
    participant TP as test_torch_fsdp2.py
    participant TR as torchrun (nproc workers)
    participant IP as inner pytest session
    participant CF as conftest.py fixtures
    participant TF as run_fsdp2_model.py / run_fsdp2_fused_adam.py

    CI->>TP: pytest test_fsdp2_model_tests (or test_fsdp2_fused_adam_tests)
    TP->>TR: subprocess.run(["torchrun", "-m", "pytest", ...], timeout=600)
    TR->>IP: spawn N worker processes, each running pytest session
    IP->>CF: dist_init fixture (session-scoped, autouse)\n init_process_group once
    loop For each parametrized test (recipe × fp8_init × sharding × layer)
        IP->>TF: call test_distributed(recipe_name, ...)
        TF->>TF: _run_training(args) — uses shared PG
        TF-->>IP: pass / fail / xfail
        IP->>CF: _cleanup fixture\n dist.barrier() + gc.collect() + empty_cache()
    end
    IP->>CF: dist_init teardown — destroy_process_group()
    IP-->>TR: exit code (0=pass, 1=fail, 5=no tests)
    TR-->>TP: exit code (max over all workers)
    TP->>TP: assert returncode in (0, 5)
    TP-->>CI: PASS / FAIL
Loading

Reviews (3): Last reviewed commit: "Merge branch 'fsdp_pytest_infra_change' ..." | Re-trigger Greptile

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

pre-commit-ci bot and others added 4 commits March 23, 2026 05:31
Comment on lines +362 to +363
@pytest.mark.parametrize("sharding_dims", [[NUM_PROCS], [2, NUM_PROCS // 2]])
@pytest.mark.parametrize("fp8_init", [False, True])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 HSDP parametrize produces a degenerate (2, 1) mesh when WORLD_SIZE=2

NUM_PROCS is evaluated at module load time from os.environ.get("WORLD_SIZE", "1"). When running under the outer test which uses nproc_per_node=NUM_PROCS (e.g., NUM_PROCS=2), the second parametrize value becomes [2, 2//2] = [2, 1].

A (2, 1) device mesh creates 2 replicas with 1 GPU per shard group, which means no tensor sharding actually occurs along the "shard" dimension. The FSDP2 HSDP path is exercised with a trivial sharding configuration that behaves identically to DDP, giving no additional sharding coverage over the first [NUM_PROCS] case.

This only tests actual hybrid sharding when NUM_PROCS >= 4. Since the outer test_fsdp2_model_tests only requires NUM_PROCS % 2 == 0, a 2-GPU runner always gets a degenerate HSDP case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesnt stop us from testing it on machines with >=4 GPUs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And calling it degenerate doesnt make sense to me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right - "degenerate" was the wrong term. The [2, 1] mesh is a valid HSDP configuration that exercises the HSDP code path, and on 4+ GPU machines it creates proper hybrid sharding like [2, 2].

My concern was narrower: on 2-GPU CI runners, both parametrize cases ([2] and [2, 1]) produce equivalent behavior since the sharding dimension is 1 in the second case. But that's a test coverage observation for specific hardware, not a correctness issue with the code.

The implementation is fine - it tests valid configurations and scales properly to larger GPU counts.

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…nsformerEngine into fsdp_pytest_infra_change
[
"torchrun",
f"--nproc_per_node={NUM_PROCS}",
"--local-ranks-filter=0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 --local-ranks-filter=0 silences non-rank-0 failure diagnostics

--local-ranks-filter=0 (a torchrun output-filtering flag) suppresses stdout/stderr from every rank except rank 0. A torchrun process still exits non-zero if any worker exits non-zero, so the overall pass/fail signal reaches the outer assert correctly. However, when a test fails on rank 1 (e.g., a non-collective assertion like assert losses[-1] < losses[0]), the traceback and test name from rank 1 are invisible. CI will report the outer test as failed, but the specific failing test case (recipe + parametrize combination) on non-rank-0 processes won't appear in the logs.

Consider either dropping --local-ranks-filter=0 entirely, or adding -p no:terminal on ranks > 0 via PYTEST_ADDOPTS so the tracebacks are still emitted from every rank.

The same applies to test_fsdp2_fused_adam_tests on line 51.

Comment on lines 398 to 451
@@ -400,8 +417,8 @@ def test_fuse_wgrad_accumulation(recipe=None):
writes the gradient directly into main_grad and returns None to autograd,
bypassing FSDP2's reduce-scatter.
"""
world_size, _, device = _setup()

recipe = get_recipe_from_string(recipe_name)
world_size, device = _get_dist_info()
model = _build_model(fp8_init=True, fuse_wgrad_accumulation=True, recipe=recipe)

# Allocate main_grad buffers on the DTensor params
@@ -433,10 +450,8 @@ def test_fuse_wgrad_accumulation(recipe=None):
loss = F.mse_loss(output, target)
loss.backward() # Expected to raise AttributeError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 xfail in shared NCCL session — FSDP2 backward hooks not flushed

When loss.backward() raises AttributeError, FSDP2's post-backward hooks (reduce-scatter) never fire. Parameters remain in the all-gathered state at the time of the exception. In the old subprocess-based design the whole NCCL communicator was torn down on process exit, providing a clean slate. In the new shared-session design, _cleanup calls gc.collect() and torch.cuda.empty_cache() to release those tensors.

This is safe in practice because:

  • All ranks execute the same Python code and hit the same AttributeError at the same autograd node, so no partial NCCL collective is ever initiated.
  • FSDP2's reshard_after_forward=True (default) already reshards parameters after the forward pass completes, so no all-gathered buffers are held when the backward starts.

However, the fact that the test intentionally leaves FSDP2 backward hooks un-fired is non-obvious. A brief code comment explaining why this is safe would improve maintainability, e.g.:

# NOTE: loss.backward() is expected to raise AttributeError before any
# FSDP2 reduce-scatter hook fires. Because all ranks hit the same
# exception at the same autograd node, no NCCL collective is left
# partially initiated. _cleanup's dist.barrier() will succeed normally.
loss.backward()  # Expected to raise AttributeError

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +359 to +362
NUM_PROCS = int(os.environ.get("WORLD_SIZE", "1"))


@pytest.mark.parametrize("sharding_dims", [[NUM_PROCS], [2, NUM_PROCS // 2]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 NUM_PROCS == 0 bypasses the even-number skip guard

NUM_PROCS = int(os.environ.get("WORLD_SIZE", "1")) is evaluated at import time. When the outer test_fsdp2_model_tests launches torchrun, WORLD_SIZE is set by torchrun so this is always ≥ 1 in practice. However, the second parametrize entry [2, NUM_PROCS // 2] produces [2, 0] when NUM_PROCS == 1 (single-GPU edge-case) and [2, 0] is still technically produced when NUM_PROCS == 1. While the outer skip guard (NUM_PROCS % 2 != 0) would prevent running with 1 GPU, there is no guard inside the inner pytest session itself.

Adding a skip mark or an explicit pytest.skip at the top of test_distributed for the degenerate [2, 0] case would make this self-documenting:

@pytest.mark.parametrize("sharding_dims", [[NUM_PROCS], [2, NUM_PROCS // 2]])
...
def test_distributed(recipe_name, fp8_init, sharding_dims, layer_type):
    if 0 in sharding_dims:
        pytest.skip("Degenerate mesh dimension (0); requires NUM_PROCS >= 2")

@vthumbe1503 vthumbe1503 requested a review from ksivaman March 23, 2026 21:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants