Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,16 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
candidates | known_break)

# Compute iteration direction
idir = {d: Backward for d in candidates if d.root in scope.d_anti.cause}
# When checking for iteration direction, the user may have specified an LHS
# preceding the RHS, implying backward iteration, even if there is no strict
# reason that this iteration would need to run backward. Check if there is a
# user-specified backward iteration before defaulting to forward to avoid a
# gotcha by using the logical d_anti here.
idir = {d: Backward for d in candidates
if d.root in scope.d_anti_logical.cause}
if maybe_break:
idir.update({d: Forward for d in candidates if d.root in scope.d_flow.cause})
# Default to forward for remaining dimensions
idir.update({d: Forward for d in candidates if d not in idir})

# Enforce iteration direction on each Cluster
Expand Down
53 changes: 37 additions & 16 deletions devito/ir/support/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,17 @@ def lex_le(self, other):
def lex_lt(self, other):
return self.timestamp < other.timestamp

def distance(self, other):
def distance(self, other, logical=False):
"""
Compute the distance from ``self`` to ``other``.

Parameters
----------
other : TimedAccess
The TimedAccess w.r.t. which the distance is computed.
logical : bool
Compute a logical distance rather than true distance (i.e. ignoring
degenerating indices created by size 1 buffers etc).
"""
if isinstance(self.access, ComponentAccess) and \
isinstance(other.access, ComponentAccess) and \
Expand Down Expand Up @@ -392,7 +395,7 @@ def distance(self, other):
# objects falls back to zero, as any other value would be
# nonsensical
ret.append(S.Zero)
elif degenerating_indices(self[n], other[n], self.function):
elif degenerating_indices(self[n], other[n], self.function, logical=logical):
# Special case: `sai` and `oai` may be different symbolic objects
# but they can be proved to systematically generate the same value
ret.append(S.Zero)
Expand Down Expand Up @@ -786,6 +789,13 @@ def is_storage_related(self, dims=None):
return False


class LogicalDependence(Dependence):

@cached_property
def distance(self):
return self.source.distance(self.sink, logical=True)


class DependenceGroup(set):

@cached_property
Expand Down Expand Up @@ -1111,20 +1121,21 @@ def d_flow(self):
return DependenceGroup(self.d_flow_gen())

@memoized_generator
def d_anti_gen(self):
def d_anti_gen(self, depcls=Dependence):
"""Generate the anti (or "write-after-read") dependences."""
for k, v in self.writes.items():
for w in v:
for r in self.reads_smart_gen(k):
if any(not rule(r, w) for rule in self.rules):
continue

dependence = Dependence(r, w)
dependence = depcls(r, w)

if dependence.is_imaginary:
continue

distance = dependence.distance

try:
is_anti = distance > 0 or (r.lex_lt(w) and distance == 0)
except TypeError:
Expand All @@ -1140,6 +1151,14 @@ def d_anti(self):
"""Anti (or "write-after-read") dependences."""
return DependenceGroup(self.d_anti_gen())

@cached_property
def d_anti_logical(self):
"""
Anti (or "write-after-read") dependences using logical rather than true
distances.
"""
return DependenceGroup(self.d_anti_gen(depcls=LogicalDependence))

@memoized_generator
def d_output_gen(self):
"""Generate the output (or "write-after-write") dependences."""
Expand Down Expand Up @@ -1425,7 +1444,7 @@ def disjoint_test(e0, e1, d, it):
return not bool(i0.intersect(i1))


def degenerating_indices(i0, i1, function):
def degenerating_indices(i0, i1, function, logical=False):
"""
True if `i0` and `i1` are indices that are possibly symbolically
different, but they can be proved to systematically degenerate to the
Expand All @@ -1440,17 +1459,19 @@ def degenerating_indices(i0, i1, function):

# Case 2: SteppingDimension corresponding to buffer of size 1
# Extract dimension from both IndexAccessFunctions -> d0, d1
try:
d0 = i0.d
except AttributeError:
d0 = i0
try:
d1 = i1.d
except AttributeError:
d1 = i1
# Skipped if doing a purely logical check
if not logical:
try:
d0 = i0.d
except AttributeError:
d0 = i0
try:
d1 = i1.d
except AttributeError:
d1 = i1

with suppress(AttributeError):
if d0 is d1 and d0.is_Stepping and function._size_domain[d0] == 1:
return True
with suppress(AttributeError):
if d0 is d1 and d0.is_Stepping and function._size_domain[d0] == 1:
return True

return False
13 changes: 8 additions & 5 deletions devito/tools/memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,18 @@ class CacheInstancesMeta(type):
def __init__(cls: type[InstanceType], *args) -> None: # type: ignore
super().__init__(*args)

# Register the cached type
# Register the cached type and eagerly create its cache, bound to its
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a Claude fix 😅 LogicalDependence objects kept getting constructed as Dependence objects due to the caching

# own constructor. Eager initialisation avoids a bug where a subclass
# would inherit (and reuse) a parent's cache via MRO lookup if the
# parent happened to be instantiated first.
CacheInstancesMeta._cached_types.add(cls)
maxsize = cls._instance_cache_size
cls._instance_cache = lru_cache(maxsize=maxsize)(
super().__call__
)

def __call__(cls: type[InstanceType], # type: ignore
*args, **kwargs) -> InstanceType:
if cls._instance_cache is None:
maxsize = cls._instance_cache_size
cls._instance_cache = lru_cache(maxsize=maxsize)(super().__call__)

args, kwargs = cls._preprocess_args(*args, **kwargs)
return cls._instance_cache(*args, **kwargs)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from devito.ir.iet import (
Conditional, Expression, FindNodes, FindSymbols, Iteration, retrieve_iteration_tree
)
from devito.ir.support.space import Backward, Forward
from devito.symbolics import INT, IntDiv, indexify, retrieve_functions
from devito.types import Array, StencilDimension, Symbol
from devito.types.basic import Scalar
Expand Down Expand Up @@ -235,6 +236,26 @@ def test_degenerate_to_zero(self):

assert np.all(u.data == 10)

@pytest.mark.parametrize('direction', ['fwd', 'bwd'])
def test_buffer1_direction(self, direction):
grid = Grid(shape=(10, 10))

u = TimeFunction(name='u', grid=grid, save=Buffer(1))

# Equations technically have no implied time direction as u.forward and u refer
# to the same buffer slot. However, user usage of u.forward and u.backward should
# be picked up by the compiler
if direction == 'fwd':
op = Operator(Eq(u.forward, u + 1))
else:
op = Operator(Eq(u.backward, u + 1))

# Check for time loop direction
trees = retrieve_iteration_tree(op)
direction = Forward if direction == 'fwd' else Backward
for tree in trees:
assert tree[0].direction == direction


class TestSubDimension:

Expand Down
11 changes: 9 additions & 2 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from devito.ir.iet import (
Call, Conditional, FindNodes, FindSymbols, Iteration, retrieve_iteration_tree
)
from devito.ir.support.space import Backward, Forward
from devito.mpi import MPI
from devito.mpi.distributed import CustomTopology
from devito.mpi.routines import ComputeCall, HaloUpdateCall, HaloUpdateList, MPICall
Expand Down Expand Up @@ -1916,8 +1917,8 @@ def test_haloupdate_buffer1(self, mode):
(2, True, 'Eq(v3.forward, v2.forward.laplace + 1)', 3, 2, ('v1', 'v2')),
(1, False, 'rec.interpolate(v2)', 3, 2, ('v1', 'v2')),
(1, False, 'Eq(v3.backward, v2.laplace + 1)', 3, 2, ('v1', 'v2')),
(1, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 3, ('v2', 'v1', 'v2')),
(2, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 3, ('v2', 'v1', 'v2')),
(1, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 2, ('v1', 'v2')),
(2, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 2, ('v1', 'v2')),
])
def test_haloupdate_buffer_cases(self, sz, fwd, expr, exp0, exp1, args, mode):
grid = Grid((65, 65, 65), topology=('*', 1, '*'))
Expand All @@ -1943,6 +1944,12 @@ def test_haloupdate_buffer_cases(self, sz, fwd, expr, exp0, exp1, args, mode):
op = Operator(eqns)
_ = op.cfunction

# Check for time loop direction
trees = retrieve_iteration_tree(op)
direction = Forward if fwd else Backward
for tree in trees:
assert tree[0].direction == direction

calls, _ = check_halo_exchanges(op, exp0, exp1)
for i, v in enumerate(args):
assert calls[i].arguments[0] is eval(v)
Expand Down
Loading