diff --git a/Analysis/src/AstJsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp index a964c78..f2943c4 100644 --- a/Analysis/src/AstJsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -752,6 +752,7 @@ struct AstJsonEncoder : public AstVisitor if (node->superName) write("superName", *node->superName); PROP(props); + PROP(indexer); }); } diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 821f6c2..07dba92 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -22,6 +22,7 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauMagicTypes); +LUAU_FASTFLAG(LuauParseDeclareClassIndexer); namespace Luau { @@ -1157,6 +1158,23 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareC scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; + if (FFlag::LuauParseDeclareClassIndexer && declaredClass->indexer) + { + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(declaredClass->indexer->location); + } + else + { + ctv->indexer = TableIndexer{ + resolveType(scope, declaredClass->indexer->indexType, /* inTypeArguments */ false), + resolveType(scope, declaredClass->indexer->resultType, /* inTypeArguments */ false), + }; + } + } + for (const AstDeclaredClassProp& prop : declaredClass->props) { Name propName(prop.name.value); diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index dba9547..3a1217b 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -13,6 +13,8 @@ #include +LUAU_FASTFLAG(LuauParseDeclareClassIndexer); + static char* allocateString(Luau::Allocator& allocator, std::string_view contents) { char* result = (char*)allocator.allocate(contents.size() + 1); @@ -227,7 +229,17 @@ public: idx++; } - return allocator->alloc(Location(), props); + AstTableIndexer* indexer = nullptr; + if (FFlag::LuauParseDeclareClassIndexer && ctv.indexer) + { + RecursionCounter counter(&count); + + indexer = allocator->alloc(); + indexer->indexType = Luau::visit(*this, ctv.indexer->indexType->ty); + indexer->resultType = Luau::visit(*this, ctv.indexer->indexResultType->ty); + } + + return allocator->alloc(Location(), props, indexer); } AstType* operator()(const FunctionType& ftv) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 5127feb..a3d9170 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -41,6 +41,7 @@ LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckClassTypeIndexers, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) +LUAU_FASTFLAG(LuauParseDeclareClassIndexer) namespace Luau { @@ -1757,6 +1758,9 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& if (!ctv->metatable) ice("No metatable for declared class"); + if (const auto& indexer = declaredClass.indexer; FFlag::LuauParseDeclareClassIndexer && indexer) + ctv->indexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); + TableType* metatable = getMutable(*ctv->metatable); for (const AstDeclaredClassProp& prop : declaredClass.props) { diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index a486ad0..f9f9ab4 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -801,12 +801,20 @@ struct AstDeclaredClassProp bool isMethod = false; }; +struct AstTableIndexer +{ + AstType* indexType; + AstType* resultType; + Location location; +}; + class AstStatDeclareClass : public AstStat { public: LUAU_RTTI(AstStatDeclareClass) - AstStatDeclareClass(const Location& location, const AstName& name, std::optional superName, const AstArray& props); + AstStatDeclareClass(const Location& location, const AstName& name, std::optional superName, const AstArray& props, + AstTableIndexer* indexer = nullptr); void visit(AstVisitor* visitor) override; @@ -814,6 +822,7 @@ public: std::optional superName; AstArray props; + AstTableIndexer* indexer; }; class AstType : public AstNode @@ -862,13 +871,6 @@ struct AstTableProp AstType* type; }; -struct AstTableIndexer -{ - AstType* indexType; - AstType* resultType; - Location location; -}; - class AstTypeTable : public AstType { public: diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index d2c552a..3c87e36 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -714,12 +714,13 @@ void AstStatDeclareFunction::visit(AstVisitor* visitor) } } -AstStatDeclareClass::AstStatDeclareClass( - const Location& location, const AstName& name, std::optional superName, const AstArray& props) +AstStatDeclareClass::AstStatDeclareClass(const Location& location, const AstName& name, std::optional superName, + const AstArray& props, AstTableIndexer* indexer) : AstStat(ClassIndex(), location) , name(name) , superName(superName) , props(props) + , indexer(indexer) { } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 7cae609..cc5d7b3 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -13,6 +13,7 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) +LUAU_FASTFLAGVARIABLE(LuauParseDeclareClassIndexer, false) #define ERROR_INVALID_INTERP_DOUBLE_BRACE "Double braces are not permitted within interpolated strings. Did you mean '\\{'?" @@ -877,6 +878,7 @@ AstStat* Parser::parseDeclaration(const Location& start) } TempVector props(scratchDeclaredClassProps); + AstTableIndexer* indexer = nullptr; while (lexer.current().type != Lexeme::ReservedEnd) { @@ -885,7 +887,8 @@ AstStat* Parser::parseDeclaration(const Location& start) { props.push_back(parseDeclaredClassMethod()); } - else if (lexer.current().type == '[') + else if (lexer.current().type == '[' && (!FFlag::LuauParseDeclareClassIndexer || lexer.lookahead().type == Lexeme::RawString || + lexer.lookahead().type == Lexeme::QuotedString)) { const Lexeme begin = lexer.current(); nextLexeme(); // [ @@ -904,6 +907,22 @@ AstStat* Parser::parseDeclaration(const Location& start) else report(begin.location, "String literal contains malformed escape sequence"); } + else if (lexer.current().type == '[' && FFlag::LuauParseDeclareClassIndexer) + { + if (indexer) + { + // maybe we don't need to parse the entire badIndexer... + // however, we either have { or [ to lint, not the entire table type or the bad indexer. + AstTableIndexer* badIndexer = parseTableIndexer(); + + // we lose all additional indexer expressions from the AST after error recovery here + report(badIndexer->location, "Cannot have more than one class indexer"); + } + else + { + indexer = parseTableIndexer(); + } + } else { Name propName = parseName("property name"); @@ -916,7 +935,7 @@ AstStat* Parser::parseDeclaration(const Location& start) Location classEnd = lexer.current().location; nextLexeme(); // skip past `end` - return allocator.alloc(Location(classStart, classEnd), className.name, superName, copy(props)); + return allocator.alloc(Location(classStart, classEnd), className.name, superName, copy(props), indexer); } else if (std::optional globalName = parseNameOpt("global variable name")) { diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index 82577be..a264d0e 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -432,11 +432,11 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") REQUIRE(2 == root->body.size); std::string_view expected1 = - R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","nameLocation":"2,18 - 2,24","parameters":[]}},{"name":"method","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeFunction","location":"3,21 - 4,11","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","nameLocation":"3,39 - 3,45","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","nameLocation":"3,48 - 3,54","parameters":[]}]}}}]})"; + R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","nameLocation":"2,18 - 2,24","parameters":[]}},{"name":"method","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeFunction","location":"3,21 - 4,11","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","nameLocation":"3,39 - 3,45","parameters":[]}]},"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","nameLocation":"3,48 - 3,54","parameters":[]}]}}}],"indexer":null})"; CHECK(toJson(root->body.data[0]) == expected1); std::string_view expected2 = - R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","nameLocation":"7,19 - 7,25","parameters":[]}}]})"; + R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","nameLocation":"7,19 - 7,25","parameters":[]}}],"indexer":null})"; CHECK(toJson(root->body.data[1]) == expected2); } diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 1335b6f..a8738ac 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -54,7 +54,8 @@ TEST_SUITE_BEGIN("AllocatorTests"); TEST_CASE("allocator_can_be_moved") { Counter* c = nullptr; - auto inner = [&]() { + auto inner = [&]() + { Luau::Allocator allocator; c = allocator.alloc(); Luau::Allocator moved{std::move(allocator)}; @@ -921,7 +922,8 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_double_brace_mid") TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace") { - auto columnOfEndBraceError = [this](const char* code) { + auto columnOfEndBraceError = [this](const char* code) + { try { parse(code); @@ -1882,6 +1884,44 @@ TEST_CASE_FIXTURE(Fixture, "class_method_properties") CHECK_EQ(2, klass2->props.size); } +TEST_CASE_FIXTURE(Fixture, "class_indexer") +{ + ScopedFastFlag LuauParseDeclareClassIndexer("LuauParseDeclareClassIndexer", true); + + AstStatBlock* stat = parseEx(R"( + declare class Foo + prop: boolean + [string]: number + end + )") + .root; + + REQUIRE_EQ(stat->body.size, 1); + + AstStatDeclareClass* declaredClass = stat->body.data[0]->as(); + REQUIRE(declaredClass); + REQUIRE(declaredClass->indexer); + REQUIRE(declaredClass->indexer->indexType->is()); + CHECK(declaredClass->indexer->indexType->as()->name == "string"); + REQUIRE(declaredClass->indexer->resultType->is()); + CHECK(declaredClass->indexer->resultType->as()->name == "number"); + + const ParseResult p1 = matchParseError(R"( + declare class Foo + [string]: number + -- can only have one indexer + [number]: number + end + )", + "Cannot have more than one class indexer"); + + REQUIRE_EQ(1, p1.root->body.size); + + AstStatDeclareClass* klass = p1.root->body.data[0]->as(); + REQUIRE(klass != nullptr); + CHECK(klass->indexer); +} + TEST_CASE_FIXTURE(Fixture, "parse_variadics") { //clang-format off @@ -2347,7 +2387,8 @@ public: TEST_CASE_FIXTURE(Fixture, "recovery_of_parenthesized_expressions") { - auto checkAstEquivalence = [this](const char* codeWithErrors, const char* code) { + auto checkAstEquivalence = [this](const char* codeWithErrors, const char* code) + { try { parse(codeWithErrors); @@ -2367,7 +2408,8 @@ TEST_CASE_FIXTURE(Fixture, "recovery_of_parenthesized_expressions") CHECK_EQ(counterWithErrors.count, counter.count); }; - auto checkRecovery = [this, checkAstEquivalence](const char* codeWithErrors, const char* code, unsigned expectedErrorCount) { + auto checkRecovery = [this, checkAstEquivalence](const char* codeWithErrors, const char* code, unsigned expectedErrorCount) + { try { parse(codeWithErrors); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index d679975..0ca9bd7 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -394,6 +394,36 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") CHECK_EQ(toString(requireType("y")), "string"); } + +TEST_CASE_FIXTURE(Fixture, "class_definition_indexer") +{ + ScopedFastFlag LuauParseDeclareClassIndexer("LuauParseDeclareClassIndexer", true); + ScopedFastFlag LuauTypecheckClassTypeIndexers("LuauTypecheckClassTypeIndexers", true); + + loadDefinition(R"( + declare class Foo + [number]: string + end + )"); + + CheckResult result = check(R"( + local x: Foo + local y = x[1] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const ClassType* ctv = get(requireType("x")); + REQUIRE(ctv != nullptr); + + REQUIRE(bool(ctv->indexer)); + + CHECK_EQ(*ctv->indexer->indexType, *builtinTypes->numberType); + CHECK_EQ(*ctv->indexer->indexResultType, *builtinTypes->stringType); + + CHECK_EQ(toString(requireType("y")), "string"); +} + TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") { unfreeze(frontend.globals.globalTypes);