diff --git a/clang/lib/AST/ByteCode/Interp.h b/clang/lib/AST/ByteCode/Interp.h index 427f694319b6c..9accbbc1605a9 100644 --- a/clang/lib/AST/ByteCode/Interp.h +++ b/clang/lib/AST/ByteCode/Interp.h @@ -1095,8 +1095,9 @@ inline bool CmpHelperEQ(InterpState &S, CodePtr OpPC, CompareFn Fn) { } if (Pointer::hasSameBase(LHS, RHS)) { - size_t A = LHS.computeOffsetForComparison(); - size_t B = RHS.computeOffsetForComparison(); + size_t A = LHS.computeOffsetForComparison(S.getASTContext()); + size_t B = RHS.computeOffsetForComparison(S.getASTContext()); + S.Stk.push(BoolT::from(Fn(Compare(A, B)))); return true; } diff --git a/clang/lib/AST/ByteCode/Pointer.cpp b/clang/lib/AST/ByteCode/Pointer.cpp index 90f41bed39440..c1749d9634167 100644 --- a/clang/lib/AST/ByteCode/Pointer.cpp +++ b/clang/lib/AST/ByteCode/Pointer.cpp @@ -362,7 +362,13 @@ void Pointer::print(llvm::raw_ostream &OS) const { } } -size_t Pointer::computeOffsetForComparison() const { +/// Compute an offset that can be used to compare this pointer to another one +/// with the same base. To get accurate results, we basically _have to_ compute +/// the lvalue offset using the ASTRecordLayout. +/// +/// FIXME: We're still mixing values from the record layout with our internal +/// offsets, which will inevitably lead to cryptic errors. +size_t Pointer::computeOffsetForComparison(const ASTContext &ASTCtx) const { switch (StorageKind) { case Storage::Int: return Int.Value + Offset; @@ -378,7 +384,6 @@ size_t Pointer::computeOffsetForComparison() const { size_t Result = 0; Pointer P = *this; while (true) { - if (P.isVirtualBaseClass()) { Result += getInlineDesc()->Offset; P = P.getBase(); @@ -400,28 +405,29 @@ size_t Pointer::computeOffsetForComparison() const { if (P.isRoot()) { if (P.isOnePastEnd()) - ++Result; + Result += + ASTCtx.getTypeSizeInChars(P.getDeclDesc()->getType()).getQuantity(); break; } - if (const Record *R = P.getBase().getRecord(); R && R->isUnion()) { - if (P.isOnePastEnd()) - ++Result; - // Direct child of a union - all have offset 0. - P = P.getBase(); - continue; - } + assert(P.getField()); + const Record *R = P.getBase().getRecord(); + assert(R); + + const ASTRecordLayout &Layout = ASTCtx.getASTRecordLayout(R->getDecl()); + Result += ASTCtx + .toCharUnitsFromBits( + Layout.getFieldOffset(P.getField()->getFieldIndex())) + .getQuantity(); - // Fields, etc. - Result += P.getInlineDesc()->Offset; if (P.isOnePastEnd()) - ++Result; + Result += + ASTCtx.getTypeSizeInChars(P.getField()->getType()).getQuantity(); P = P.getBase(); if (P.isRoot()) break; } - return Result; } diff --git a/clang/lib/AST/ByteCode/Pointer.h b/clang/lib/AST/ByteCode/Pointer.h index 0978090ba8b19..9032bdade850f 100644 --- a/clang/lib/AST/ByteCode/Pointer.h +++ b/clang/lib/AST/ByteCode/Pointer.h @@ -784,7 +784,7 @@ class Pointer { /// Compute an integer that can be used to compare this pointer to /// another one. This is usually NOT the same as the pointer offset /// regarding the AST record layout. - size_t computeOffsetForComparison() const; + size_t computeOffsetForComparison(const ASTContext &ASTCtx) const; private: friend class Block; diff --git a/clang/lib/AST/ByteCode/Record.h b/clang/lib/AST/ByteCode/Record.h index 8245eeff2f20d..7b66c3b263e38 100644 --- a/clang/lib/AST/ByteCode/Record.h +++ b/clang/lib/AST/ByteCode/Record.h @@ -61,14 +61,6 @@ class Record final { unsigned getSize() const { return BaseSize; } /// Returns the full size of the record, including records. unsigned getFullSize() const { return BaseSize + VirtualSize; } - /// Returns a field. - const Field *getField(const FieldDecl *FD) const; - /// Returns a base descriptor. - const Base *getBase(const RecordDecl *FD) const; - /// Returns a base descriptor. - const Base *getBase(QualType T) const; - /// Returns a virtual base descriptor. - const Base *getVirtualBase(const RecordDecl *RD) const; /// Returns the destructor of the record, if any. const CXXDestructorDecl *getDestructor() const { if (const auto *CXXDecl = dyn_cast(Decl)) @@ -87,6 +79,8 @@ class Record final { unsigned getNumFields() const { return Fields.size(); } const Field *getField(unsigned I) const { return &Fields[I]; } + /// Returns a field. + const Field *getField(const FieldDecl *FD) const; using const_base_iter = BaseList::const_iterator; llvm::iterator_range bases() const { @@ -98,6 +92,10 @@ class Record final { assert(I < getNumBases()); return &Bases[I]; } + /// Returns a base descriptor. + const Base *getBase(QualType T) const; + /// Returns a base descriptor. + const Base *getBase(const RecordDecl *FD) const; using const_virtual_iter = VirtualBaseList::const_iterator; llvm::iterator_range virtual_bases() const { @@ -106,6 +104,8 @@ class Record final { unsigned getNumVirtualBases() const { return VirtualBases.size(); } const Base *getVirtualBase(unsigned I) const { return &VirtualBases[I]; } + /// Returns a virtual base descriptor. + const Base *getVirtualBase(const RecordDecl *RD) const; void dump(llvm::raw_ostream &OS, unsigned Indentation = 0, unsigned Offset = 0) const; diff --git a/clang/test/AST/ByteCode/cxx20.cpp b/clang/test/AST/ByteCode/cxx20.cpp index 227f34cee80ff..2abe8dd120e6f 100644 --- a/clang/test/AST/ByteCode/cxx20.cpp +++ b/clang/test/AST/ByteCode/cxx20.cpp @@ -1225,3 +1225,58 @@ namespace ConditionalTemporaries { static_assert(foo(false)== 13); static_assert(foo(true)== 12); } + +namespace PointerCmp { + struct K { + struct { + double a; + alignas(8) int b; + } m; + char c; + }; + constexpr K k{1,2, 3}; + static_assert((void*)(&k.m.a + 1) == (void*)&k.m.b); + static_assert((void*)(&k.m + 1) == (void*)&k.c); + static_assert((void*)(&k + 1) != (void*)(&k.c + 1)); + + struct K2 { + struct { + int a; + alignas(8) int b; + } m; + double c; + }; + constexpr K2 k2{1,2, 3}; + static_assert((void*)(&k2.m.a + 1) != (void*)&k2.m.b); + /// static_assert((void*)(&k2.m.a + 1) < (void*)&k2.m.b); FIXME + static_assert((void*)(&k2.m + 1) == (void*)&k2.c); + static_assert((void*)(&k2 + 1) == (void*)(&k2.c + 1)); + + + struct tuple { + int a; + int b; + }; + + constexpr tuple tpl{1,2}; + static_assert((void*)&tpl == (void*)&tpl.a); + + + struct B { + int a; + }; + + struct tuple2 : public B { + int b; + }; + constexpr tuple2 tpl2{1,2}; + static_assert((void*)&tpl2 == (void*)&tpl2.a); + + struct A { + int i[3]; + double c; + }; + constexpr A a{1,2, 3}; + static_assert((void*)(&a.i + 1) != (void*)(&a.i[1])); // expected-error {{static assertion failed}} + static_assert((void*)(&a.i[2] + 1) == (void*)(&a.i[3])); +}