Skip to content

Optimize RoPE arithmetic and simplify tensor reshaping#355

Open
Perseus14 wants to merge 1 commit intomainfrom
wan-opt
Open

Optimize RoPE arithmetic and simplify tensor reshaping#355
Perseus14 wants to merge 1 commit intomainfrom
wan-opt

Conversation

@Perseus14
Copy link
Collaborator

This PR introduces two focused optimizations to improve memory bandwidth utilization and overall code maintainability. It optimizes the Rotary Position Embedding (RoPE) application by utilizing native real arithmetic for better hardware fusion, and refactors a sequence of tensor collapses into a single, robust reshape operation.

Changes:

  • Replaced the float32 upcasting and jax.lax.complex multiplication with explicit real-number arithmetic (computing the 2D rotation directly), keeping the tensors in their native dtype (e.g., bfloat16).

  • Replaced three sequential jax.lax.collapse operations with a single, explicit hidden_states.reshape(batch_size, -1, num_frames, height, width).

@Perseus14 Perseus14 requested a review from entrpn as a code owner March 12, 2026 08:19
@github-actions
Copy link

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant