diff --git a/mypyc/codegen/emitclass.py b/mypyc/codegen/emitclass.py index 26bf189694f97..d7ff730151fc8 100644 --- a/mypyc/codegen/emitclass.py +++ b/mypyc/codegen/emitclass.py @@ -379,7 +379,12 @@ def emit_line() -> None: generate_methods_table(cl, methods_name, setup_name if generate_full else None, emitter) emit_line() - flags = ["Py_TPFLAGS_DEFAULT", "Py_TPFLAGS_HEAPTYPE", "Py_TPFLAGS_BASETYPE"] + flags = [ + "Py_TPFLAGS_DEFAULT", + "Py_TPFLAGS_HEAPTYPE", + "Py_TPFLAGS_BASETYPE", + "CPy_TPFLAGS_MYPYC_COMPILED", + ] if generate_full and not cl.is_acyclic: flags.append("Py_TPFLAGS_HAVE_GC") if cl.has_method("__call__"): diff --git a/mypyc/codegen/emitfunc.py b/mypyc/codegen/emitfunc.py index c1202d1c928ca..804924732ca5d 100644 --- a/mypyc/codegen/emitfunc.py +++ b/mypyc/codegen/emitfunc.py @@ -404,6 +404,35 @@ def visit_get_attr(self, op: GetAttr) -> None: ) else: # Otherwise, use direct or offset struct access. + # For classes with allow_interpreted_subclasses, an interpreted + # subclass may override class attributes in its __dict__. The + # compiled code reads from instance struct slots, so we check if + # the instance is a compiled type (via tp_flags). If not, fall + # back to Python's generic attribute lookup which respects the MRO. + # We use the CPy_TPFLAGS_MYPYC_COMPILED flag (set on all mypyc-compiled + # types) so that compiled subclasses get direct struct access while only + # interpreted subclasses hit the slow path. + use_fallback = cl.allow_interpreted_subclasses and not cl.is_trait + if use_fallback: + fallback_attr = self.emitter.temp_name() + fallback_result = self.emitter.temp_name() + self.declarations.emit_line(f"PyObject *{fallback_attr};") + self.declarations.emit_line(f"PyObject *{fallback_result};") + self.emit_line(f"if (!(Py_TYPE({obj})->tp_flags & CPy_TPFLAGS_MYPYC_COMPILED)) {{") + self.emit_line(f'{fallback_attr} = PyUnicode_FromString("{op.attr}");') + self.emit_line( + f"{fallback_result} = PyObject_GenericGetAttr((PyObject *){obj}, {fallback_attr});" + ) + self.emit_line(f"Py_DECREF({fallback_attr});") + if attr_rtype.is_unboxed: + self.emitter.emit_unbox( + fallback_result, dest, attr_rtype, raise_exception=False + ) + self.emit_line(f"Py_XDECREF({fallback_result});") + else: + self.emit_line(f"{dest} = {fallback_result};") + self.emit_line("} else {") + attr_expr = self.get_attr_expr(obj, op, decl_cl) self.emitter.emit_line(f"{dest} = {attr_expr};") always_defined = cl.is_always_defined(op.attr) @@ -447,6 +476,9 @@ def visit_get_attr(self, op: GetAttr) -> None: elif not always_defined: self.emitter.emit_line("}") + if use_fallback: + self.emitter.emit_line("}") + def get_attr_with_allow_error_value(self, op: GetAttr) -> None: """Handle GetAttr with allow_error_value=True. diff --git a/mypyc/lib-rt/mypyc_util.h b/mypyc/lib-rt/mypyc_util.h index 6715d67d96572..cc41402bd76d7 100644 --- a/mypyc/lib-rt/mypyc_util.h +++ b/mypyc/lib-rt/mypyc_util.h @@ -146,6 +146,11 @@ typedef PyObject CPyModule; #define CPY_NONE_ERROR 2 #define CPY_NONE 1 +// Flag bit set on all mypyc-compiled types. Used to distinguish compiled +// subclasses (safe for direct struct access) from interpreted subclasses +// (need PyObject_GenericGetAttr fallback) in allow_interpreted_subclasses mode. +#define CPy_TPFLAGS_MYPYC_COMPILED (1UL << 20) + typedef void (*CPyVTableItem)(void); static inline CPyTagged CPyTagged_ShortFromInt(int x) { diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index 54e568a477684..22f6f2c466919 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -5774,3 +5774,102 @@ from native import Concrete c = Concrete() assert c.value() == 42 assert c.derived() == 42 + +[case testInterpretedSubclassAttrOverrideWithAllowInterpretedSubclasses] +# Test that interpreted subclasses can override class attributes and the +# compiled base class methods see the overridden values via GenericGetAttr. +from mypy_extensions import mypyc_attr + +@mypyc_attr(allow_interpreted_subclasses=True) +class Base: + VALUE: int = 10 + FLAG: bool = False + + def get_value(self) -> int: + return self.VALUE + + def check_flag(self) -> bool: + return self.FLAG + +[file driver.py] +from native import Base + +# Interpreted subclass that overrides class attributes +class Sub(Base): + VALUE = 42 + FLAG = True + +b = Base() +assert b.get_value() == 10 +assert not b.check_flag() + +s = Sub() +assert s.get_value() == 42, "compiled method doesn't see subclass override" +assert s.check_flag(), "compiled method doesn't see subclass override" + +[case testCompiledSubclassAttrAccessWithAllowInterpretedSubclasses] +# Test that compiled subclasses of a class with allow_interpreted_subclasses=True +# can correctly access parent instance attributes via direct struct access +# (not falling back to PyObject_GenericGetAttr). +from mypy_extensions import mypyc_attr + +@mypyc_attr(allow_interpreted_subclasses=True) +class Base: + def __init__(self, x: int, name: str) -> None: + self.x = x + self.name = name + + def get_x(self) -> int: + return self.x + + def get_name(self) -> str: + return self.name + + def compute(self) -> int: + return self.x * 2 + +@mypyc_attr(allow_interpreted_subclasses=True) +class Child(Base): + def __init__(self, x: int, name: str, y: int) -> None: + super().__init__(x, name) + self.y = y + + def compute(self) -> int: + return self.x + self.y + + def get_both(self) -> int: + return self.x + self.y + +@mypyc_attr(allow_interpreted_subclasses=True) +class GrandChild(Child): + def __init__(self, x: int, name: str, y: int, z: int) -> None: + super().__init__(x, name, y) + self.z = z + + def compute(self) -> int: + return self.x + self.y + self.z + +def test_compiled_subclass_attr_access() -> None: + b = Base(10, "base") + assert b.get_x() == 10 + assert b.get_name() == "base" + assert b.compute() == 20 + + c = Child(10, "child", 5) + assert c.get_x() == 10 + assert c.get_name() == "child" + assert c.compute() == 15 + assert c.get_both() == 15 + + g = GrandChild(10, "grand", 5, 3) + assert g.get_x() == 10 + assert g.get_name() == "grand" + assert g.compute() == 18 + + ref: Base = Child(7, "ref", 3) + assert ref.get_x() == 7 + assert ref.compute() == 10 + +[file driver.py] +from native import test_compiled_subclass_attr_access +test_compiled_subclass_attr_access()