diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index ab688e3c81..99b05c8b75 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -170,9 +170,16 @@ def callback(self, clusters, prefix, backlog=None, known_break=None): candidates | known_break) # Compute iteration direction - idir = {d: Backward for d in candidates if d.root in scope.d_anti.cause} + # When checking for iteration direction, the user may have specified an LHS + # preceding the RHS, implying backward iteration, even if there is no strict + # reason that this iteration would need to run backward. Check if there is a + # user-specified backward iteration before defaulting to forward to avoid a + # gotcha by using the logical d_anti here. + idir = {d: Backward for d in candidates + if d.root in scope.d_anti_logical.cause} if maybe_break: idir.update({d: Forward for d in candidates if d.root in scope.d_flow.cause}) + # Default to forward for remaining dimensions idir.update({d: Forward for d in candidates if d not in idir}) # Enforce iteration direction on each Cluster diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 8357b1b05f..7a841be671 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -320,7 +320,7 @@ def lex_le(self, other): def lex_lt(self, other): return self.timestamp < other.timestamp - def distance(self, other): + def distance(self, other, logical=False): """ Compute the distance from ``self`` to ``other``. @@ -328,6 +328,9 @@ def distance(self, other): ---------- other : TimedAccess The TimedAccess w.r.t. which the distance is computed. + logical : bool + Compute a logical distance rather than true distance (i.e. ignoring + degenerating indices created by size 1 buffers etc). """ if isinstance(self.access, ComponentAccess) and \ isinstance(other.access, ComponentAccess) and \ @@ -392,7 +395,7 @@ def distance(self, other): # objects falls back to zero, as any other value would be # nonsensical ret.append(S.Zero) - elif degenerating_indices(self[n], other[n], self.function): + elif degenerating_indices(self[n], other[n], self.function, logical=logical): # Special case: `sai` and `oai` may be different symbolic objects # but they can be proved to systematically generate the same value ret.append(S.Zero) @@ -786,6 +789,13 @@ def is_storage_related(self, dims=None): return False +class LogicalDependence(Dependence): + + @cached_property + def distance(self): + return self.source.distance(self.sink, logical=True) + + class DependenceGroup(set): @cached_property @@ -1111,7 +1121,7 @@ def d_flow(self): return DependenceGroup(self.d_flow_gen()) @memoized_generator - def d_anti_gen(self): + def d_anti_gen(self, depcls=Dependence): """Generate the anti (or "write-after-read") dependences.""" for k, v in self.writes.items(): for w in v: @@ -1119,12 +1129,13 @@ def d_anti_gen(self): if any(not rule(r, w) for rule in self.rules): continue - dependence = Dependence(r, w) + dependence = depcls(r, w) if dependence.is_imaginary: continue distance = dependence.distance + try: is_anti = distance > 0 or (r.lex_lt(w) and distance == 0) except TypeError: @@ -1140,6 +1151,14 @@ def d_anti(self): """Anti (or "write-after-read") dependences.""" return DependenceGroup(self.d_anti_gen()) + @cached_property + def d_anti_logical(self): + """ + Anti (or "write-after-read") dependences using logical rather than true + distances. + """ + return DependenceGroup(self.d_anti_gen(depcls=LogicalDependence)) + @memoized_generator def d_output_gen(self): """Generate the output (or "write-after-write") dependences.""" @@ -1425,7 +1444,7 @@ def disjoint_test(e0, e1, d, it): return not bool(i0.intersect(i1)) -def degenerating_indices(i0, i1, function): +def degenerating_indices(i0, i1, function, logical=False): """ True if `i0` and `i1` are indices that are possibly symbolically different, but they can be proved to systematically degenerate to the @@ -1440,17 +1459,19 @@ def degenerating_indices(i0, i1, function): # Case 2: SteppingDimension corresponding to buffer of size 1 # Extract dimension from both IndexAccessFunctions -> d0, d1 - try: - d0 = i0.d - except AttributeError: - d0 = i0 - try: - d1 = i1.d - except AttributeError: - d1 = i1 + # Skipped if doing a purely logical check + if not logical: + try: + d0 = i0.d + except AttributeError: + d0 = i0 + try: + d1 = i1.d + except AttributeError: + d1 = i1 - with suppress(AttributeError): - if d0 is d1 and d0.is_Stepping and function._size_domain[d0] == 1: - return True + with suppress(AttributeError): + if d0 is d1 and d0.is_Stepping and function._size_domain[d0] == 1: + return True return False diff --git a/devito/tools/memoization.py b/devito/tools/memoization.py index b24a9166f6..c10f5ea092 100644 --- a/devito/tools/memoization.py +++ b/devito/tools/memoization.py @@ -142,15 +142,18 @@ class CacheInstancesMeta(type): def __init__(cls: type[InstanceType], *args) -> None: # type: ignore super().__init__(*args) - # Register the cached type + # Register the cached type and eagerly create its cache, bound to its + # own constructor. Eager initialisation avoids a bug where a subclass + # would inherit (and reuse) a parent's cache via MRO lookup if the + # parent happened to be instantiated first. CacheInstancesMeta._cached_types.add(cls) + maxsize = cls._instance_cache_size + cls._instance_cache = lru_cache(maxsize=maxsize)( + super().__call__ + ) def __call__(cls: type[InstanceType], # type: ignore *args, **kwargs) -> InstanceType: - if cls._instance_cache is None: - maxsize = cls._instance_cache_size - cls._instance_cache = lru_cache(maxsize=maxsize)(super().__call__) - args, kwargs = cls._preprocess_args(*args, **kwargs) return cls._instance_cache(*args, **kwargs) diff --git a/tests/test_dimension.py b/tests/test_dimension.py index 8c4781d58e..aae7e4e0cb 100644 --- a/tests/test_dimension.py +++ b/tests/test_dimension.py @@ -17,6 +17,7 @@ from devito.ir.iet import ( Conditional, Expression, FindNodes, FindSymbols, Iteration, retrieve_iteration_tree ) +from devito.ir.support.space import Backward, Forward from devito.symbolics import INT, IntDiv, indexify, retrieve_functions from devito.types import Array, StencilDimension, Symbol from devito.types.basic import Scalar @@ -235,6 +236,26 @@ def test_degenerate_to_zero(self): assert np.all(u.data == 10) + @pytest.mark.parametrize('direction', ['fwd', 'bwd']) + def test_buffer1_direction(self, direction): + grid = Grid(shape=(10, 10)) + + u = TimeFunction(name='u', grid=grid, save=Buffer(1)) + + # Equations technically have no implied time direction as u.forward and u refer + # to the same buffer slot. However, user usage of u.forward and u.backward should + # be picked up by the compiler + if direction == 'fwd': + op = Operator(Eq(u.forward, u + 1)) + else: + op = Operator(Eq(u.backward, u + 1)) + + # Check for time loop direction + trees = retrieve_iteration_tree(op) + direction = Forward if direction == 'fwd' else Backward + for tree in trees: + assert tree[0].direction == direction + class TestSubDimension: diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 24ac679812..2c4758c20d 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -17,6 +17,7 @@ from devito.ir.iet import ( Call, Conditional, FindNodes, FindSymbols, Iteration, retrieve_iteration_tree ) +from devito.ir.support.space import Backward, Forward from devito.mpi import MPI from devito.mpi.distributed import CustomTopology from devito.mpi.routines import ComputeCall, HaloUpdateCall, HaloUpdateList, MPICall @@ -1916,8 +1917,8 @@ def test_haloupdate_buffer1(self, mode): (2, True, 'Eq(v3.forward, v2.forward.laplace + 1)', 3, 2, ('v1', 'v2')), (1, False, 'rec.interpolate(v2)', 3, 2, ('v1', 'v2')), (1, False, 'Eq(v3.backward, v2.laplace + 1)', 3, 2, ('v1', 'v2')), - (1, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 3, ('v2', 'v1', 'v2')), - (2, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 3, ('v2', 'v1', 'v2')), + (1, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 2, ('v1', 'v2')), + (2, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 2, ('v1', 'v2')), ]) def test_haloupdate_buffer_cases(self, sz, fwd, expr, exp0, exp1, args, mode): grid = Grid((65, 65, 65), topology=('*', 1, '*')) @@ -1943,6 +1944,12 @@ def test_haloupdate_buffer_cases(self, sz, fwd, expr, exp0, exp1, args, mode): op = Operator(eqns) _ = op.cfunction + # Check for time loop direction + trees = retrieve_iteration_tree(op) + direction = Forward if fwd else Backward + for tree in trees: + assert tree[0].direction == direction + calls, _ = check_halo_exchanges(op, exp0, exp1) for i, v in enumerate(args): assert calls[i].arguments[0] is eval(v)