diff --git a/include/dxc/DXIL/DxilConstants.h b/include/dxc/DXIL/DxilConstants.h index 8d44e58487..12caf1af60 100644 --- a/include/dxc/DXIL/DxilConstants.h +++ b/include/dxc/DXIL/DxilConstants.h @@ -192,6 +192,18 @@ enum class ComponentType : uint32_t { LastEntry }; +enum class MatrixUse : uint32_t { + A = 0, + B = 1, + Accumulator = 2, +}; + +enum class MatrixScope : uint32_t { + Thread = 0, + Wave = 1, + ThreadGroup = 2, +}; + // Must match D3D_INTERPOLATION_MODE enum class InterpolationMode : uint8_t { Undefined = 0, diff --git a/include/dxc/dxcapi.internal.h b/include/dxc/dxcapi.internal.h index 276c9d6793..e2ab229b2f 100644 --- a/include/dxc/dxcapi.internal.h +++ b/include/dxc/dxcapi.internal.h @@ -130,17 +130,19 @@ enum LEGAL_INTRINSIC_COMPTYPES { LICOMPTYPE_HIT_OBJECT = 51, LICOMPTYPE_RAY_QUERY = 52, - LICOMPTYPE_LINALG = 53, // f32, partial-precision-f32, f16, + LICOMPTYPE_LINALG_MATRIX = 53, + + LICOMPTYPE_LINALG = 54, // f32, partial-precision-f32, f16, // i32, i16, u32, u16, // int8_4packed, uint8_4packed - LICOMPTYPE_BUILTIN_TRIANGLE_POSITIONS = 54, + LICOMPTYPE_BUILTIN_TRIANGLE_POSITIONS = 55, #ifdef ENABLE_SPIRV_CODEGEN - LICOMPTYPE_VK_BUFFER_POINTER = 55, - LICOMPTYPE_COUNT = 56 + LICOMPTYPE_VK_BUFFER_POINTER = 56, + LICOMPTYPE_COUNT = 57 #else - LICOMPTYPE_COUNT = 55 + LICOMPTYPE_COUNT = 56 #endif }; diff --git a/tools/clang/include/clang/AST/ASTContext.h b/tools/clang/include/clang/AST/ASTContext.h index 94093a2733..169efa26ba 100644 --- a/tools/clang/include/clang/AST/ASTContext.h +++ b/tools/clang/include/clang/AST/ASTContext.h @@ -15,6 +15,7 @@ #ifndef LLVM_CLANG_AST_ASTCONTEXT_H #define LLVM_CLANG_AST_ASTCONTEXT_H +#include "dxc/DXIL/DxilConstants.h" #include "clang/AST/ASTTypeTraits.h" #include "clang/AST/CanonicalType.h" #include "clang/AST/CommentCommandTraits.h" @@ -130,6 +131,12 @@ class ASTContext : public RefCountedBase { mutable llvm::FoldingSet AtomicTypes; llvm::FoldingSet AttributedTypes; + // HLSL Change Start + llvm::FoldingSet AttrLinAlgMatrixTypes; + llvm::FoldingSet + DepAttrLinAlgMatrixTypes; + // HLSL Change End + mutable llvm::FoldingSet QualifiedTemplateNames; mutable llvm::FoldingSet DependentTemplateNames; mutable llvm::FoldingSet @@ -1156,6 +1163,19 @@ class ASTContext : public RefCountedBase { QualType modifiedType, QualType equivalentType); + // HLSL Change Start + QualType getAttributedLinAlgMatrixType(QualType WrappedTy, + hlsl::DXIL::ComponentType ComponentTy, + size_t Rows, size_t Cols, + hlsl::DXIL::MatrixUse Use, + hlsl::DXIL::MatrixScope Scope); + + QualType getDependentAttributedLinAlgMatrixType(QualType WrappedTy, + Expr *ComponentTyExpr, + Expr *RowsExpr, + Expr *ColsExpr, Expr *UseExpr, + Expr *ScopeExpr); + // HLSL Change End QualType getSubstTemplateTypeParmType(const TemplateTypeParmType *Replaced, QualType Replacement) const; QualType getSubstTemplateTypeParmPackType( diff --git a/tools/clang/include/clang/AST/DataRecursiveASTVisitor.h b/tools/clang/include/clang/AST/DataRecursiveASTVisitor.h index 50955d8ec3..78242b9f47 100644 --- a/tools/clang/include/clang/AST/DataRecursiveASTVisitor.h +++ b/tools/clang/include/clang/AST/DataRecursiveASTVisitor.h @@ -900,6 +900,21 @@ DEF_TRAVERSE_TYPE(AutoType, { TRY_TO(TraverseType(T->getDeducedType())); }) DEF_TRAVERSE_TYPE(RecordType, {}) DEF_TRAVERSE_TYPE(EnumType, {}) + +// HLSL Change Start +DEF_TRAVERSE_TYPE(AttributedLinAlgMatrixType, + { TRY_TO(TraverseType(T->getWrappedType())); }) + +DEF_TRAVERSE_TYPE(DependentAttributedLinAlgMatrixType, { + TRY_TO(TraverseType(T->getWrappedType())); + TRY_TO(TraverseStmt(T->getComponentTyExpr())); + TRY_TO(TraverseStmt(T->getRowsExpr())); + TRY_TO(TraverseStmt(T->getColsExpr())); + TRY_TO(TraverseStmt(T->getUseExpr())); + TRY_TO(TraverseStmt(T->getScopeExpr())); +}) +// HLSL Change End + DEF_TRAVERSE_TYPE(TemplateTypeParmType, {}) DEF_TRAVERSE_TYPE(SubstTemplateTypeParmType, {}) DEF_TRAVERSE_TYPE(SubstTemplateTypeParmPackType, {}) @@ -1119,6 +1134,10 @@ DEF_TRAVERSE_TYPELOC(AutoType, { DEF_TRAVERSE_TYPELOC(RecordType, {}) DEF_TRAVERSE_TYPELOC(EnumType, {}) +// HLSL Change Start +DEF_TRAVERSE_TYPELOC(AttributedLinAlgMatrixType, {}) +DEF_TRAVERSE_TYPELOC(DependentAttributedLinAlgMatrixType, {}) +// HLSL Change End DEF_TRAVERSE_TYPELOC(TemplateTypeParmType, {}) DEF_TRAVERSE_TYPELOC(SubstTemplateTypeParmType, {}) DEF_TRAVERSE_TYPELOC(SubstTemplateTypeParmPackType, {}) diff --git a/tools/clang/include/clang/AST/RecursiveASTVisitor.h b/tools/clang/include/clang/AST/RecursiveASTVisitor.h index c8c79664c8..a6e0fae9a1 100644 --- a/tools/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/tools/clang/include/clang/AST/RecursiveASTVisitor.h @@ -982,6 +982,20 @@ DEF_TRAVERSE_TYPE(InjectedClassNameType, {}) DEF_TRAVERSE_TYPE(AttributedType, { TRY_TO(TraverseType(T->getModifiedType())); }) +// HLSL Change Start +DEF_TRAVERSE_TYPE(AttributedLinAlgMatrixType, + { TRY_TO(TraverseType(T->getWrappedType())); }) + +DEF_TRAVERSE_TYPE(DependentAttributedLinAlgMatrixType, { + TRY_TO(TraverseType(T->getWrappedType())); + TRY_TO(TraverseStmt(T->getComponentTyExpr())); + TRY_TO(TraverseStmt(T->getRowsExpr())); + TRY_TO(TraverseStmt(T->getColsExpr())); + TRY_TO(TraverseStmt(T->getUseExpr())); + TRY_TO(TraverseStmt(T->getScopeExpr())); +}) +// HLSL Change End + DEF_TRAVERSE_TYPE(ParenType, { TRY_TO(TraverseType(T->getInnerType())); }) DEF_TRAVERSE_TYPE(ElaboratedType, { @@ -1206,6 +1220,11 @@ DEF_TRAVERSE_TYPELOC(ParenType, { TRY_TO(TraverseTypeLoc(TL.getInnerLoc())); }) DEF_TRAVERSE_TYPELOC(AttributedType, { TRY_TO(TraverseTypeLoc(TL.getModifiedLoc())); }) +// HLSL Change Start +DEF_TRAVERSE_TYPELOC(AttributedLinAlgMatrixType, {}) +DEF_TRAVERSE_TYPELOC(DependentAttributedLinAlgMatrixType, {}) +// HLSL Change End + DEF_TRAVERSE_TYPELOC(ElaboratedType, { if (TL.getQualifierLoc()) { TRY_TO(TraverseNestedNameSpecifierLoc(TL.getQualifierLoc())); diff --git a/tools/clang/include/clang/AST/Type.h b/tools/clang/include/clang/AST/Type.h index 7f23fd5fdf..a449b6b229 100644 --- a/tools/clang/include/clang/AST/Type.h +++ b/tools/clang/include/clang/AST/Type.h @@ -14,6 +14,7 @@ #ifndef LLVM_CLANG_AST_TYPE_H #define LLVM_CLANG_AST_TYPE_H +#include "dxc/DXIL/DxilConstants.h" #include "clang/AST/NestedNameSpecifier.h" #include "clang/AST/TemplateName.h" #include "clang/Basic/AddressSpaces.h" @@ -1699,8 +1700,15 @@ class Type : public ExtQualsTypeCommonBase { bool isOpenCLSpecificType() const; // Any OpenCL specific type + // HLSL Change Start bool isLinAlgMatrixType() const; // HLSL __builtin_LinAlgMatrix - + bool isAttributedLinAlgMatrixType() + const; // HLSL attributed __builtin_LinAlgMatrix + bool isDependentAttributedLinAlgMatrixType() + const; // HLSL attributed __builtin_LinAlgMatrix with dependent + // parameters + // HLSL Change End + /// Determines if this type, which must satisfy /// isObjCLifetimeType(), is implicitly __unsafe_unretained rather /// than implicitly __strong. @@ -3736,6 +3744,108 @@ class AttributedType : public Type, public llvm::FoldingSetNode { } }; +// HLSL Change Start + +class AttributedLinAlgMatrixType : public Type, public llvm::FoldingSetNode { + friend class ASTContext; // ASTContext creates these + + QualType WrappedType; // should be __builtin_LinAlgMatrix + hlsl::DXIL::ComponentType ComponentTy; + size_t Rows, Cols; + hlsl::DXIL::MatrixUse Use; + hlsl::DXIL::MatrixScope Scope; + + AttributedLinAlgMatrixType(QualType WrappedTy, + hlsl::DXIL::ComponentType ComponentTy, size_t Rows, + size_t Cols, hlsl::DXIL::MatrixUse Use, + hlsl::DXIL::MatrixScope Scope) + : Type(AttributedLinAlgMatrix, QualType(), /*Dependent*/ false, + /*InstantiationDependent*/ false, /*VariablyModified*/ false, + /*ContainsUnexpandedParameterPack*/ false), + WrappedType(WrappedTy), ComponentTy(ComponentTy), Rows(Rows), + Cols(Cols), Use(Use), Scope(Scope) {} + +public: + QualType getWrappedType() const { return WrappedType; } + + hlsl::DXIL::ComponentType getComponentType() const { return ComponentTy; } + size_t getRows() const { return Rows; } + size_t getCols() const { return Cols; } + hlsl::DXIL::MatrixUse getUse() const { return Use; } + hlsl::DXIL::MatrixScope getScope() const { return Scope; } + + void appendMangledAttributes(llvm::raw_ostream &OS) const; + + bool isSugared() const { return false; } + QualType desugar() const { return QualType(this, 0); } + + void Profile(llvm::FoldingSetNodeID &ID) { + Profile(ID, WrappedType, ComponentTy, Rows, Cols, Use, Scope); + } + + static void Profile(llvm::FoldingSetNodeID &ID, QualType WrappedTy, + hlsl::DXIL::ComponentType ComponentTy, size_t Rows, + size_t Cols, hlsl::DXIL::MatrixUse Use, + hlsl::DXIL::MatrixScope Scope) { + ID.AddPointer(WrappedTy.getAsOpaquePtr()); + ID.AddInteger(static_cast(ComponentTy)); + ID.AddInteger(static_cast(Rows)); + ID.AddInteger(static_cast(Cols)); + ID.AddInteger(static_cast(Use)); + ID.AddInteger(static_cast(Scope)); + } + + static bool classof(const Type *T) { + return T->getTypeClass() == AttributedLinAlgMatrix; + } +}; + +class DependentAttributedLinAlgMatrixType : public Type, + public llvm::FoldingSetNode { + const ASTContext &Context; + QualType WrappedType; // should be __builtin_LinAlgMatrix + Expr *ComponentTyExpr; + Expr *RowsExpr; + Expr *ColsExpr; + Expr *UseExpr; + Expr *ScopeExpr; + + DependentAttributedLinAlgMatrixType(const ASTContext &Context, + QualType WrappedType, + Expr *ComponentTyExpr, Expr *RowsExpr, + Expr *ColsExpr, Expr *UseExpr, + Expr *ScopeExpr); + + friend class ASTContext; + +public: + QualType getWrappedType() const { return WrappedType; } + Expr *getComponentTyExpr() const { return ComponentTyExpr; } + Expr *getRowsExpr() const { return RowsExpr; } + Expr *getColsExpr() const { return ColsExpr; } + Expr *getUseExpr() const { return UseExpr; } + Expr *getScopeExpr() const { return ScopeExpr; } + + bool isSugared() const { return false; } + QualType desugar() const { return QualType(this, 0); } + + static bool classof(const Type *T) { + return T->getTypeClass() == DependentAttributedLinAlgMatrix; + } + + void Profile(llvm::FoldingSetNodeID &ID) { + Profile(ID, Context, getWrappedType(), getComponentTyExpr(), getRowsExpr(), + getColsExpr(), getUseExpr(), getScopeExpr()); + } + + static void Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context, + QualType WrappedType, Expr *ComponentTyExpr, + Expr *RowsExpr, Expr *ColsExpr, Expr *UseExpr, + Expr *ScopeExpr); +}; + +// HLSL Change End + class TemplateTypeParmType : public Type, public llvm::FoldingSetNode { // Helper data collector for canonical types. struct CanonicalTTPTInfo { @@ -5426,6 +5536,14 @@ inline bool Type::isEventT() const { inline bool Type::isLinAlgMatrixType() const { return isSpecificBuiltinType(BuiltinType::LinAlgMatrix); } + +inline bool Type::isAttributedLinAlgMatrixType() const { + return isa(this); +} + +inline bool Type::isDependentAttributedLinAlgMatrixType() const { + return isa(this); +} // HLSL Change Ends inline bool Type::isImageType() const { diff --git a/tools/clang/include/clang/AST/TypeLoc.h b/tools/clang/include/clang/AST/TypeLoc.h index 40d92537df..e8a04cf65b 100644 --- a/tools/clang/include/clang/AST/TypeLoc.h +++ b/tools/clang/include/clang/AST/TypeLoc.h @@ -823,6 +823,17 @@ class AttributedTypeLoc : public ConcreteTypeLoc {}; + +class DependentAttributedLinAlgMatrixTypeLoc + : public InheritingConcreteTypeLoc {}; +// HLSL Change End struct ObjCObjectTypeLocInfo { SourceLocation TypeArgsLAngleLoc; diff --git a/tools/clang/include/clang/AST/TypeNodes.def b/tools/clang/include/clang/AST/TypeNodes.def index 2549f0bf50..f2664e66c1 100644 --- a/tools/clang/include/clang/AST/TypeNodes.def +++ b/tools/clang/include/clang/AST/TypeNodes.def @@ -92,6 +92,10 @@ TYPE(Record, TagType) TYPE(Enum, TagType) NON_CANONICAL_TYPE(Elaborated, Type) NON_CANONICAL_TYPE(Attributed, Type) +// HLSL Change Start +TYPE(AttributedLinAlgMatrix, Type) +DEPENDENT_TYPE(DependentAttributedLinAlgMatrix, Type) +// HLSL Change End DEPENDENT_TYPE(TemplateTypeParm, Type) NON_CANONICAL_TYPE(SubstTemplateTypeParm, Type) DEPENDENT_TYPE(SubstTemplateTypeParmPack, Type) diff --git a/tools/clang/include/clang/Basic/Attr.td b/tools/clang/include/clang/Basic/Attr.td index 33f1594cac..f220157d52 100644 --- a/tools/clang/include/clang/Basic/Attr.td +++ b/tools/clang/include/clang/Basic/Attr.td @@ -233,6 +233,7 @@ def Borland : LangOpt<"Borland">; def CUDA : LangOpt<"CUDA">; def COnly : LangOpt<"CPlusPlus", 1>; def SPIRV : LangOpt<"SPIRV">; // SPIRV Change +def HLSL : LangOpt<"HLSL">; // HLSL Change // Defines targets for target-specific attributes. The list of strings should // specify architectures for which the target applies, based off the ArchType @@ -1228,6 +1229,13 @@ def HLSLUnboundedSparseNodes : InheritableParamAttr { let Documentation = [Undocumented]; } +def HLSLLinAlgMatrixAttributes : TypeAttr { + let Spellings = [CXX11<"", "__LinAlgMatrix_Attributes", 2015>]; + let LangOpts = [HLSL]; + let Args = [ExprArgument<"ComponentTy">, ExprArgument<"M">, ExprArgument<"N">, + ExprArgument<"Use">, ExprArgument<"Scope">]; + let Documentation = [Undocumented]; +} // HLSL Change Ends // SPIRV Change Starts diff --git a/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td b/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td index 47c545c32b..c80726ed74 100644 --- a/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -8042,6 +8042,19 @@ def err_hlsl_linalg_matrix_dim_must_be_greater_than_zero: Error< def err_hlsl_linalg_matrix_layout_invalid : Error< "matrix layout %0 is not valid, must be in the range [%1, %2]">; +// SM 6.10 Linear Algebra Operations +def err_hlsl_linalg_matrix_attribute_arg_not_int_or_enum + : Error<"argument is not an integer%select{| or enumeration}0">; +def err_hlsl_linalg_matrix_attribute_arg_not_constant_value + : Error<"matrix attributes argument %1 is not a constant value">; +def err_hlsl_linalg_matrix_invalid_enum_attribute_value + : Error<"matrix attribute %0 has invalid value %1, must be in the range " + "[%2, %3]">; +def err_hlsl_linalg_matrix_attribute_on_invalid_type + : Error<"matrix attributes can only be applied to %0">; +def err_hlsl_linalg_attributed_matrix_required + : Error<"argument type must be linear algebra matrix with attributes">; + def err_hlsl_linalg_mul_muladd_output_vector_size_not_equal_to_matrix_M : Error< "output vector length must be equal to Matrix M dimension in a linalg Mul/MulAdd operation">; def err_hlsl_linalg_mul_muladd_unpacked_input_vector_size_not_equal_to_matrix_K : Error< diff --git a/tools/clang/include/clang/Sema/SemaHLSL.h b/tools/clang/include/clang/Sema/SemaHLSL.h index 80ce8ddd7d..1917a9277a 100644 --- a/tools/clang/include/clang/Sema/SemaHLSL.h +++ b/tools/clang/include/clang/Sema/SemaHLSL.h @@ -14,8 +14,10 @@ #ifndef LLVM_CLANG_SEMA_SEMAHLSL_H #define LLVM_CLANG_SEMA_SEMAHLSL_H +#include "dxc/DXIL/DxilConstants.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Attr.h" +#include "clang/AST/Type.h" #include "clang/Sema/Initialization.h" #include "clang/Sema/Lookup.h" #include "clang/Sema/Overload.h" @@ -263,6 +265,20 @@ clang::QualType CheckVectorConditional(clang::Sema *self, clang::ExprResult &LHS, clang::ExprResult &RHS, clang::SourceLocation QuestionLoc); + +bool HandleLinAlgMatrixAttributes(clang::Sema &S, clang::AttributeList &Attr, + clang::QualType &Type); + +bool CreateAttributedLinAlgMatrixType( + clang::Sema &S, clang::QualType WrappedTy, clang::Expr *ComponentTyExpr, + clang::Expr *RowsExpr, clang::Expr *ColsExpr, clang::Expr *UseExpr, + clang::Expr *ScopeExpr, clang::QualType &OutType); + +std::string +ConvertLinAlgMatrixComponentTypeToString(hlsl::DXIL::ComponentType CompType); +std::string ConvertLinAlgMatrixUseToString(hlsl::DXIL::MatrixUse Use); +std::string ConvertLinAlgMatrixScopeToString(hlsl::DXIL::MatrixScope Scope); + } // namespace hlsl bool IsTypeNumeric(clang::Sema *self, clang::QualType &type); diff --git a/tools/clang/include/clang/Serialization/ASTBitCodes.h b/tools/clang/include/clang/Serialization/ASTBitCodes.h index a3d3050077..9afa276efa 100644 --- a/tools/clang/include/clang/Serialization/ASTBitCodes.h +++ b/tools/clang/include/clang/Serialization/ASTBitCodes.h @@ -797,87 +797,89 @@ namespace clang { /// AST. enum TypeCode { /// \brief An ExtQualType record. - TYPE_EXT_QUAL = 1, + TYPE_EXT_QUAL = 1, /// \brief A ComplexType record. - TYPE_COMPLEX = 3, + TYPE_COMPLEX = 3, /// \brief A PointerType record. - TYPE_POINTER = 4, + TYPE_POINTER = 4, /// \brief A BlockPointerType record. - TYPE_BLOCK_POINTER = 5, + TYPE_BLOCK_POINTER = 5, /// \brief An LValueReferenceType record. - TYPE_LVALUE_REFERENCE = 6, + TYPE_LVALUE_REFERENCE = 6, /// \brief An RValueReferenceType record. - TYPE_RVALUE_REFERENCE = 7, + TYPE_RVALUE_REFERENCE = 7, /// \brief A MemberPointerType record. - TYPE_MEMBER_POINTER = 8, + TYPE_MEMBER_POINTER = 8, /// \brief A ConstantArrayType record. - TYPE_CONSTANT_ARRAY = 9, + TYPE_CONSTANT_ARRAY = 9, /// \brief An IncompleteArrayType record. - TYPE_INCOMPLETE_ARRAY = 10, + TYPE_INCOMPLETE_ARRAY = 10, /// \brief A VariableArrayType record. - TYPE_VARIABLE_ARRAY = 11, + TYPE_VARIABLE_ARRAY = 11, /// \brief A VectorType record. - TYPE_VECTOR = 12, + TYPE_VECTOR = 12, /// \brief An ExtVectorType record. - TYPE_EXT_VECTOR = 13, + TYPE_EXT_VECTOR = 13, /// \brief A FunctionNoProtoType record. - TYPE_FUNCTION_NO_PROTO = 14, + TYPE_FUNCTION_NO_PROTO = 14, /// \brief A FunctionProtoType record. - TYPE_FUNCTION_PROTO = 15, + TYPE_FUNCTION_PROTO = 15, /// \brief A TypedefType record. - TYPE_TYPEDEF = 16, + TYPE_TYPEDEF = 16, /// \brief A TypeOfExprType record. - TYPE_TYPEOF_EXPR = 17, + TYPE_TYPEOF_EXPR = 17, /// \brief A TypeOfType record. - TYPE_TYPEOF = 18, + TYPE_TYPEOF = 18, /// \brief A RecordType record. - TYPE_RECORD = 19, + TYPE_RECORD = 19, /// \brief An EnumType record. - TYPE_ENUM = 20, + TYPE_ENUM = 20, /// \brief An ObjCInterfaceType record. - TYPE_OBJC_INTERFACE = 21, + TYPE_OBJC_INTERFACE = 21, /// \brief An ObjCObjectPointerType record. - TYPE_OBJC_OBJECT_POINTER = 22, + TYPE_OBJC_OBJECT_POINTER = 22, /// \brief a DecltypeType record. - TYPE_DECLTYPE = 23, + TYPE_DECLTYPE = 23, /// \brief An ElaboratedType record. - TYPE_ELABORATED = 24, + TYPE_ELABORATED = 24, /// \brief A SubstTemplateTypeParmType record. TYPE_SUBST_TEMPLATE_TYPE_PARM = 25, /// \brief An UnresolvedUsingType record. - TYPE_UNRESOLVED_USING = 26, + TYPE_UNRESOLVED_USING = 26, /// \brief An InjectedClassNameType record. - TYPE_INJECTED_CLASS_NAME = 27, + TYPE_INJECTED_CLASS_NAME = 27, /// \brief An ObjCObjectType record. - TYPE_OBJC_OBJECT = 28, + TYPE_OBJC_OBJECT = 28, /// \brief An TemplateTypeParmType record. - TYPE_TEMPLATE_TYPE_PARM = 29, + TYPE_TEMPLATE_TYPE_PARM = 29, /// \brief An TemplateSpecializationType record. - TYPE_TEMPLATE_SPECIALIZATION = 30, + TYPE_TEMPLATE_SPECIALIZATION = 30, /// \brief A DependentNameType record. - TYPE_DEPENDENT_NAME = 31, + TYPE_DEPENDENT_NAME = 31, /// \brief A DependentTemplateSpecializationType record. TYPE_DEPENDENT_TEMPLATE_SPECIALIZATION = 32, /// \brief A DependentSizedArrayType record. - TYPE_DEPENDENT_SIZED_ARRAY = 33, + TYPE_DEPENDENT_SIZED_ARRAY = 33, /// \brief A ParenType record. - TYPE_PAREN = 34, + TYPE_PAREN = 34, /// \brief A PackExpansionType record. - TYPE_PACK_EXPANSION = 35, + TYPE_PACK_EXPANSION = 35, /// \brief An AttributedType record. - TYPE_ATTRIBUTED = 36, + TYPE_ATTRIBUTED = 36, /// \brief A SubstTemplateTypeParmPackType record. TYPE_SUBST_TEMPLATE_TYPE_PARM_PACK = 37, /// \brief A AutoType record. - TYPE_AUTO = 38, + TYPE_AUTO = 38, /// \brief A UnaryTransformType record. - TYPE_UNARY_TRANSFORM = 39, + TYPE_UNARY_TRANSFORM = 39, /// \brief An AtomicType record. - TYPE_ATOMIC = 40, + TYPE_ATOMIC = 40, /// \brief A DecayedType record. - TYPE_DECAYED = 41, + TYPE_DECAYED = 41, /// \brief An AdjustedType record. - TYPE_ADJUSTED = 42 + TYPE_ADJUSTED = 42, + /// \brief An AttributedLinAlgMatrixType record. + TYPE_ATTRIBUTED_LINALG_MATRIX = 43 }; /// \brief The type IDs for special types constructed by semantic diff --git a/tools/clang/lib/AST/ASTContext.cpp b/tools/clang/lib/AST/ASTContext.cpp index a3a362cfea..dd6ec2784d 100644 --- a/tools/clang/lib/AST/ASTContext.cpp +++ b/tools/clang/lib/AST/ASTContext.cpp @@ -13,6 +13,7 @@ #include "clang/AST/ASTContext.h" #include "CXXABI.h" +#include "dxc/DXIL/DxilConstants.h" #include "clang/AST/ASTMutationListener.h" #include "clang/AST/Attr.h" #include "clang/AST/CharUnits.h" @@ -29,6 +30,7 @@ #include "clang/AST/MangleNumberingContext.h" #include "clang/AST/RecordLayout.h" #include "clang/AST/RecursiveASTVisitor.h" +#include "clang/AST/Type.h" #include "clang/AST/TypeLoc.h" #include "clang/AST/VTableBuilder.h" #include "clang/Basic/Builtins.h" @@ -1867,6 +1869,10 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const { return getTypeInfo( cast(T)->getEquivalentType().getTypePtr()); + case Type::AttributedLinAlgMatrix: + return getTypeInfo( + cast(T)->getWrappedType().getTypePtr()); + case Type::Atomic: { // Start with the base type information. TypeInfo Info = getTypeInfo(cast(T)->getValueType()); @@ -3288,6 +3294,50 @@ QualType ASTContext::getAttributedType(AttributedType::Kind attrKind, return QualType(type, 0); } +// HLSL Change Start +QualType ASTContext::getAttributedLinAlgMatrixType( + QualType WrappedTy, hlsl::DXIL::ComponentType ComponentTy, size_t Rows, + size_t Cols, hlsl::DXIL::MatrixUse Use, hlsl::DXIL::MatrixScope Scope) { + + llvm::FoldingSetNodeID ID; + AttributedLinAlgMatrixType::Profile(ID, WrappedTy, ComponentTy, Rows, Cols, + Use, Scope); + void *InsertPos = nullptr; + AttributedLinAlgMatrixType *Ty = + AttrLinAlgMatrixTypes.FindNodeOrInsertPos(ID, InsertPos); + if (Ty) + return QualType(Ty, 0); + + Ty = new (*this, TypeAlignment) AttributedLinAlgMatrixType( + WrappedTy, ComponentTy, Rows, Cols, Use, Scope); + Types.push_back(Ty); + AttrLinAlgMatrixTypes.InsertNode(Ty, InsertPos); + return QualType(Ty, 0); +} + +QualType ASTContext::getDependentAttributedLinAlgMatrixType( + QualType WrappedTy, Expr *ComponentTyExpr, Expr *RowsExpr, Expr *ColsExpr, + Expr *UseExpr, Expr *ScopeExpr) { + llvm::FoldingSetNodeID ID; + DependentAttributedLinAlgMatrixType::Profile(ID, *this, WrappedTy, + ComponentTyExpr, RowsExpr, + ColsExpr, UseExpr, ScopeExpr); + + void *InsertPos = nullptr; + DependentAttributedLinAlgMatrixType *Ty = + DepAttrLinAlgMatrixTypes.FindNodeOrInsertPos(ID, InsertPos); + if (Ty) + return QualType(Ty, 0); + + Ty = new (*this, TypeAlignment) DependentAttributedLinAlgMatrixType( + *this, WrappedTy, ComponentTyExpr, RowsExpr, ColsExpr, UseExpr, + ScopeExpr); + + Types.push_back(Ty); + DepAttrLinAlgMatrixTypes.InsertNode(Ty, InsertPos); + return QualType(Ty, 0); +} +// HLSL Change Start /// \brief Retrieve a substitution-result type. QualType diff --git a/tools/clang/lib/AST/ASTDumper.cpp b/tools/clang/lib/AST/ASTDumper.cpp index 334542d7f1..2f47576bb0 100644 --- a/tools/clang/lib/AST/ASTDumper.cpp +++ b/tools/clang/lib/AST/ASTDumper.cpp @@ -23,6 +23,7 @@ #include "clang/AST/TypeVisitor.h" #include "clang/Basic/Module.h" #include "clang/Basic/SourceManager.h" +#include "clang/Sema/SemaHLSL.h" #include "llvm/Support/raw_ostream.h" using namespace clang; using namespace clang::comments; @@ -370,6 +371,26 @@ namespace { // FIXME: AttrKind dumpTypeAsChild(T->getModifiedType()); } + void VisitAttributedLinAlgMatrixType(const AttributedLinAlgMatrixType *T) { + dumpTypeAsChild(T->getWrappedType()); + OS << " [[__LinAlgMatrix_Attributes(" + << hlsl::ConvertLinAlgMatrixComponentTypeToString( + T->getComponentType()) + << ", " << T->getRows() << ", " << T->getCols() << ", " + << hlsl::ConvertLinAlgMatrixUseToString(T->getUse()) << ", " + << hlsl::ConvertLinAlgMatrixScopeToString(T->getScope()) << ")]]"; + } + void VisitDependentAttributedLinAlgMatrixType( + const DependentAttributedLinAlgMatrixType *T) { + OS << " "; + dumpTypeAsChild(T->getWrappedType()); + dumpStmt(T->getComponentTyExpr()); + dumpStmt(T->getRowsExpr()); + dumpStmt(T->getColsExpr()); + dumpStmt(T->getUseExpr()); + dumpStmt(T->getScopeExpr()); + } + void VisitTemplateTypeParmType(const TemplateTypeParmType *T) { OS << " depth " << T->getDepth() << " index " << T->getIndex(); if (T->isParameterPack()) OS << " pack"; diff --git a/tools/clang/lib/AST/ASTImporter.cpp b/tools/clang/lib/AST/ASTImporter.cpp index d1f4d832dd..e6bbc7ebb9 100644 --- a/tools/clang/lib/AST/ASTImporter.cpp +++ b/tools/clang/lib/AST/ASTImporter.cpp @@ -70,6 +70,10 @@ namespace clang { QualType VisitRecordType(const RecordType *T); QualType VisitEnumType(const EnumType *T); QualType VisitAttributedType(const AttributedType *T); + // HLSL Change Start + QualType + VisitAttributedLinAlgMatrixType(const AttributedLinAlgMatrixType *T); + // HLSL Change End // FIXME: TemplateTypeParmType // FIXME: SubstTemplateTypeParmType QualType VisitTemplateSpecializationType(const TemplateSpecializationType *T); @@ -644,7 +648,18 @@ static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context, cast(T2)->getEquivalentType())) return false; break; - + + // HLSL Change Start + case Type::AttributedLinAlgMatrix: + // FIXMME: Implement structural equivalence for matrix attributes. + llvm_unreachable("FIXME: Implement structural equivalence for HLSL " + "attributed linear algebra matrix type"); + + case Type::DependentAttributedLinAlgMatrix: + // FIXMME: Implement structural equivalence for dependent matrix attributes. + llvm_unreachable("FIXME: Implement structural equivalence for dependent " + "HLSL attributed linear algebra matrix type"); + // HLSL Change End case Type::Paren: if (!IsStructurallyEquivalent(Context, cast(T1)->getInnerType(), @@ -1792,6 +1807,24 @@ QualType ASTNodeImporter::VisitAttributedType(const AttributedType *T) { ToModifiedType, ToEquivalentType); } +// HLSL Change Start +QualType ASTNodeImporter::VisitAttributedLinAlgMatrixType( + const AttributedLinAlgMatrixType *T) { + QualType FromWrappedTy = T->getWrappedType(); + QualType ToWrappedTy; + + if (!FromWrappedTy.isNull()) { + ToWrappedTy = Importer.Import(FromWrappedTy); + if (ToWrappedTy.isNull()) + return QualType(); + } + + return Importer.getToContext().getAttributedLinAlgMatrixType( + ToWrappedTy, T->getComponentType(), T->getRows(), T->getCols(), + T->getUse(), T->getScope()); +} +// HLSL Change End + QualType ASTNodeImporter::VisitTemplateSpecializationType( const TemplateSpecializationType *T) { TemplateName ToTemplate = Importer.Import(T->getTemplateName()); diff --git a/tools/clang/lib/AST/ItaniumMangle.cpp b/tools/clang/lib/AST/ItaniumMangle.cpp index 32a146b633..d4f8f2e186 100644 --- a/tools/clang/lib/AST/ItaniumMangle.cpp +++ b/tools/clang/lib/AST/ItaniumMangle.cpp @@ -1509,6 +1509,10 @@ bool CXXNameMangler::mangleUnresolvedTypeOrSimpleId(QualType Ty, case Type::FunctionNoProto: case Type::Paren: case Type::Attributed: + // HLSL Change Start + case Type::AttributedLinAlgMatrix: + case Type::DependentAttributedLinAlgMatrix: + // HLSL Change End case Type::Auto: case Type::PackExpansion: case Type::ObjCObject: @@ -2576,6 +2580,16 @@ void CXXNameMangler::mangleType(const AtomicType *T) { mangleType(T->getValueType()); } +// HLSL Change Start +void CXXNameMangler::mangleType(const AttributedLinAlgMatrixType *) { + llvm_unreachable("DXC uses Microsoft name mangling"); +} + +void CXXNameMangler::mangleType(const DependentAttributedLinAlgMatrixType *) { + llvm_unreachable("DXC uses Microsoft name mangling"); +} +// HLSL Change End + void CXXNameMangler::mangleIntegerLiteral(QualType T, const llvm::APSInt &Value) { // ::= L E # integer literal diff --git a/tools/clang/lib/AST/MicrosoftMangle.cpp b/tools/clang/lib/AST/MicrosoftMangle.cpp index cf7f0bc2b3..138afb5f0f 100644 --- a/tools/clang/lib/AST/MicrosoftMangle.cpp +++ b/tools/clang/lib/AST/MicrosoftMangle.cpp @@ -2245,6 +2245,23 @@ void MicrosoftCXXNameMangler::mangleType(const AtomicType *T, Qualifiers, << Range; } +void MicrosoftCXXNameMangler::mangleType(const AttributedLinAlgMatrixType *T, + Qualifiers, SourceRange Range) { + Out << "$linalg_matrix"; + T->appendMangledAttributes(Out); + Out << "@"; +} + +void MicrosoftCXXNameMangler::mangleType( + const DependentAttributedLinAlgMatrixType *T, Qualifiers, + SourceRange Range) { + DiagnosticsEngine &Diags = Context.getDiags(); + unsigned DiagID = Diags.getCustomDiagID( + DiagnosticsEngine::Error, "cannot mangle this dependent-sized HLSL " + "attributed linear algebra matrix type yet"); + Diags.Report(Range.getBegin(), DiagID) << Range; +} + void MicrosoftMangleContextImpl::mangleCXXName(const NamedDecl *D, raw_ostream &Out) { assert((isa(D) || isa(D)) && diff --git a/tools/clang/lib/AST/Type.cpp b/tools/clang/lib/AST/Type.cpp index cbfa49dfd8..a937edae17 100644 --- a/tools/clang/lib/AST/Type.cpp +++ b/tools/clang/lib/AST/Type.cpp @@ -923,6 +923,22 @@ struct SimpleTransformVisitor equivalentType); } + // HLSL Change Start + QualType + VisitAttributedLinAlgMatrixType(const AttributedLinAlgMatrixType *T) { + QualType wrappedTy = recurse(T->getWrappedType()); + if (wrappedTy.isNull()) + return QualType(); + + if (wrappedTy.getAsOpaquePtr() == T->getWrappedType().getAsOpaquePtr()) + return QualType(T, 0); + + return Ctx.getAttributedLinAlgMatrixType(wrappedTy, T->getComponentType(), + T->getRows(), T->getCols(), + T->getUse(), T->getScope()); + } + // HLSL Change End + QualType VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *T) { QualType replacementType = recurse(T->getReplacementType()); if (replacementType.isNull()) @@ -1564,6 +1580,10 @@ namespace { AutoType *VisitAttributedType(const AttributedType *T) { return Visit(T->getModifiedType()); } + AutoType * + VisitAttributedLinAlgMatrixType(const AttributedLinAlgMatrixType *T) { + return Visit(T->getWrappedType()); + } AutoType *VisitAdjustedType(const AdjustedType *T) { return Visit(T->getOriginalType()); } @@ -2998,6 +3018,39 @@ bool AttributedType::isCallingConv() const { llvm_unreachable("invalid attr kind"); } +// HLSL Change Start +void AttributedLinAlgMatrixType::appendMangledAttributes( + llvm::raw_ostream &OS) const { + OS << "C" << static_cast(getComponentType()) << "M" + << static_cast(getRows()) << "N" + << static_cast(getCols()) << "U" + << static_cast(getUse()) << "S" + << static_cast(getScope()); +} + +DependentAttributedLinAlgMatrixType::DependentAttributedLinAlgMatrixType( + const ASTContext &Context, QualType WrappedType, Expr *ComponentTyExpr, + Expr *RowsExpr, Expr *ColsExpr, Expr *UseExpr, Expr *ScopeExpr) + : Type(DependentAttributedLinAlgMatrix, QualType(), /*Dependent=*/true, + /*InstantiationDependent=*/true, /*VariablyModified*/ false, + /*ContainsUnexpandedParameterPack=*/false), + Context(Context), WrappedType(WrappedType), + ComponentTyExpr(ComponentTyExpr), RowsExpr(RowsExpr), ColsExpr(ColsExpr), + UseExpr(UseExpr), ScopeExpr(ScopeExpr) {} + +void DependentAttributedLinAlgMatrixType::Profile( + llvm::FoldingSetNodeID &ID, const ASTContext &Context, QualType WrappedType, + Expr *ComponentTyExpr, Expr *RowsExpr, Expr *ColsExpr, Expr *UseExpr, + Expr *ScopeExpr) { + ID.AddPointer(WrappedType.getAsOpaquePtr()); + ComponentTyExpr->Profile(ID, Context, true); + RowsExpr->Profile(ID, Context, true); + ColsExpr->Profile(ID, Context, true); + UseExpr->Profile(ID, Context, true); + ScopeExpr->Profile(ID, Context, true); +} +// HLSL Change End + CXXRecordDecl *InjectedClassNameType::getDecl() const { return cast(getInterestingTagDecl(Decl)); } @@ -3302,6 +3355,8 @@ static CachedProperties computeCachedProperties(const Type *T) { return Cache::get(cast(T)->getPointeeType()); case Type::Atomic: return Cache::get(cast(T)->getValueType()); + case Type::AttributedLinAlgMatrix: + return Cache::get(cast(T)->getWrappedType()); } llvm_unreachable("unhandled type class"); @@ -3527,6 +3582,8 @@ bool Type::canHaveNullability() const { case Type::FunctionNoProto: case Type::Record: case Type::Enum: + case Type::AttributedLinAlgMatrix: + case Type::DependentAttributedLinAlgMatrix: case Type::InjectedClassName: case Type::PackExpansion: case Type::ObjCObject: diff --git a/tools/clang/lib/AST/TypePrinter.cpp b/tools/clang/lib/AST/TypePrinter.cpp index ca9e15bfd7..5ee24a7749 100644 --- a/tools/clang/lib/AST/TypePrinter.cpp +++ b/tools/clang/lib/AST/TypePrinter.cpp @@ -11,15 +11,16 @@ // //===----------------------------------------------------------------------===// -#include "clang/AST/PrettyPrinter.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Decl.h" #include "clang/AST/DeclObjC.h" #include "clang/AST/DeclTemplate.h" #include "clang/AST/Expr.h" +#include "clang/AST/PrettyPrinter.h" #include "clang/AST/Type.h" #include "clang/Basic/LangOptions.h" #include "clang/Basic/SourceManager.h" +#include "clang/Sema/SemaHLSL.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/SaveAndRestore.h" @@ -244,6 +245,7 @@ bool TypePrinter::canPrefixQualifiers(const Type *T, case Type::FunctionNoProto: case Type::Paren: case Type::Attributed: + case Type::AttributedLinAlgMatrix: // HLSL Change case Type::PackExpansion: case Type::SubstTemplateTypeParm: CanPrefixQualifiers = false; @@ -1363,6 +1365,44 @@ void TypePrinter::printAttributedAfter(const AttributedType *T, OS << "))"; } +// HLSL Change Start +void TypePrinter::printAttributedLinAlgMatrixAfter( + const AttributedLinAlgMatrixType *T, raw_ostream &OS) { + OS << " [[__LinAlgMatrix_Attributes(" + << hlsl::ConvertLinAlgMatrixComponentTypeToString(T->getComponentType()) + << ", " << T->getRows() << ", " << T->getCols() << ", " + << hlsl::ConvertLinAlgMatrixUseToString(T->getUse()) << ", " + << hlsl::ConvertLinAlgMatrixScopeToString(T->getScope()) << ")]]"; + printAfter(T->getWrappedType(), OS); +} + +void TypePrinter::printAttributedLinAlgMatrixBefore( + const AttributedLinAlgMatrixType *T, raw_ostream &OS) { + printBefore(T->getWrappedType(), OS); +} + +void TypePrinter::printDependentAttributedLinAlgMatrixBefore( + const DependentAttributedLinAlgMatrixType *T, raw_ostream &OS) { + printBefore(T->getWrappedType(), OS); +} + +void TypePrinter::printDependentAttributedLinAlgMatrixAfter( + const DependentAttributedLinAlgMatrixType *T, raw_ostream &OS) { + OS << " [[__LinAlgMatrix_Attributes("; + T->getComponentTyExpr()->printPretty(OS, nullptr, Policy); + OS << ", "; + T->getRowsExpr()->printPretty(OS, nullptr, Policy); + OS << ", "; + T->getColsExpr()->printPretty(OS, nullptr, Policy); + OS << ", "; + T->getUseExpr()->printPretty(OS, nullptr, Policy); + OS << ", "; + T->getScopeExpr()->printPretty(OS, nullptr, Policy); + OS << ")]]"; + printAfter(T->getWrappedType(), OS); +} +// HLSL Change End + void TypePrinter::printObjCInterfaceBefore(const ObjCInterfaceType *T, raw_ostream &OS) { OS << T->getDecl()->getName(); diff --git a/tools/clang/lib/Headers/hlsl/dx/linalg.h b/tools/clang/lib/Headers/hlsl/dx/linalg.h index 4f5e62070d..d916401a45 100644 --- a/tools/clang/lib/Headers/hlsl/dx/linalg.h +++ b/tools/clang/lib/Headers/hlsl/dx/linalg.h @@ -196,3 +196,79 @@ void VectorAccumulate(vector InputVector, } // namespace dx #endif // SM 6.9 check and HV version check + +#if ((__SHADER_TARGET_MAJOR > 6) || \ + (__SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR >= 10)) && \ + (__HLSL_VERSION >= 2021) + +namespace hlsl { + +#ifdef __hlsl_dx_compiler +#define SIZE_TYPE int +#else +#define SIZE_TYPE uint +#endif + +} // namespace hlsl + +namespace dx { + +namespace linalg { + +struct ComponentType { + enum ComponentEnum { + Invalid = 0, + I1 = 1, + I16 = 2, + U16 = 3, + I32 = 4, + U32 = 5, + I64 = 6, + U64 = 7, + F16 = 8, + F32 = 9, + F64 = 10, + SNormF16 = 11, + UNormF16 = 12, + SNormF32 = 13, + UNormF32 = 14, + SNormF64 = 15, + UNormF64 = 16, + PackedS8x32 = 17, + PackedU8x32 = 18, + U8 = 19, + I8 = 20, + F8_E4M3 = 21, + F8_E5M2 = 22, + }; +}; +using ComponentEnum = ComponentType::ComponentEnum; + +struct MatrixUse { + enum MatrixUseEnum { A = 0, B = 1, Accumulator = 2 }; +}; +using MatrixUseEnum = MatrixUse::MatrixUseEnum; + +struct MatrixScope { + enum MatrixScopeEnum { + Thread = 0, + Wave = 1, + ThreadGroup = 2, + }; +}; +using MatrixScopeEnum = MatrixScope::MatrixScopeEnum; + +template +class Matrix { + int TestField; + using HandleT = __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(ComponentTy, M, N, Use, Scope)]]; + HandleT __handle; +}; + +} // namespace linalg + +} // namespace dx + +#endif // SM 6.10 check and HV version check diff --git a/tools/clang/lib/Parse/ParseDecl.cpp b/tools/clang/lib/Parse/ParseDecl.cpp index eed90d0a78..23d86923c2 100644 --- a/tools/clang/lib/Parse/ParseDecl.cpp +++ b/tools/clang/lib/Parse/ParseDecl.cpp @@ -3356,11 +3356,20 @@ void Parser::ParseDeclarationSpecifiers(DeclSpec &DS, if (!AttrsLastTime) ProhibitAttributes(attrs); else { - // Reject C++11 attributes that appertain to decl specifiers as - // we don't support any C++11 attributes that appertain to decl - // specifiers. This also conforms to what g++ 4.8 is doing. - ProhibitCXX11Attributes(attrs); - + // HLSL Change Start + // Reject attributes that aren't type attributes. Unknown attributes + // are diagnosed elsewhere. + AttributeList *Attr = attrs.getList(); + while (Attr) { + if (!Attr->isTypeAttr() && + Attr->getKind() != AttributeList::UnknownAttribute) { + Diag(Attr->getLoc(), diag::err_attribute_not_type_attr) + << Attr->getName(); + Attr->setInvalid(); + } + Attr = Attr->getNext(); + } + // HLSL Change End DS.takeAttributesFrom(attrs); } @@ -3371,7 +3380,7 @@ void Parser::ParseDeclarationSpecifiers(DeclSpec &DS, case tok::l_square: case tok::kw_alignas: - if (!getLangOpts().CPlusPlus11 || !isCXX11AttributeSpecifier()) + if (!isCXX11AttributeSpecifier()) // HLSL Change goto DoneWithDeclSpec; ProhibitAttributes(attrs); diff --git a/tools/clang/lib/Parse/ParseDeclCXX.cpp b/tools/clang/lib/Parse/ParseDeclCXX.cpp index de46dfdb43..8412987a39 100644 --- a/tools/clang/lib/Parse/ParseDeclCXX.cpp +++ b/tools/clang/lib/Parse/ParseDeclCXX.cpp @@ -3741,6 +3741,19 @@ static bool IsBuiltInOrStandardCXX11Attribute(IdentifierInfo *AttrName, } } +// HLSL Change Start +static bool hasCXXAttributeInHLSL(IdentifierInfo *ScopeName, + IdentifierInfo *AttrName) { + if (ScopeName && ScopeName->getName() != "") + return false; + + StringRef Name = AttrName->getName(); + return llvm::StringSwitch(Name) + .Case("__LinAlgMatrix_Attributes", true) + .Default(false); +} +// HLSL Change End + /// ParseCXX11AttributeArgs -- Parse a C++11 attribute-argument-clause. /// /// [C++11] attribute-argument-clause: @@ -3767,7 +3780,8 @@ bool Parser::ParseCXX11AttributeArgs(IdentifierInfo *AttrName, // If the attribute isn't known, we will not attempt to parse any // arguments. if (!hasAttribute(AttrSyntax::CXX, ScopeName, AttrName, - getTargetInfo().getTriple(), getLangOpts())) { + getTargetInfo().getTriple(), getLangOpts()) && + !hasCXXAttributeInHLSL(ScopeName, AttrName)) { // HLSL Change // Eat the left paren, then skip to the ending right paren. ConsumeParen(); SkipUntil(tok::r_paren); diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index d83632e897..5bb4b2a1db 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -31,6 +31,7 @@ #include "clang/AST/ExprCXX.h" #include "clang/AST/ExternalASTSource.h" #include "clang/AST/HlslTypes.h" +#include "clang/AST/Type.h" #include "clang/AST/TypeLoc.h" #include "clang/Basic/Diagnostic.h" #include "clang/Basic/Specifiers.h" @@ -254,6 +255,9 @@ enum ArBasicKind { // Shader Execution Reordering AR_OBJECT_HIT_OBJECT, + // Linear Algebra + AR_OBJECT_LINALG_MATRIX, + AR_BASIC_MAXIMUM_COUNT }; @@ -609,6 +613,9 @@ const UINT g_uBasicKindProps[] = { // Shader Execution Reordering LICOMPTYPE_HIT_OBJECT, // AR_OBJECT_HIT_OBJECT, + // Linear Algebra + LICOMPTYPE_LINALG_MATRIX, // AR_OBJECT_LINALG_MATRIX, + // AR_BASIC_MAXIMUM_COUNT }; @@ -671,18 +678,19 @@ enum ArTypeObjectKind { AR_TOBJ_INVALID, // Flag for an unassigned / unavailable object type. AR_TOBJ_VOID, // Represents the type for functions with not returned valued. AR_TOBJ_BASIC, // Represents a primitive type. - AR_TOBJ_COMPOUND, // Represents a struct or class. - AR_TOBJ_INTERFACE, // Represents an interface. - AR_TOBJ_POINTER, // Represents a pointer to another type. - AR_TOBJ_OBJECT, // Represents a built-in object. - AR_TOBJ_ARRAY, // Represents an array of other types. - AR_TOBJ_MATRIX, // Represents a matrix of basic types. - AR_TOBJ_VECTOR, // Represents a vector of basic types. - AR_TOBJ_QUALIFIER, // Represents another type plus an ArTypeQualifier. - AR_TOBJ_INNER_OBJ, // Represents a built-in inner object, such as an - // indexer object used to implement .mips[1]. - AR_TOBJ_STRING, // Represents a string - AR_TOBJ_DEPENDENT, // Dependent type for template. + AR_TOBJ_COMPOUND, // Represents a struct or class. + AR_TOBJ_INTERFACE, // Represents an interface. + AR_TOBJ_POINTER, // Represents a pointer to another type. + AR_TOBJ_OBJECT, // Represents a built-in object. + AR_TOBJ_ARRAY, // Represents an array of other types. + AR_TOBJ_MATRIX, // Represents a matrix of basic types. + AR_TOBJ_VECTOR, // Represents a vector of basic types. + AR_TOBJ_QUALIFIER, // Represents another type plus an ArTypeQualifier. + AR_TOBJ_INNER_OBJ, // Represents a built-in inner object, such as an + // indexer object used to implement .mips[1]. + AR_TOBJ_STRING, // Represents a string + AR_TOBJ_DEPENDENT, // Dependent type for template. + AR_TOBJ_LINALG_MATRIX // LinAlg Matric type }; enum TYPE_CONVERSION_FLAGS { @@ -1252,6 +1260,10 @@ static const ArBasicKind g_AnyOutputRecordCT[] = { static const ArBasicKind g_DxHitObjectCT[] = {AR_OBJECT_HIT_OBJECT, AR_BASIC_UNKNOWN}; +// Linear Algebra +static const ArBasicKind g_LinAlgMatrixCT[] = {AR_OBJECT_LINALG_MATRIX, + AR_BASIC_UNKNOWN}; + #ifdef ENABLE_SPIRV_CODEGEN static const ArBasicKind g_VKBufferPointerCT[] = {AR_OBJECT_VK_BUFFER_POINTER, AR_BASIC_UNKNOWN}; @@ -1313,6 +1325,7 @@ const ArBasicKind *g_LegalIntrinsicCompTypes[] = { g_ThreadNodeOutputRecordsCT, // LICOMPTYPE_THREAD_NODE_OUTPUT_RECORDS g_DxHitObjectCT, // LICOMPTYPE_HIT_OBJECT g_RayQueryCT, // LICOMPTYPE_RAY_QUERY + g_LinAlgMatrixCT, // LICOMPTYPE_LINALG_MATRIX g_LinAlgCT, // LICOMPTYPE_LINALG g_BuiltInTrianglePositionsCT, // LICOMPTYPE_BUILTIN_TRIANGLE_POSITIONS #ifdef ENABLE_SPIRV_CODEGEN @@ -1411,7 +1424,10 @@ static const ArBasicKind g_ArBasicKindsAsTypes[] = { AR_OBJECT_THREAD_NODE_OUTPUT_RECORDS, AR_OBJECT_GROUP_NODE_OUTPUT_RECORDS, // Shader Execution Reordering - AR_OBJECT_HIT_OBJECT}; + AR_OBJECT_HIT_OBJECT, + + // LinAlg Matrix + AR_OBJECT_LINALG_MATRIX}; // Count of template arguments for basic kind of objects that look like // templates (one or more type arguments). @@ -1532,6 +1548,9 @@ static const uint8_t g_ArBasicKindsTemplateCount[] = { // Shader Execution Reordering 0, // AR_OBJECT_HIT_OBJECT, + + // LinAlg Matrix + 0, // AR_OBJECT_LINALG_MATRIX, }; C_ASSERT(_countof(g_ArBasicKindsAsTypes) == @@ -1683,6 +1702,9 @@ static const SubscriptOperatorRecord g_ArBasicKindsSubscripts[] = { // Shader Execution Reordering {0, MipsFalse, SampleFalse}, // AR_OBJECT_HIT_OBJECT, + + // LinAlg Matrix + {0, MipsFalse, SampleFalse}, // AR_OBJECT_LINALG_MATRIX }; C_ASSERT(_countof(g_ArBasicKindsAsTypes) == _countof(g_ArBasicKindsSubscripts)); @@ -1851,6 +1873,9 @@ static const char *g_ArBasicTypeNames[] = { // Shader Execution Reordering "HitObject", + + // LinAlg Matrix + "__builtin_LinAlgMatrix", }; C_ASSERT(_countof(g_ArBasicTypeNames) == AR_BASIC_MAXIMUM_COUNT); @@ -3640,6 +3665,9 @@ class HLSLExternalSource : public ExternalSemaSource { case LICOMPTYPE_HIT_OBJECT: paramTypes.push_back(GetBasicKindType(AR_OBJECT_HIT_OBJECT)); break; + case LICOMPTYPE_LINALG_MATRIX: + paramTypes.push_back(GetBasicKindType(AR_OBJECT_LINALG_MATRIX)); + break; case LICOMPTYPE_BUILTIN_TRIANGLE_POSITIONS: paramTypes.push_back( GetBasicKindType(AR_OBJECT_BUILTIN_TRIANGLE_POSITIONS)); @@ -4502,6 +4530,10 @@ class HLSLExternalSource : public ExternalSemaSource { if (type->isPointerType()) { return hlsl::IsPointerStringType(type) ? AR_TOBJ_STRING : AR_TOBJ_POINTER; } + if (type->isAttributedLinAlgMatrixType() || + type->isDependentAttributedLinAlgMatrixType()) { + return AR_TOBJ_LINALG_MATRIX; + } if (type->isDependentType()) { return AR_TOBJ_DEPENDENT; } @@ -4934,6 +4966,9 @@ class HLSLExternalSource : public ExternalSemaSource { return m_context->getTagDeclType(this->m_objectTypeDecls[index]); } + case AR_OBJECT_LINALG_MATRIX: + return m_context->LinAlgMatrixTy; + case AR_OBJECT_SAMPLER1D: case AR_OBJECT_SAMPLER2D: case AR_OBJECT_SAMPLER3D: @@ -6794,6 +6829,17 @@ bool HLSLExternalSource::MatchArguments( return false; } + if (pIntrinsicArg->uLegalComponentTypes == LICOMPTYPE_LINALG_MATRIX) { + if (TypeInfoShapeKind == AR_TOBJ_LINALG_MATRIX) { + ++iArg; + continue; + } + m_sema->Diag(pCallArg->getExprLoc(), + diag::err_hlsl_linalg_attributed_matrix_required); + badArgIdx = iArg; + return false; + } + ASTContext &actx = m_sema->getASTContext(); // Usage @@ -13663,7 +13709,8 @@ bool FlattenedTypeIterator::considerLeaf() { ArTypeObjectKind objectKind = m_source.GetTypeObjectKind(tracker.Type); if (objectKind != ArTypeObjectKind::AR_TOBJ_BASIC && objectKind != ArTypeObjectKind::AR_TOBJ_OBJECT && - objectKind != ArTypeObjectKind::AR_TOBJ_STRING) { + objectKind != ArTypeObjectKind::AR_TOBJ_STRING && + objectKind != ArTypeObjectKind::AR_TOBJ_LINALG_MATRIX) { if (pushTrackerForType(tracker.Type, tracker.CurrentExpr)) { result = considerLeaf(); } @@ -13885,6 +13932,11 @@ bool FlattenedTypeIterator::pushTrackerForType( type.getCanonicalType(), 1, expression)); return true; } + case ArTypeObjectKind::AR_TOBJ_LINALG_MATRIX: { + m_typeTrackers.push_back(FlattenedTypeIterator::FlattenedTypeTracker( + type.getCanonicalType(), 1, expression)); + return true; + } default: DXASSERT(false, "unreachable"); return false; @@ -17775,4 +17827,256 @@ void DiagnoseEntry(Sema &S, FunctionDecl *FD) { } } } + +// Returns false on error +static bool verifyLinAlgMatrixSizeArg(Sema &S, Expr *Arg, bool &IsDependent, + size_t &OutValue) { + QualType QT = Arg->getType(); + + // Check that the type is an integer type. + if (!QT->isIntegerType()) { + S.Diag(Arg->getExprLoc(), + diag::err_hlsl_linalg_matrix_attribute_arg_not_int_or_enum) + << 0 << Arg->getSourceRange(); + return false; + } + + // That's all we can do for dependent expressions. + if (Arg->isValueDependent()) { + IsDependent = true; + return true; + } + + // Check that it is a constant value. + llvm::APSInt APVal; + if (!Arg->isIntegerConstantExpr(APVal, S.Context)) { + S.Diag(Arg->getExprLoc(), + diag::err_hlsl_linalg_matrix_attribute_arg_not_constant_value) + << Arg << Arg->getSourceRange(); + return false; + } + + // Check that the value is a valid range. + int64_t Value = APVal.getLimitedValue(); + if (Value < 0) { + S.Diag(Arg->getExprLoc(), + diag::err_hlsl_linalg_matrix_dim_must_be_greater_than_zero) + << Arg->getSourceRange(); + return false; + } + + OutValue = (size_t)Value; + return true; +} + +// Returns false on error +template +static bool verifyLinAlgMatrixEnumArg(Sema &S, Expr *Arg, const char *EnumName, + unsigned MinValue, unsigned MaxValue, + bool &IsDependent, EnumT &OutValue) { + QualType QT = Arg->getType(); + + // Check that the type is an integer or enumeration type. + if (!QT->isIntegralOrEnumerationType()) { + S.Diag(Arg->getExprLoc(), + diag::err_hlsl_linalg_matrix_attribute_arg_not_int_or_enum) + << 1 << Arg->getSourceRange(); + return false; + } + + // That's all we can do for dependent expressions. + if (Arg->isValueDependent()) { + IsDependent = true; + return true; + } + + // Check that it is a constant value. + llvm::APSInt APVal; + if (!Arg->isIntegerConstantExpr(APVal, S.Context)) { + S.Diag(Arg->getExprLoc(), + diag::err_hlsl_linalg_matrix_attribute_arg_not_constant_value) + << Arg->getSourceRange(); + return false; + } + + // Check that the value is a valid range. + int64_t Value = APVal.getLimitedValue(); + if (Value < (int64_t)MinValue || Value > (int64_t)MaxValue) { + S.Diags.Report(Arg->getExprLoc(), + diag::err_hlsl_linalg_matrix_invalid_enum_attribute_value) + << EnumName << std::to_string(Value) << std::to_string(MinValue) + << std::to_string(MaxValue); + return false; + } + + OutValue = (EnumT)Value; + return true; +} + +// Returns false on error +bool CreateAttributedLinAlgMatrixType( + clang::Sema &S, clang::QualType WrappedTy, clang::Expr *ComponentTyExpr, + clang::Expr *RowsExpr, clang::Expr *ColsExpr, clang::Expr *UseExpr, + clang::Expr *ScopeExpr, clang::QualType &OutType) { + + bool IsDependent = false; + + // Verify component type argument. + hlsl::DXIL::ComponentType CompTyValue = hlsl::DXIL::ComponentType::Invalid; + if (!verifyLinAlgMatrixEnumArg( + S, ComponentTyExpr, "ComponentEnum", + static_cast(hlsl::DXIL::ComponentType::I1), + static_cast(hlsl::DXIL::ComponentType::LastEntry) - 1, + IsDependent, CompTyValue)) + return false; + + // Verify size arguments + size_t RowsValue = 0; + size_t ColsValue = 0; + if (!verifyLinAlgMatrixSizeArg(S, RowsExpr, IsDependent, RowsValue) || + !verifyLinAlgMatrixSizeArg(S, ColsExpr, IsDependent, ColsValue)) + return false; + + // Verify matrix Use argument. + hlsl::DXIL::MatrixUse UseValue = hlsl::DXIL::MatrixUse::A; + if (!verifyLinAlgMatrixEnumArg( + S, UseExpr, "MatrixUseEnum", + static_cast(hlsl::DXIL::MatrixUse::A), + static_cast(hlsl::DXIL::MatrixUse::Accumulator), + IsDependent, UseValue)) + return false; + + // Verify matrix Scope argument. + hlsl::DXIL::MatrixScope ScopeValue = hlsl::DXIL::MatrixScope::Thread; + if (!verifyLinAlgMatrixEnumArg( + S, ScopeExpr, "MatrixScopeEnum", + static_cast(hlsl::DXIL::MatrixScope::Thread), + static_cast(hlsl::DXIL::MatrixScope::ThreadGroup), + IsDependent, ScopeValue)) + return false; + + // Create one of tyhe two LinAlg Matrix attributed types based on whether + // it has dependent attributes or not. + if (IsDependent) + OutType = S.Context.getDependentAttributedLinAlgMatrixType( + WrappedTy, ComponentTyExpr, RowsExpr, ColsExpr, UseExpr, ScopeExpr); + else + OutType = S.Context.getAttributedLinAlgMatrixType( + WrappedTy, CompTyValue, RowsValue, ColsValue, UseValue, ScopeValue); + return true; +} + +// Returns true on error +bool HandleLinAlgMatrixAttributes(clang::Sema &S, clang::AttributeList &Attr, + clang::QualType &Type) { + + assert(Attr.getKind() == AttributeList::AT_HLSLLinAlgMatrixAttributes && + "unexpected attribute"); + + QualType CanonTy = Type.getCanonicalType(); + if (!CanonTy->isLinAlgMatrixType()) { + const auto *LinAlgMTy = cast(S.getASTContext().LinAlgMatrixTy); + PrintingPolicy PP(S.getLangOpts()); + S.Diag(Attr.getLoc(), + diag::err_hlsl_linalg_matrix_attribute_on_invalid_type) + << LinAlgMTy->getName(PP) << Attr.getLoc(); + return true; + } + + if (Attr.getNumArgs() != 5) { + S.Diag(Attr.getLoc(), diag::err_attribute_wrong_number_arguments) + << Attr.getName() << 5; + Attr.setInvalid(); + return true; + } + + QualType ResultType; + if (!CreateAttributedLinAlgMatrixType( + S, CanonTy, Attr.getArgAsExpr(0), Attr.getArgAsExpr(1), + Attr.getArgAsExpr(2), Attr.getArgAsExpr(3), Attr.getArgAsExpr(4), + ResultType)) + return true; + + Type = ResultType; + return false; +} + +std::string +ConvertLinAlgMatrixComponentTypeToString(hlsl::DXIL::ComponentType CompType) { + switch (CompType) { + case DXIL::ComponentType::I1: + return "ComponentType::I1"; + case DXIL::ComponentType::I16: + return "ComponentType::I16"; + case DXIL::ComponentType::U16: + return "ComponentType::U16"; + case DXIL::ComponentType::I32: + return "ComponentType::I32"; + case DXIL::ComponentType::U32: + return "ComponentType::U32"; + case DXIL::ComponentType::I64: + return "ComponentType::I64"; + case DXIL::ComponentType::U64: + return "ComponentType::U64"; + case DXIL::ComponentType::F16: + return "ComponentType::F16"; + case DXIL::ComponentType::F32: + return "ComponentType::F32"; + case DXIL::ComponentType::F64: + return "ComponentType::F64"; + case DXIL::ComponentType::SNormF16: + return "ComponentType::SNormF16"; + case DXIL::ComponentType::UNormF16: + return "ComponentType::UNormF16"; + case DXIL::ComponentType::SNormF32: + return "ComponentType::SNormF32"; + case DXIL::ComponentType::UNormF32: + return "ComponentType::UNormF32"; + case DXIL::ComponentType::SNormF64: + return "ComponentType::SNormF64"; + case DXIL::ComponentType::UNormF64: + return "ComponentType::UNormF64"; + case DXIL::ComponentType::PackedS8x32: + return "ComponentType::PackedS8x32"; + case DXIL::ComponentType::PackedU8x32: + return "ComponentType::PackedU8x32"; + case DXIL::ComponentType::U8: + return "ComponentType::U8"; + case DXIL::ComponentType::I8: + return "ComponentType::I8"; + case DXIL::ComponentType::F8_E4M3: + return "ComponentType::F8_E4M3"; + case DXIL::ComponentType::F8_E5M2: + return "ComponentType::F8_E5M2"; + default: + llvm_unreachable("Unknown ComponentType"); + } +} + +std::string ConvertLinAlgMatrixUseToString(hlsl::DXIL::MatrixUse Use) { + switch (Use) { + case hlsl::DXIL::MatrixUse::A: + return "MatrixUse::A"; + case hlsl::DXIL::MatrixUse::B: + return "MatrixUse::B"; + case hlsl::DXIL::MatrixUse::Accumulator: + return "MatrixUse::Accumulator"; + default: + llvm_unreachable("Unknown MatrixUse"); + } +} + +std::string ConvertLinAlgMatrixScopeToString(hlsl::DXIL::MatrixScope Scope) { + switch (Scope) { + case hlsl::DXIL::MatrixScope::Thread: + return "MatrixScope::Thread"; + case hlsl::DXIL::MatrixScope::ThreadGroup: + return "MatrixScope::ThreadGroup"; + case hlsl::DXIL::MatrixScope::Wave: + return "MatrixScope::Wave"; + default: + llvm_unreachable("Unknown MatrixScope"); + } +} + } // namespace hlsl diff --git a/tools/clang/lib/Sema/SemaTemplate.cpp b/tools/clang/lib/Sema/SemaTemplate.cpp index 0edc1698a5..37b296aefd 100644 --- a/tools/clang/lib/Sema/SemaTemplate.cpp +++ b/tools/clang/lib/Sema/SemaTemplate.cpp @@ -4136,6 +4136,18 @@ bool UnnamedLocalNoLinkageFinder::VisitEnumType(const EnumType* T) { return VisitTagDecl(T->getDecl()); } +// HLSL Change Start +bool UnnamedLocalNoLinkageFinder::VisitAttributedLinAlgMatrixType( + const AttributedLinAlgMatrixType *) { + return false; +} + +bool UnnamedLocalNoLinkageFinder::VisitDependentAttributedLinAlgMatrixType( + const DependentAttributedLinAlgMatrixType *) { + return false; +} +// HLSL Change End + bool UnnamedLocalNoLinkageFinder::VisitTemplateTypeParmType( const TemplateTypeParmType*) { return false; diff --git a/tools/clang/lib/Sema/SemaType.cpp b/tools/clang/lib/Sema/SemaType.cpp index f08ae486b5..52c55df6ae 100644 --- a/tools/clang/lib/Sema/SemaType.cpp +++ b/tools/clang/lib/Sema/SemaType.cpp @@ -5774,6 +5774,7 @@ static bool isHLSLTypeAttr(AttributeList::Kind Kind) { case AttributeList::AT_HLSLUnorm: case AttributeList::AT_HLSLGloballyCoherent: case AttributeList::AT_HLSLReorderCoherent: + case AttributeList::AT_HLSLLinAlgMatrixAttributes: return true; default: // Only meant to catch attr handled by handleHLSLTypeAttr, ignore the rest @@ -5781,13 +5782,16 @@ static bool isHLSLTypeAttr(AttributeList::Kind Kind) { } } +// Return true on error static bool handleHLSLTypeAttr(TypeProcessingState &State, AttributeList &Attr, QualType &Type) { - // Return true on error + AttributeList::Kind Kind = Attr.getKind(); Sema &S = State.getSema(); - AttributeList::Kind Kind = Attr.getKind(); + // Handle LinAlg Matrix type attribute separatelly. + if (Kind == AttributeList::AT_HLSLLinAlgMatrixAttributes) + return hlsl::HandleLinAlgMatrixAttributes(S, Attr, Type); // Check for attributes on incorrect types if ((Kind == AttributeList::AT_HLSLColumnMajor || diff --git a/tools/clang/lib/Sema/TreeTransform.h b/tools/clang/lib/Sema/TreeTransform.h index 7b23823f72..2d2e692cd4 100644 --- a/tools/clang/lib/Sema/TreeTransform.h +++ b/tools/clang/lib/Sema/TreeTransform.h @@ -5501,6 +5501,81 @@ QualType TreeTransform::TransformAttributedType( return result; } +// HLSL Change Start +template +QualType TreeTransform::TransformAttributedLinAlgMatrixType( + TypeLocBuilder &TLB, AttributedLinAlgMatrixTypeLoc TL) { + const AttributedLinAlgMatrixType *OldTy = TL.getTypePtr(); + QualType ModifiedTy = getDerived().TransformType(OldTy->getWrappedType()); + if (ModifiedTy.isNull()) + return QualType(); + + QualType Result = TL.getType(); + if (getDerived().AlwaysRebuild() || ModifiedTy != OldTy->getWrappedType()) { + Result = SemaRef.Context.getAttributedLinAlgMatrixType( + ModifiedTy, OldTy->getComponentType(), OldTy->getRows(), + OldTy->getCols(), OldTy->getUse(), OldTy->getScope()); + } + + AttributedLinAlgMatrixTypeLoc NewTL = + TLB.push(Result); + NewTL.setNameLoc(TL.getNameLoc()); + + return Result; +} + +template +QualType TreeTransform::TransformDependentAttributedLinAlgMatrixType( + TypeLocBuilder &TLB, DependentAttributedLinAlgMatrixTypeLoc TL) { + const DependentAttributedLinAlgMatrixType *T = TL.getTypePtr(); + + QualType NewWrappedType = getDerived().TransformType(T->getWrappedType()); + if (NewWrappedType.isNull()) + return QualType(); + + ExprResult CompTy = getDerived().TransformExpr(T->getComponentTyExpr()); + ExprResult Rows = getDerived().TransformExpr(T->getRowsExpr()); + ExprResult Cols = getDerived().TransformExpr(T->getColsExpr()); + ExprResult Use = getDerived().TransformExpr(T->getUseExpr()); + ExprResult Scope = getDerived().TransformExpr(T->getScopeExpr()); + + if (CompTy.isInvalid() || Rows.isInvalid() || Cols.isInvalid() || + Use.isInvalid() || Scope.isInvalid()) + return QualType(); + + Expr *NewCompTyExpr = CompTy.get(); + Expr *NewRowsExpr = Rows.get(); + Expr *NewColsExpr = Cols.get(); + Expr *NewUseExpr = Use.get(); + Expr *NewScopeExpr = Scope.get(); + + QualType Result = TL.getType(); + if (getDerived().AlwaysRebuild() || NewWrappedType != T->getWrappedType() || + NewCompTyExpr != T->getComponentTyExpr() || + NewRowsExpr != T->getRowsExpr() || NewColsExpr != T->getColsExpr() || + NewUseExpr != T->getUseExpr() || NewScopeExpr != T->getScopeExpr()) { + if (!hlsl::CreateAttributedLinAlgMatrixType( + SemaRef, NewWrappedType, NewCompTyExpr, NewRowsExpr, NewColsExpr, + NewUseExpr, NewScopeExpr, Result)) + return QualType(); + } + + // Result might be dependent or not. + if (isa(Result)) { + DependentAttributedLinAlgMatrixTypeLoc NewTL = + TLB.push(Result); + NewTL.setNameLoc(TL.getNameLoc()); + } else { + AttributedLinAlgMatrixTypeLoc NewTL = + TLB.push(Result); + NewTL.setNameLoc(TL.getNameLoc()); + } + + return Result; +} + +// HLSL Change End + template QualType TreeTransform::TransformParenType(TypeLocBuilder &TLB, diff --git a/tools/clang/lib/Serialization/ASTReader.cpp b/tools/clang/lib/Serialization/ASTReader.cpp index 68367fc582..717d1db0ed 100644 --- a/tools/clang/lib/Serialization/ASTReader.cpp +++ b/tools/clang/lib/Serialization/ASTReader.cpp @@ -5609,6 +5609,12 @@ void TypeLocReader::VisitAttributedTypeLoc(AttributedTypeLoc TL) { } else if (TL.hasAttrEnumOperand()) TL.setAttrEnumOperandLoc(ReadSourceLocation(Record, Idx)); } +// HLSL Change Start +void TypeLocReader::VisitAttributedLinAlgMatrixTypeLoc( + AttributedLinAlgMatrixTypeLoc TL) { + TL.setSourceLocation(ReadSourceLocation(Record, Idx)); +} +// HLSL Change End void TypeLocReader::VisitTemplateTypeParmTypeLoc(TemplateTypeParmTypeLoc TL) { TL.setNameLoc(ReadSourceLocation(Record, Idx)); } diff --git a/tools/clang/lib/Serialization/ASTWriter.cpp b/tools/clang/lib/Serialization/ASTWriter.cpp index 7c58e81a23..cfa4bfa6d2 100644 --- a/tools/clang/lib/Serialization/ASTWriter.cpp +++ b/tools/clang/lib/Serialization/ASTWriter.cpp @@ -307,6 +307,25 @@ void ASTTypeWriter::VisitAttributedType(const AttributedType *T) { Code = TYPE_ATTRIBUTED; } +// HLSL Change Start +void ASTTypeWriter::VisitAttributedLinAlgMatrixType( + const AttributedLinAlgMatrixType *T) { + Writer.AddTypeRef(T->getWrappedType(), Record); + Record.push_back(static_cast(T->getComponentType())); + Record.push_back(T->getRows()); + Record.push_back(T->getColumns()); + Record.push_back(static_cast(T->getUse())); + Record.push_back(static_cast(T->getScope())); + Code = TYPE_HLSL_ATTRIBUTED_LINALG_MATRIX; +} + +void ASTTypeWriter::VisitDependentAttributedLinAlgMatrixType( + const DependentAttributedLinAlgMatrixType *T) { + // FIXME: Serialize this type (C++ only) + llvm_unreachable( + "Cannot serialize dependent HLSL attributed linear algebra matrix types"); +} +// HLSL Change End void ASTTypeWriter::VisitSubstTemplateTypeParmType( const SubstTemplateTypeParmType *T) { @@ -595,6 +614,10 @@ void TypeLocWriter::VisitAttributedTypeLoc(AttributedTypeLoc TL) { Writer.AddSourceLocation(TL.getAttrEnumOperandLoc(), Record); } } +void TypeLocWriter::VisitAttributedLinAlgMatrixTypeLoc( + AttributedLinAlgMatrixTypeLoc TL) { + Writer.AddSourceLocation(TL.getSourceLocation(), Record); +} void TypeLocWriter::VisitTemplateTypeParmTypeLoc(TemplateTypeParmTypeLoc TL) { Writer.AddSourceLocation(TL.getNameLoc(), Record); } diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/linalg-matrix-ast.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/linalg-matrix-ast.hlsl new file mode 100644 index 0000000000..bb69dda956 --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/linalg-matrix-ast.hlsl @@ -0,0 +1,57 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -I %hlsl_headers -T lib_6_10 -ast-dump %s FileCheck %s + +#include +using namespace dx::linalg; + +void f() { + Matrix mat1; + Matrix mat2; +} + +// CHECK: ClassTemplateDecl {{.*}} Matrix{{$}} +// CHECK-NEXT: NonTypeTemplateParmDecl {{.*}} 'ComponentEnum':'dx::linalg::ComponentType::ComponentEnum' ComponentTy +// CHECK-NEXT: NonTypeTemplateParmDecl {{.*}} 'int' M +// CHECK-NEXT: NonTypeTemplateParmDecl {{.*}} 'int' N +// CHECK-NEXT: NonTypeTemplateParmDecl {{.*}} 'MatrixUseEnum':'dx::linalg::MatrixUse::MatrixUseEnum' Use +// CHECK-NEXT: NonTypeTemplateParmDecl {{.*}} 'MatrixScopeEnum':'dx::linalg::MatrixScope::MatrixScopeEnum' Scope +// CHECK-NEXT: CXXRecordDecl {{.*}} class Matrix definition +// CHECK-NEXT: CXXRecordDecl {{.*}} Matrix +// CHECK-NEXT:TypeAliasDecl {{.*}} HandleT '__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentTy, M, N, Use, Scope)]]' +// CHECK-NEXT: FieldDecl {{.*}} __handle 'HandleT':'__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentTy, M, N, Use, Scope)]]' + +// CHECK-NEXT: ClassTemplateSpecializationDecl {{.*}} class Matrix definition +// CHECK-NEXT: TemplateArgument integral 4 +// CHECK-NEXT: TemplateArgument integral 4 +// CHECK-NEXT: TemplateArgument integral 5 +// CHECK-NEXT: TemplateArgument integral 1 +// CHECK-NEXT: TemplateArgument integral 2 +// CHECK-NEXT: CXXRecordDecl {{.*}} implicit class Matrix +// CHECK-NEXT: TypeAliasDecl {{.*}} HandleT '__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentType::I32, 4, 5, MatrixUse::B, MatrixScope::ThreadGroup)]]' +// CHECK-NEXT: FieldDecl {{.*}} __handle 'HandleT':'__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentType::I32, 4, 5, MatrixUse::B, MatrixScope::ThreadGroup)]]' + +// CHECK-NEXT: ClassTemplateSpecializationDecl {{.*}} class Matrix definition +// CHECK-NEXT: TemplateArgument integral 17 +// CHECK-NEXT: TemplateArgument integral 100 +// CHECK-NEXT: TemplateArgument integral 100 +// CHECK-NEXT: TemplateArgument integral 0 +// CHECK-NEXT: TemplateArgument integral 1 +// CHECK-NEXT: CXXRecordDecl {{.*}} implicit class Matrix +// CHECK-NEXT: TypeAliasDecl {{.*}} HandleT '__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentType::PackedS8x32,100, 100, MatrixUse::A, MatrixScope::Wave)]]' +// CHECK-NEXT: FieldDecl {{.*}} __handle 'HandleT':'__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentType::PackedS8x32, 100, 100, MatrixUse::A, MatrixScope::Wave)]]' + +// CHECK: FunctionDecl {{.*}} f 'void ()' + +// CHECK: VarDecl {{.*}} mat1 'Matrix': +// CHECK-SAME: 'dx::linalg::Matrix' + +// CHECK: VarDecl {{.*}} mat2 'Matrix': +// CHECK-SAME: 'dx::linalg::Matrix' diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/linalg-matrix-error.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/linalg-matrix-error.hlsl new file mode 100644 index 0000000000..9fafa16bd4 --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/linalg-matrix-error.hlsl @@ -0,0 +1,35 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -I %hlsl_headers -T lib_6_10 -verify %s + +#include +using namespace dx::linalg; + +// CHECK:: foo +void f() { + // expected-error@+1{{too few template arguments for class template 'Matrix'}} + Matrix mat1; + + // expected-error@+1{{non-type template argument of type 'literal string' must have an integral or enumeration type}} + Matrix mat2; + + Matrix mat3; + + Matrix mat4; + + // expected-error@+1{{cannot convert from 'Matrix' to 'Matrix'}} + mat3 = mat4; + + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::I32, 4, 5, MatrixUse::B, MatrixScope::ThreadGroup)]] mat8; + + // expected-error@+1 {{cannot initialize a variable of type '__builtin_LinAlgMatrix' with an lvalue of type '__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::I32, 4, 5, MatrixUse::B, MatrixScope::ThreadGroup)]]'}} + __builtin_LinAlgMatrix naked_mat = mat8; + + // ok + Matrix same_as_mat4 = mat4; + + // expected-error@+1{{cannot initialize a variable of type 'Matrix' with an lvalue of type 'Matrix'}} + Matrix different_mat = mat4; +} + +// expected-note@dx/linalg.h:*{{template is declared here}} +// expected-note@dx/linalg.h:*{{template parameter is declared here}} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/matrix-attributed-type-ast.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/matrix-attributed-type-ast.hlsl new file mode 100644 index 0000000000..c1dfb8502a --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/matrix-attributed-type-ast.hlsl @@ -0,0 +1,65 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -I %hlsl_headers -T lib_6_10 -enable-16bit-types -ast-dump %s FileCheck %s + +#include +using namespace dx::linalg; + +// CHECK: FunctionDecl {{.*}} f1 'void ()' +// CHECK: VarDecl {{.*}} mat1 '__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentType::I32, 4, 5, MatrixUse::B, MatrixScope::ThreadGroup)]]' + +void f1() { + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::I32, 4, 5, MatrixUse::B, MatrixScope::ThreadGroup)]] mat1; +} + +// CHECK: FunctionDecl {{.*}} f2 'void (__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentType::F32, 10, 20, MatrixUse::A, MatrixScope::Wave)]])' +// CHECK-NEXT: ParmVarDecl {{.*}} mat2 '__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentType::F32, 10, 20, MatrixUse::A, MatrixScope::Wave)]]' + +void f2(__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::F32, 10, 20, MatrixUse::A, MatrixScope::Wave)]] mat2) { +} + +// CHECK: TypedefDecl {{.*}} Mat10by20 '__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentType::F32, 10, 20, MatrixUse::A, MatrixScope::Wave)]]' +typedef __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::F32, 10, 20, MatrixUse::A, MatrixScope::Wave)]] Mat10by20; + +// CHECK: FunctionDecl {{.*}} f3 'Mat10by20 ()' +// CHECK: VarDecl {{.*}} mat3 'Mat10by20':'__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentType::F32, 10, 20, MatrixUse::A, MatrixScope::Wave)]]' +// CHECK: ReturnStmt +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'Mat10by20':'__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentType::F32, 10, 20, MatrixUse::A, MatrixScope::Wave)]]' +// CHECK-NEXT: DeclRefExpr {{.*}} 'Mat10by20':'__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentType::F32, 10, 20, MatrixUse::A, MatrixScope::Wave)]]' lvalue Var {{.*}} 'mat3' 'Mat10by20':'__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentType::F32, 10, 20, MatrixUse::A, MatrixScope::Wave)]]' + +Mat10by20 f3() { + Mat10by20 mat3; + return mat3; +} + +// CHECK: FunctionTemplateDecl {{.*}} fTemplate +// CHECK-NEXT: NonTypeTemplateParmDecl {{.*}} 'uint':'unsigned int' M +// CHECK-NEXT: NonTypeTemplateParmDecl {{.*}} 'uint':'unsigned int' N +// CHECK-NEXT: FunctionDecl {{.*}} fTemplate 'void ()' +// CHECK: VarDecl {{.*}} mat4 '__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentType::I16, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup)]]' +template +void fTemplate() { + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::I16, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup)]] mat4; +} + +// CHECK: FunctionDecl {{.*}} fTemplate 'void ()' +// CHECK-NEXT: TemplateArgument integral 3 +// CHECK-NEXT: TemplateArgument integral 4 +// CHECK: VarDecl {{.*}} mat4 '__builtin_LinAlgMatrix +// CHECK-SAME{LITERAL}: [[__LinAlgMatrix_Attributes(ComponentType::I16, 3, 4, MatrixUse::Accumulator, MatrixScope::ThreadGroup)]]' + +// CHECK: FunctionDecl {{.*}} f4 'void ()' +// CHECK: CallExpr {{.*}} +// CHECK: ImplicitCastExpr {{.*}} 'void (*)()' +// CHECK: DeclRefExpr {{.*}} 'void ()' lvalue Function {{.*}} 'fTemplate' 'void ()' (FunctionTemplate {{.*}} 'fTemplate') +void f4() { + fTemplate<3, 4>(); +} diff --git a/tools/clang/test/SemaHLSL/hlsl/linalg/matrix-attributed-type-error.hlsl b/tools/clang/test/SemaHLSL/hlsl/linalg/matrix-attributed-type-error.hlsl new file mode 100644 index 0000000000..3bd487f66f --- /dev/null +++ b/tools/clang/test/SemaHLSL/hlsl/linalg/matrix-attributed-type-error.hlsl @@ -0,0 +1,31 @@ +// REQUIRES: dxil-1-10 +// RUN: %dxc -I %hlsl_headers -T lib_6_10 -enable-16bit-types -verify %s + +#include +using namespace dx::linalg; + +void f() { + // expected-error@+1 {{matrix attributes can only be applied to __builtin_LinAlgMatrix}} + int [[__LinAlgMatrix_Attributes(ComponentType::I32, 4, 5, MatrixUse::B, MatrixScope::ThreadGroup)]] mat1; + + // expected-error@+1 {{'__LinAlgMatrix_Attributes' attribute requires exactly 5 arguments}} + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(10)]] mat2; + // expected-error@+1 {{'__LinAlgMatrix_Attributes' attribute requires exactly 5 arguments}} + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 2, 3, 4, 5, 6)]] mat3; + + // ok + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(10, 4, 5, 2, 1)]] mat4; + + // expected-error@+1 {{argument is not an integer or enumeration}} + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(10.56, 4, 5, 2, 1)]] mat5; + + // expected-error@+1 {{argument is not an integer}} + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::I32, "str", 5, 2, 1)]] mat6; + // expected-error@+3 {{cannot convert from '__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::I32, 4, 5, MatrixUse::B, MatrixScope::ThreadGroup)]]' to '__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::I32, 4, 5, MatrixUse::A, MatrixScope::ThreadGroup)]]'}} + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::I32, 4, 5, MatrixUse::A, MatrixScope::ThreadGroup)]] mat7; + __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::I32, 4, 5, MatrixUse::B, MatrixScope::ThreadGroup)]] mat8; + mat7 = mat8; + + // expected-error@+1 {{cannot initialize a variable of type '__builtin_LinAlgMatrix' with an lvalue of type '__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::I32, 4, 5, MatrixUse::B, MatrixScope::ThreadGroup)]]'}} + __builtin_LinAlgMatrix naked_mat = mat8; +} diff --git a/tools/clang/tools/libclang/CIndex.cpp b/tools/clang/tools/libclang/CIndex.cpp index f308848c98..45d24125f7 100644 --- a/tools/clang/tools/libclang/CIndex.cpp +++ b/tools/clang/tools/libclang/CIndex.cpp @@ -1701,6 +1701,8 @@ DEFAULT_TYPELOC_IMPL(FunctionProto, FunctionType) DEFAULT_TYPELOC_IMPL(FunctionNoProto, FunctionType) DEFAULT_TYPELOC_IMPL(Record, TagType) DEFAULT_TYPELOC_IMPL(Enum, TagType) +DEFAULT_TYPELOC_IMPL(AttributedLinAlgMatrix, Type) +DEFAULT_TYPELOC_IMPL(DependentAttributedLinAlgMatrix, Type) DEFAULT_TYPELOC_IMPL(SubstTemplateTypeParm, Type) DEFAULT_TYPELOC_IMPL(SubstTemplateTypeParmPack, Type) DEFAULT_TYPELOC_IMPL(Auto, Type) diff --git a/utils/hct/hctdb.py b/utils/hct/hctdb.py index 61b13c72bb..23aaa92bda 100644 --- a/utils/hct/hctdb.py +++ b/utils/hct/hctdb.py @@ -9602,6 +9602,7 @@ def __init__(self, intrinsic_defs, opcode_data): "GroupNodeOutputRecords": "LICOMPTYPE_GROUP_NODE_OUTPUT_RECORDS", "ThreadNodeOutputRecords": "LICOMPTYPE_THREAD_NODE_OUTPUT_RECORDS", "DxHitObject": "LICOMPTYPE_HIT_OBJECT", + "LinAlgMatrix": "LICOMPTYPE_LINALG_MATRIX", "VkBufferPointer": "LICOMPTYPE_VK_BUFFER_POINTER", "RayQuery": "LICOMPTYPE_RAY_QUERY", "LinAlg": "LICOMPTYPE_LINALG", @@ -9667,7 +9668,7 @@ def load_intrinsics(self, intrinsic_defs): acceleration_struct | ray_desc | RayQuery | DxHitObject | Node\w* | RWNode\w* | EmptyNode\w* | AnyNodeOutput\w* | NodeOutputRecord\w* | GroupShared\w* | - VkBufferPointer + VkBufferPointer | LinAlgMatrix $)""", flags=re.VERBOSE, )