From d6e80befe75723620f20ffd4e33c805ad625f03a Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Thu, 19 Mar 2026 12:16:43 +0000 Subject: [PATCH 1/2] compiler: Tweak use of degenerating_indices when determining iteration directions --- devito/ir/clusters/algorithms.py | 9 ++++- devito/ir/support/basic.py | 68 ++++++++++++++++++++++++-------- tests/test_dimension.py | 21 ++++++++++ tests/test_mpi.py | 11 +++++- 4 files changed, 90 insertions(+), 19 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index ab688e3c81..5159ea3020 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -170,9 +170,16 @@ def callback(self, clusters, prefix, backlog=None, known_break=None): candidates | known_break) # Compute iteration direction - idir = {d: Backward for d in candidates if d.root in scope.d_anti.cause} + # When checking for iteration direction, the user may have specified an LHS + # preceding the RHS, implying backward iteration, even if there is no strict + # reason that this iteration would need to run backward. Check if there is a + # user-specified backward iteration before defaulting to forward to avoid a + # gotcha by using the logical d_anti here. + idir = {d: Backward for d in candidates + if d.root in scope.d_anti_logical.cause_logical} if maybe_break: idir.update({d: Forward for d in candidates if d.root in scope.d_flow.cause}) + # Default to forward for remaining dimensions idir.update({d: Forward for d in candidates if d not in idir}) # Enforce iteration direction on each Cluster diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 8357b1b05f..b27aeab213 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -320,7 +320,7 @@ def lex_le(self, other): def lex_lt(self, other): return self.timestamp < other.timestamp - def distance(self, other): + def distance(self, other, logical=False): """ Compute the distance from ``self`` to ``other``. @@ -328,6 +328,9 @@ def distance(self, other): ---------- other : TimedAccess The TimedAccess w.r.t. which the distance is computed. + logical : bool + Compute a logical distance rather than true distance (i.e. ignoring + degenerating indices created by size 1 buffers etc). """ if isinstance(self.access, ComponentAccess) and \ isinstance(other.access, ComponentAccess) and \ @@ -392,7 +395,7 @@ def distance(self, other): # objects falls back to zero, as any other value would be # nonsensical ret.append(S.Zero) - elif degenerating_indices(self[n], other[n], self.function): + elif degenerating_indices(self[n], other[n], self.function, logical=logical): # Special case: `sai` and `oai` may be different symbolic objects # but they can be proved to systematically generate the same value ret.append(S.Zero) @@ -566,6 +569,10 @@ def timestamp(self): def distance(self): return self.source.distance(self.sink) + @cached_property + def distance_logical(self): + return self.source.distance(self.sink, logical=True) + @cached_property def _defined_findices(self): return frozenset(flatten(i._defines for i in self.findices)) @@ -656,6 +663,19 @@ def cause(self): return i._defines return frozenset() + # TODO: Refactor this + @cached_property + def cause_logical(self): + """Return the findex causing the dependence.""" + for i, j in zip(self.findices, self.distance_logical, strict=False): + try: + if j > 0: + return i._defines + except TypeError: + # Conservatively assume this is an offending dimension + return i._defines + return frozenset() + @cached_property def read(self): if self.is_flow: @@ -792,6 +812,10 @@ class DependenceGroup(set): def cause(self): return frozenset().union(*[i.cause for i in self]) + @cached_property + def cause_logical(self): + return frozenset().union(*[i.cause_logical for i in self]) + @cached_property def functions(self): """Return the DiscreteFunctions inducing a dependence.""" @@ -1111,7 +1135,7 @@ def d_flow(self): return DependenceGroup(self.d_flow_gen()) @memoized_generator - def d_anti_gen(self): + def d_anti_gen(self, logical=False): """Generate the anti (or "write-after-read") dependences.""" for k, v in self.writes.items(): for w in v: @@ -1124,7 +1148,9 @@ def d_anti_gen(self): if dependence.is_imaginary: continue - distance = dependence.distance + distance = dependence.distance_logical \ + if logical else dependence.distance + try: is_anti = distance > 0 or (r.lex_lt(w) and distance == 0) except TypeError: @@ -1140,6 +1166,14 @@ def d_anti(self): """Anti (or "write-after-read") dependences.""" return DependenceGroup(self.d_anti_gen()) + @cached_property + def d_anti_logical(self): + """ + Anti (or "write-after-read") dependences using logical rather than true + distances. + """ + return DependenceGroup(self.d_anti_gen(logical=True)) + @memoized_generator def d_output_gen(self): """Generate the output (or "write-after-write") dependences.""" @@ -1425,7 +1459,7 @@ def disjoint_test(e0, e1, d, it): return not bool(i0.intersect(i1)) -def degenerating_indices(i0, i1, function): +def degenerating_indices(i0, i1, function, logical=False): """ True if `i0` and `i1` are indices that are possibly symbolically different, but they can be proved to systematically degenerate to the @@ -1440,17 +1474,19 @@ def degenerating_indices(i0, i1, function): # Case 2: SteppingDimension corresponding to buffer of size 1 # Extract dimension from both IndexAccessFunctions -> d0, d1 - try: - d0 = i0.d - except AttributeError: - d0 = i0 - try: - d1 = i1.d - except AttributeError: - d1 = i1 + # Skipped if doing a purely logical check + if not logical: + try: + d0 = i0.d + except AttributeError: + d0 = i0 + try: + d1 = i1.d + except AttributeError: + d1 = i1 - with suppress(AttributeError): - if d0 is d1 and d0.is_Stepping and function._size_domain[d0] == 1: - return True + with suppress(AttributeError): + if d0 is d1 and d0.is_Stepping and function._size_domain[d0] == 1: + return True return False diff --git a/tests/test_dimension.py b/tests/test_dimension.py index 8c4781d58e..aae7e4e0cb 100644 --- a/tests/test_dimension.py +++ b/tests/test_dimension.py @@ -17,6 +17,7 @@ from devito.ir.iet import ( Conditional, Expression, FindNodes, FindSymbols, Iteration, retrieve_iteration_tree ) +from devito.ir.support.space import Backward, Forward from devito.symbolics import INT, IntDiv, indexify, retrieve_functions from devito.types import Array, StencilDimension, Symbol from devito.types.basic import Scalar @@ -235,6 +236,26 @@ def test_degenerate_to_zero(self): assert np.all(u.data == 10) + @pytest.mark.parametrize('direction', ['fwd', 'bwd']) + def test_buffer1_direction(self, direction): + grid = Grid(shape=(10, 10)) + + u = TimeFunction(name='u', grid=grid, save=Buffer(1)) + + # Equations technically have no implied time direction as u.forward and u refer + # to the same buffer slot. However, user usage of u.forward and u.backward should + # be picked up by the compiler + if direction == 'fwd': + op = Operator(Eq(u.forward, u + 1)) + else: + op = Operator(Eq(u.backward, u + 1)) + + # Check for time loop direction + trees = retrieve_iteration_tree(op) + direction = Forward if direction == 'fwd' else Backward + for tree in trees: + assert tree[0].direction == direction + class TestSubDimension: diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 24ac679812..2c4758c20d 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -17,6 +17,7 @@ from devito.ir.iet import ( Call, Conditional, FindNodes, FindSymbols, Iteration, retrieve_iteration_tree ) +from devito.ir.support.space import Backward, Forward from devito.mpi import MPI from devito.mpi.distributed import CustomTopology from devito.mpi.routines import ComputeCall, HaloUpdateCall, HaloUpdateList, MPICall @@ -1916,8 +1917,8 @@ def test_haloupdate_buffer1(self, mode): (2, True, 'Eq(v3.forward, v2.forward.laplace + 1)', 3, 2, ('v1', 'v2')), (1, False, 'rec.interpolate(v2)', 3, 2, ('v1', 'v2')), (1, False, 'Eq(v3.backward, v2.laplace + 1)', 3, 2, ('v1', 'v2')), - (1, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 3, ('v2', 'v1', 'v2')), - (2, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 3, ('v2', 'v1', 'v2')), + (1, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 2, ('v1', 'v2')), + (2, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 2, ('v1', 'v2')), ]) def test_haloupdate_buffer_cases(self, sz, fwd, expr, exp0, exp1, args, mode): grid = Grid((65, 65, 65), topology=('*', 1, '*')) @@ -1943,6 +1944,12 @@ def test_haloupdate_buffer_cases(self, sz, fwd, expr, exp0, exp1, args, mode): op = Operator(eqns) _ = op.cfunction + # Check for time loop direction + trees = retrieve_iteration_tree(op) + direction = Forward if fwd else Backward + for tree in trees: + assert tree[0].direction == direction + calls, _ = check_halo_exchanges(op, exp0, exp1) for i, v in enumerate(args): assert calls[i].arguments[0] is eval(v) From 5326df021b3bc457db6c48f62ec6e53c5e3ce24d Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Thu, 19 Mar 2026 17:19:06 +0000 Subject: [PATCH 2/2] misc: Refactoring, tidy up, and caching bugfix --- devito/ir/clusters/algorithms.py | 2 +- devito/ir/support/basic.py | 37 ++++++++++---------------------- devito/tools/memoization.py | 13 ++++++----- 3 files changed, 20 insertions(+), 32 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 5159ea3020..99b05c8b75 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -176,7 +176,7 @@ def callback(self, clusters, prefix, backlog=None, known_break=None): # user-specified backward iteration before defaulting to forward to avoid a # gotcha by using the logical d_anti here. idir = {d: Backward for d in candidates - if d.root in scope.d_anti_logical.cause_logical} + if d.root in scope.d_anti_logical.cause} if maybe_break: idir.update({d: Forward for d in candidates if d.root in scope.d_flow.cause}) # Default to forward for remaining dimensions diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index b27aeab213..7a841be671 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -569,10 +569,6 @@ def timestamp(self): def distance(self): return self.source.distance(self.sink) - @cached_property - def distance_logical(self): - return self.source.distance(self.sink, logical=True) - @cached_property def _defined_findices(self): return frozenset(flatten(i._defines for i in self.findices)) @@ -663,19 +659,6 @@ def cause(self): return i._defines return frozenset() - # TODO: Refactor this - @cached_property - def cause_logical(self): - """Return the findex causing the dependence.""" - for i, j in zip(self.findices, self.distance_logical, strict=False): - try: - if j > 0: - return i._defines - except TypeError: - # Conservatively assume this is an offending dimension - return i._defines - return frozenset() - @cached_property def read(self): if self.is_flow: @@ -806,16 +789,19 @@ def is_storage_related(self, dims=None): return False +class LogicalDependence(Dependence): + + @cached_property + def distance(self): + return self.source.distance(self.sink, logical=True) + + class DependenceGroup(set): @cached_property def cause(self): return frozenset().union(*[i.cause for i in self]) - @cached_property - def cause_logical(self): - return frozenset().union(*[i.cause_logical for i in self]) - @cached_property def functions(self): """Return the DiscreteFunctions inducing a dependence.""" @@ -1135,7 +1121,7 @@ def d_flow(self): return DependenceGroup(self.d_flow_gen()) @memoized_generator - def d_anti_gen(self, logical=False): + def d_anti_gen(self, depcls=Dependence): """Generate the anti (or "write-after-read") dependences.""" for k, v in self.writes.items(): for w in v: @@ -1143,13 +1129,12 @@ def d_anti_gen(self, logical=False): if any(not rule(r, w) for rule in self.rules): continue - dependence = Dependence(r, w) + dependence = depcls(r, w) if dependence.is_imaginary: continue - distance = dependence.distance_logical \ - if logical else dependence.distance + distance = dependence.distance try: is_anti = distance > 0 or (r.lex_lt(w) and distance == 0) @@ -1172,7 +1157,7 @@ def d_anti_logical(self): Anti (or "write-after-read") dependences using logical rather than true distances. """ - return DependenceGroup(self.d_anti_gen(logical=True)) + return DependenceGroup(self.d_anti_gen(depcls=LogicalDependence)) @memoized_generator def d_output_gen(self): diff --git a/devito/tools/memoization.py b/devito/tools/memoization.py index b24a9166f6..c10f5ea092 100644 --- a/devito/tools/memoization.py +++ b/devito/tools/memoization.py @@ -142,15 +142,18 @@ class CacheInstancesMeta(type): def __init__(cls: type[InstanceType], *args) -> None: # type: ignore super().__init__(*args) - # Register the cached type + # Register the cached type and eagerly create its cache, bound to its + # own constructor. Eager initialisation avoids a bug where a subclass + # would inherit (and reuse) a parent's cache via MRO lookup if the + # parent happened to be instantiated first. CacheInstancesMeta._cached_types.add(cls) + maxsize = cls._instance_cache_size + cls._instance_cache = lru_cache(maxsize=maxsize)( + super().__call__ + ) def __call__(cls: type[InstanceType], # type: ignore *args, **kwargs) -> InstanceType: - if cls._instance_cache is None: - maxsize = cls._instance_cache_size - cls._instance_cache = lru_cache(maxsize=maxsize)(super().__call__) - args, kwargs = cls._preprocess_args(*args, **kwargs) return cls._instance_cache(*args, **kwargs)