Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/dxc/DXIL/DxilConstants.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,18 @@ enum class ComponentType : uint32_t {
LastEntry
};

enum class MatrixUse : uint32_t {
A = 0,
B = 1,
Accumulator = 2,
};

enum class MatrixScope : uint32_t {
Thread = 0,
Wave = 1,
ThreadGroup = 2,
};

// Must match D3D_INTERPOLATION_MODE
enum class InterpolationMode : uint8_t {
Undefined = 0,
Expand Down
12 changes: 7 additions & 5 deletions include/dxc/dxcapi.internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,19 @@ enum LEGAL_INTRINSIC_COMPTYPES {
LICOMPTYPE_HIT_OBJECT = 51,
LICOMPTYPE_RAY_QUERY = 52,

LICOMPTYPE_LINALG = 53, // f32, partial-precision-f32, f16,
LICOMPTYPE_LINALG_MATRIX = 53,

LICOMPTYPE_LINALG = 54, // f32, partial-precision-f32, f16,
// i32, i16, u32, u16,
// int8_4packed, uint8_4packed

LICOMPTYPE_BUILTIN_TRIANGLE_POSITIONS = 54,
LICOMPTYPE_BUILTIN_TRIANGLE_POSITIONS = 55,

#ifdef ENABLE_SPIRV_CODEGEN
LICOMPTYPE_VK_BUFFER_POINTER = 55,
LICOMPTYPE_COUNT = 56
LICOMPTYPE_VK_BUFFER_POINTER = 56,
LICOMPTYPE_COUNT = 57
#else
LICOMPTYPE_COUNT = 55
LICOMPTYPE_COUNT = 56
#endif
};

Expand Down
20 changes: 20 additions & 0 deletions tools/clang/include/clang/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef LLVM_CLANG_AST_ASTCONTEXT_H
#define LLVM_CLANG_AST_ASTCONTEXT_H

#include "dxc/DXIL/DxilConstants.h"
#include "clang/AST/ASTTypeTraits.h"
#include "clang/AST/CanonicalType.h"
#include "clang/AST/CommentCommandTraits.h"
Expand Down Expand Up @@ -130,6 +131,12 @@ class ASTContext : public RefCountedBase<ASTContext> {
mutable llvm::FoldingSet<AtomicType> AtomicTypes;
llvm::FoldingSet<AttributedType> AttributedTypes;

// HLSL Change Start
llvm::FoldingSet<AttributedLinAlgMatrixType> AttrLinAlgMatrixTypes;
llvm::FoldingSet<DependentAttributedLinAlgMatrixType>
DepAttrLinAlgMatrixTypes;
// HLSL Change End

mutable llvm::FoldingSet<QualifiedTemplateName> QualifiedTemplateNames;
mutable llvm::FoldingSet<DependentTemplateName> DependentTemplateNames;
mutable llvm::FoldingSet<SubstTemplateTemplateParmStorage>
Expand Down Expand Up @@ -1156,6 +1163,19 @@ class ASTContext : public RefCountedBase<ASTContext> {
QualType modifiedType,
QualType equivalentType);

// HLSL Change Start
QualType getAttributedLinAlgMatrixType(QualType WrappedTy,
hlsl::DXIL::ComponentType ComponentTy,
size_t Rows, size_t Cols,
hlsl::DXIL::MatrixUse Use,
hlsl::DXIL::MatrixScope Scope);

QualType getDependentAttributedLinAlgMatrixType(QualType WrappedTy,
Expr *ComponentTyExpr,
Expr *RowsExpr,
Expr *ColsExpr, Expr *UseExpr,
Expr *ScopeExpr);
// HLSL Change End
QualType getSubstTemplateTypeParmType(const TemplateTypeParmType *Replaced,
QualType Replacement) const;
QualType getSubstTemplateTypeParmPackType(
Expand Down
19 changes: 19 additions & 0 deletions tools/clang/include/clang/AST/DataRecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,21 @@ DEF_TRAVERSE_TYPE(AutoType, { TRY_TO(TraverseType(T->getDeducedType())); })

DEF_TRAVERSE_TYPE(RecordType, {})
DEF_TRAVERSE_TYPE(EnumType, {})

// HLSL Change Start
DEF_TRAVERSE_TYPE(AttributedLinAlgMatrixType,
{ TRY_TO(TraverseType(T->getWrappedType())); })

DEF_TRAVERSE_TYPE(DependentAttributedLinAlgMatrixType, {
TRY_TO(TraverseType(T->getWrappedType()));
TRY_TO(TraverseStmt(T->getComponentTyExpr()));
TRY_TO(TraverseStmt(T->getRowsExpr()));
TRY_TO(TraverseStmt(T->getColsExpr()));
TRY_TO(TraverseStmt(T->getUseExpr()));
TRY_TO(TraverseStmt(T->getScopeExpr()));
})
// HLSL Change End

DEF_TRAVERSE_TYPE(TemplateTypeParmType, {})
DEF_TRAVERSE_TYPE(SubstTemplateTypeParmType, {})
DEF_TRAVERSE_TYPE(SubstTemplateTypeParmPackType, {})
Expand Down Expand Up @@ -1119,6 +1134,10 @@ DEF_TRAVERSE_TYPELOC(AutoType, {

DEF_TRAVERSE_TYPELOC(RecordType, {})
DEF_TRAVERSE_TYPELOC(EnumType, {})
// HLSL Change Start
DEF_TRAVERSE_TYPELOC(AttributedLinAlgMatrixType, {})
DEF_TRAVERSE_TYPELOC(DependentAttributedLinAlgMatrixType, {})
// HLSL Change End
DEF_TRAVERSE_TYPELOC(TemplateTypeParmType, {})
DEF_TRAVERSE_TYPELOC(SubstTemplateTypeParmType, {})
DEF_TRAVERSE_TYPELOC(SubstTemplateTypeParmPackType, {})
Expand Down
19 changes: 19 additions & 0 deletions tools/clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,20 @@ DEF_TRAVERSE_TYPE(InjectedClassNameType, {})
DEF_TRAVERSE_TYPE(AttributedType,
{ TRY_TO(TraverseType(T->getModifiedType())); })

// HLSL Change Start
DEF_TRAVERSE_TYPE(AttributedLinAlgMatrixType,
{ TRY_TO(TraverseType(T->getWrappedType())); })

DEF_TRAVERSE_TYPE(DependentAttributedLinAlgMatrixType, {
TRY_TO(TraverseType(T->getWrappedType()));
TRY_TO(TraverseStmt(T->getComponentTyExpr()));
TRY_TO(TraverseStmt(T->getRowsExpr()));
TRY_TO(TraverseStmt(T->getColsExpr()));
TRY_TO(TraverseStmt(T->getUseExpr()));
TRY_TO(TraverseStmt(T->getScopeExpr()));
})
// HLSL Change End

DEF_TRAVERSE_TYPE(ParenType, { TRY_TO(TraverseType(T->getInnerType())); })

DEF_TRAVERSE_TYPE(ElaboratedType, {
Expand Down Expand Up @@ -1206,6 +1220,11 @@ DEF_TRAVERSE_TYPELOC(ParenType, { TRY_TO(TraverseTypeLoc(TL.getInnerLoc())); })
DEF_TRAVERSE_TYPELOC(AttributedType,
{ TRY_TO(TraverseTypeLoc(TL.getModifiedLoc())); })

// HLSL Change Start
DEF_TRAVERSE_TYPELOC(AttributedLinAlgMatrixType, {})
DEF_TRAVERSE_TYPELOC(DependentAttributedLinAlgMatrixType, {})
// HLSL Change End

DEF_TRAVERSE_TYPELOC(ElaboratedType, {
if (TL.getQualifierLoc()) {
TRY_TO(TraverseNestedNameSpecifierLoc(TL.getQualifierLoc()));
Expand Down
120 changes: 119 additions & 1 deletion tools/clang/include/clang/AST/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef LLVM_CLANG_AST_TYPE_H
#define LLVM_CLANG_AST_TYPE_H

#include "dxc/DXIL/DxilConstants.h"
#include "clang/AST/NestedNameSpecifier.h"
#include "clang/AST/TemplateName.h"
#include "clang/Basic/AddressSpaces.h"
Expand Down Expand Up @@ -1699,8 +1700,15 @@ class Type : public ExtQualsTypeCommonBase {

bool isOpenCLSpecificType() const; // Any OpenCL specific type

// HLSL Change Start
bool isLinAlgMatrixType() const; // HLSL __builtin_LinAlgMatrix

bool isAttributedLinAlgMatrixType()
const; // HLSL attributed __builtin_LinAlgMatrix
bool isDependentAttributedLinAlgMatrixType()
const; // HLSL attributed __builtin_LinAlgMatrix with dependent
// parameters
// HLSL Change End

/// Determines if this type, which must satisfy
/// isObjCLifetimeType(), is implicitly __unsafe_unretained rather
/// than implicitly __strong.
Expand Down Expand Up @@ -3736,6 +3744,108 @@ class AttributedType : public Type, public llvm::FoldingSetNode {
}
};

// HLSL Change Start

class AttributedLinAlgMatrixType : public Type, public llvm::FoldingSetNode {
friend class ASTContext; // ASTContext creates these

QualType WrappedType; // should be __builtin_LinAlgMatrix
hlsl::DXIL::ComponentType ComponentTy;
size_t Rows, Cols;
hlsl::DXIL::MatrixUse Use;
hlsl::DXIL::MatrixScope Scope;

AttributedLinAlgMatrixType(QualType WrappedTy,
hlsl::DXIL::ComponentType ComponentTy, size_t Rows,
size_t Cols, hlsl::DXIL::MatrixUse Use,
hlsl::DXIL::MatrixScope Scope)
: Type(AttributedLinAlgMatrix, QualType(), /*Dependent*/ false,
/*InstantiationDependent*/ false, /*VariablyModified*/ false,
/*ContainsUnexpandedParameterPack*/ false),
WrappedType(WrappedTy), ComponentTy(ComponentTy), Rows(Rows),
Cols(Cols), Use(Use), Scope(Scope) {}

public:
QualType getWrappedType() const { return WrappedType; }

hlsl::DXIL::ComponentType getComponentType() const { return ComponentTy; }
size_t getRows() const { return Rows; }
size_t getCols() const { return Cols; }
hlsl::DXIL::MatrixUse getUse() const { return Use; }
hlsl::DXIL::MatrixScope getScope() const { return Scope; }

void appendMangledAttributes(llvm::raw_ostream &OS) const;

bool isSugared() const { return false; }
QualType desugar() const { return QualType(this, 0); }

void Profile(llvm::FoldingSetNodeID &ID) {
Profile(ID, WrappedType, ComponentTy, Rows, Cols, Use, Scope);
}

static void Profile(llvm::FoldingSetNodeID &ID, QualType WrappedTy,
hlsl::DXIL::ComponentType ComponentTy, size_t Rows,
size_t Cols, hlsl::DXIL::MatrixUse Use,
hlsl::DXIL::MatrixScope Scope) {
ID.AddPointer(WrappedTy.getAsOpaquePtr());
ID.AddInteger(static_cast<uint32_t>(ComponentTy));
ID.AddInteger(static_cast<uint32_t>(Rows));
ID.AddInteger(static_cast<uint32_t>(Cols));
ID.AddInteger(static_cast<uint32_t>(Use));
ID.AddInteger(static_cast<uint32_t>(Scope));
}

static bool classof(const Type *T) {
return T->getTypeClass() == AttributedLinAlgMatrix;
}
};

class DependentAttributedLinAlgMatrixType : public Type,
public llvm::FoldingSetNode {
const ASTContext &Context;
QualType WrappedType; // should be __builtin_LinAlgMatrix
Expr *ComponentTyExpr;
Expr *RowsExpr;
Expr *ColsExpr;
Expr *UseExpr;
Expr *ScopeExpr;

DependentAttributedLinAlgMatrixType(const ASTContext &Context,
QualType WrappedType,
Expr *ComponentTyExpr, Expr *RowsExpr,
Expr *ColsExpr, Expr *UseExpr,
Expr *ScopeExpr);

friend class ASTContext;

public:
QualType getWrappedType() const { return WrappedType; }
Expr *getComponentTyExpr() const { return ComponentTyExpr; }
Expr *getRowsExpr() const { return RowsExpr; }
Expr *getColsExpr() const { return ColsExpr; }
Expr *getUseExpr() const { return UseExpr; }
Expr *getScopeExpr() const { return ScopeExpr; }

bool isSugared() const { return false; }
QualType desugar() const { return QualType(this, 0); }

static bool classof(const Type *T) {
return T->getTypeClass() == DependentAttributedLinAlgMatrix;
}

void Profile(llvm::FoldingSetNodeID &ID) {
Profile(ID, Context, getWrappedType(), getComponentTyExpr(), getRowsExpr(),
getColsExpr(), getUseExpr(), getScopeExpr());
}

static void Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
QualType WrappedType, Expr *ComponentTyExpr,
Expr *RowsExpr, Expr *ColsExpr, Expr *UseExpr,
Expr *ScopeExpr);
};

// HLSL Change End

class TemplateTypeParmType : public Type, public llvm::FoldingSetNode {
// Helper data collector for canonical types.
struct CanonicalTTPTInfo {
Expand Down Expand Up @@ -5426,6 +5536,14 @@ inline bool Type::isEventT() const {
inline bool Type::isLinAlgMatrixType() const {
return isSpecificBuiltinType(BuiltinType::LinAlgMatrix);
}

inline bool Type::isAttributedLinAlgMatrixType() const {
return isa<AttributedLinAlgMatrixType>(this);
}

inline bool Type::isDependentAttributedLinAlgMatrixType() const {
return isa<DependentAttributedLinAlgMatrixType>(this);
}
// HLSL Change Ends

inline bool Type::isImageType() const {
Expand Down
11 changes: 11 additions & 0 deletions tools/clang/include/clang/AST/TypeLoc.h
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,17 @@ class AttributedTypeLoc : public ConcreteTypeLoc<UnqualTypeLoc,
}
};

// HLSL Change Start
class AttributedLinAlgMatrixTypeLoc
: public InheritingConcreteTypeLoc<TypeSpecTypeLoc,
AttributedLinAlgMatrixTypeLoc,
AttributedLinAlgMatrixType> {};

class DependentAttributedLinAlgMatrixTypeLoc
: public InheritingConcreteTypeLoc<TypeSpecTypeLoc,
DependentAttributedLinAlgMatrixTypeLoc,
DependentAttributedLinAlgMatrixType> {};
// HLSL Change End

struct ObjCObjectTypeLocInfo {
SourceLocation TypeArgsLAngleLoc;
Expand Down
4 changes: 4 additions & 0 deletions tools/clang/include/clang/AST/TypeNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ TYPE(Record, TagType)
TYPE(Enum, TagType)
NON_CANONICAL_TYPE(Elaborated, Type)
NON_CANONICAL_TYPE(Attributed, Type)
// HLSL Change Start
TYPE(AttributedLinAlgMatrix, Type)
DEPENDENT_TYPE(DependentAttributedLinAlgMatrix, Type)
// HLSL Change End
DEPENDENT_TYPE(TemplateTypeParm, Type)
NON_CANONICAL_TYPE(SubstTemplateTypeParm, Type)
DEPENDENT_TYPE(SubstTemplateTypeParmPack, Type)
Expand Down
8 changes: 8 additions & 0 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def Borland : LangOpt<"Borland">;
def CUDA : LangOpt<"CUDA">;
def COnly : LangOpt<"CPlusPlus", 1>;
def SPIRV : LangOpt<"SPIRV">; // SPIRV Change
def HLSL : LangOpt<"HLSL">; // HLSL Change

// Defines targets for target-specific attributes. The list of strings should
// specify architectures for which the target applies, based off the ArchType
Expand Down Expand Up @@ -1228,6 +1229,13 @@ def HLSLUnboundedSparseNodes : InheritableParamAttr {
let Documentation = [Undocumented];
}

def HLSLLinAlgMatrixAttributes : TypeAttr {
let Spellings = [CXX11<"", "__LinAlgMatrix_Attributes", 2015>];
let LangOpts = [HLSL];
let Args = [ExprArgument<"ComponentTy">, ExprArgument<"M">, ExprArgument<"N">,
ExprArgument<"Use">, ExprArgument<"Scope">];
let Documentation = [Undocumented];
}
// HLSL Change Ends

// SPIRV Change Starts
Expand Down
13 changes: 13 additions & 0 deletions tools/clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -8042,6 +8042,19 @@ def err_hlsl_linalg_matrix_dim_must_be_greater_than_zero: Error<
def err_hlsl_linalg_matrix_layout_invalid : Error<
"matrix layout %0 is not valid, must be in the range [%1, %2]">;

// SM 6.10 Linear Algebra Operations
def err_hlsl_linalg_matrix_attribute_arg_not_int_or_enum
: Error<"argument is not an integer%select{| or enumeration}0">;
def err_hlsl_linalg_matrix_attribute_arg_not_constant_value
: Error<"matrix attributes argument %1 is not a constant value">;
def err_hlsl_linalg_matrix_invalid_enum_attribute_value
: Error<"matrix attribute %0 has invalid value %1, must be in the range "
"[%2, %3]">;
def err_hlsl_linalg_matrix_attribute_on_invalid_type
: Error<"matrix attributes can only be applied to %0">;
def err_hlsl_linalg_attributed_matrix_required
: Error<"argument type must be linear algebra matrix with attributes">;

def err_hlsl_linalg_mul_muladd_output_vector_size_not_equal_to_matrix_M : Error<
"output vector length must be equal to Matrix M dimension in a linalg Mul/MulAdd operation">;
def err_hlsl_linalg_mul_muladd_unpacked_input_vector_size_not_equal_to_matrix_K : Error<
Expand Down
Loading
Loading