diff --git a/Cargo.toml b/Cargo.toml index 3b4bdb4..a9ed4e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "luau0-src" -version = "0.5.4+luau567" +version = "0.5.5+luau571" authors = ["Aleksandr Orlenko "] edition = "2021" repository = "https://github.com/khvzak/luau-src-rs" diff --git a/luau/Ast/include/Luau/Parser.h b/luau/Ast/include/Luau/Parser.h index 8b7eb73..0b9d8c4 100644 --- a/luau/Ast/include/Luau/Parser.h +++ b/luau/Ast/include/Luau/Parser.h @@ -123,13 +123,13 @@ private: // return [explist] AstStat* parseReturn(); - // type Name `=' typeannotation + // type Name `=' Type AstStat* parseTypeAlias(const Location& start, bool exported); AstDeclaredClassProp parseDeclaredClassMethod(); - // `declare global' Name: typeannotation | - // `declare function' Name`(' [parlist] `)' [`:` TypeAnnotation] + // `declare global' Name: Type | + // `declare function' Name`(' [parlist] `)' [`:` Type] AstStat* parseDeclaration(const Location& start); // varlist `=' explist @@ -140,7 +140,7 @@ private: std::pair> prepareFunctionArguments(const Location& start, bool hasself, const TempVector& args); - // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` TypeAnnotation] + // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type] // funcbody ::= funcbodyhead block end std::pair parseFunctionBody( bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName); @@ -148,21 +148,21 @@ private: // explist ::= {exp `,'} exp void parseExprList(TempVector& result); - // binding ::= Name [`:` TypeAnnotation] + // binding ::= Name [`:` Type] Binding parseBinding(); // bindinglist ::= (binding | `...') {`,' bindinglist} // Returns the location of the vararg ..., or std::nullopt if the function is not vararg. std::tuple parseBindingList(TempVector& result, bool allowDot3 = false); - AstType* parseOptionalTypeAnnotation(); + AstType* parseOptionalType(); - // TypeList ::= TypeAnnotation [`,' TypeList] - // ReturnType ::= TypeAnnotation | `(' TypeList `)' - // TableProp ::= Name `:' TypeAnnotation - // TableIndexer ::= `[' TypeAnnotation `]' `:' TypeAnnotation + // TypeList ::= Type [`,' TypeList] + // ReturnType ::= Type | `(' TypeList `)' + // TableProp ::= Name `:' Type + // TableIndexer ::= `[' Type `]' `:' Type // PropList ::= (TableProp | TableIndexer) [`,' PropList] - // TypeAnnotation + // Type // ::= Name // | `nil` // | `{' [PropList] `}' @@ -171,24 +171,25 @@ private: // Returns the variadic annotation, if it exists. AstTypePack* parseTypeList(TempVector& result, TempVector>& resultNames); - std::optional parseOptionalReturnTypeAnnotation(); - std::pair parseReturnTypeAnnotation(); + std::optional parseOptionalReturnType(); + std::pair parseReturnType(); - AstTableIndexer* parseTableIndexerAnnotation(); + AstTableIndexer* parseTableIndexer(); - AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack); - AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation); + AstTypeOrPack parseFunctionType(bool allowPack); + AstType* parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, + AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation); - AstType* parseTableTypeAnnotation(); - AstTypeOrPack parseSimpleTypeAnnotation(bool allowPack); + AstType* parseTableType(); + AstTypeOrPack parseSimpleType(bool allowPack); - AstTypeOrPack parseTypeOrPackAnnotation(); - AstType* parseTypeAnnotation(TempVector& parts, const Location& begin); - AstType* parseTypeAnnotation(); + AstTypeOrPack parseTypeOrPack(); + AstType* parseType(); - AstTypePack* parseTypePackAnnotation(); - AstTypePack* parseVariadicArgumentAnnotation(); + AstTypePack* parseTypePack(); + AstTypePack* parseVariadicArgumentTypePack(); + + AstType* parseTypeSuffix(AstType* type, const Location& begin); static std::optional parseUnaryOp(const Lexeme& l); static std::optional parseBinaryOp(const Lexeme& l); @@ -215,7 +216,7 @@ private: // primaryexp -> prefixexp { `.' NAME | `[' exp `]' | `:' NAME funcargs | funcargs } AstExpr* parsePrimaryExpr(bool asStatement); - // asexp -> simpleexp [`::' typeAnnotation] + // asexp -> simpleexp [`::' Type] AstExpr* parseAssertionExpr(); // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp @@ -244,7 +245,7 @@ private: // `<' namelist `>' std::pair, AstArray> parseGenericTypeList(bool withDefaultValues); - // `<' typeAnnotation[, ...] `>' + // `<' Type[, ...] `>' AstArray parseTypeParams(); std::optional> parseCharArray(); @@ -302,13 +303,12 @@ private: AstStatError* reportStatError(const Location& location, const AstArray& expressions, const AstArray& statements, const char* format, ...) LUAU_PRINTF_ATTR(5, 6); AstExprError* reportExprError(const Location& location, const AstArray& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); - AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) - LUAU_PRINTF_ATTR(4, 5); + AstTypeError* reportTypeError(const Location& location, const AstArray& types, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); // `parseErrorLocation` is associated with the parser error // `astErrorLocation` is associated with the AstTypeError created // It can be useful to have different error locations so that the parse error can include the next lexeme, while the AstTypeError can precisely // define the location (possibly of zero size) where a type annotation is expected. - AstTypeError* reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) + AstTypeError* reportMissingTypeError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); AstExpr* reportFunctionArgsError(AstExpr* func, bool self); @@ -401,8 +401,8 @@ private: std::vector scratchBinding; std::vector scratchLocal; std::vector scratchTableTypeProps; - std::vector scratchAnnotation; - std::vector scratchTypeOrPackAnnotation; + std::vector scratchType; + std::vector scratchTypeOrPack; std::vector scratchDeclaredClassProps; std::vector scratchItem; std::vector scratchArgName; diff --git a/luau/Ast/src/Lexer.cpp b/luau/Ast/src/Lexer.cpp index dac3b95..75b4fe3 100644 --- a/luau/Ast/src/Lexer.cpp +++ b/luau/Ast/src/Lexer.cpp @@ -6,8 +6,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauFixInterpStringMid, false) - namespace Luau { @@ -642,9 +640,7 @@ Lexeme Lexer::readInterpolatedStringSection(Position start, Lexeme::Type formatT } consume(); - Lexeme lexemeOutput(Location(start, position()), FFlag::LuauFixInterpStringMid ? formatType : Lexeme::InterpStringBegin, - &buffer[startOffset], offset - startOffset - 1); - return lexemeOutput; + return Lexeme(Location(start, position()), formatType, &buffer[startOffset], offset - startOffset - 1); } default: diff --git a/luau/Ast/src/Parser.cpp b/luau/Ast/src/Parser.cpp index 4c34771..40fa754 100644 --- a/luau/Ast/src/Parser.cpp +++ b/luau/Ast/src/Parser.cpp @@ -130,7 +130,7 @@ void TempVector::push_back(const T& item) size_++; } -static bool shouldParseTypePackAnnotation(Lexer& lexer) +static bool shouldParseTypePack(Lexer& lexer) { if (lexer.current().type == Lexeme::Dot3) return true; @@ -330,11 +330,12 @@ AstStat* Parser::parseStat() if (options.allowTypeAnnotations) { if (ident == "type") - return parseTypeAlias(expr->location, /* exported =*/false); + return parseTypeAlias(expr->location, /* exported= */ false); + if (ident == "export" && lexer.current().type == Lexeme::Name && AstName(lexer.current().name) == "type") { nextLexeme(); - return parseTypeAlias(expr->location, /* exported =*/true); + return parseTypeAlias(expr->location, /* exported= */ true); } } @@ -742,7 +743,7 @@ AstStat* Parser::parseReturn() return allocator.alloc(Location(start, end), copy(list)); } -// type Name [`<' varlist `>'] `=' typeannotation +// type Name [`<' varlist `>'] `=' Type AstStat* Parser::parseTypeAlias(const Location& start, bool exported) { // note: `type` token is already parsed for us, so we just need to parse the rest @@ -757,7 +758,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) expectAndConsume('=', "type alias"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); return allocator.alloc(Location(start, type->location), name->name, name->location, generics, genericPacks, type, exported); } @@ -789,16 +790,16 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() expectMatchAndConsume(')', matchParen); - AstTypeList retTypes = parseOptionalReturnTypeAnnotation().value_or(AstTypeList{copy(nullptr, 0), nullptr}); + AstTypeList retTypes = parseOptionalReturnType().value_or(AstTypeList{copy(nullptr, 0), nullptr}); Location end = lexer.current().location; - TempVector vars(scratchAnnotation); + TempVector vars(scratchType); TempVector> varNames(scratchOptArgName); if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) { return AstDeclaredClassProp{ - fnName.name, reportTypeAnnotationError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; + fnName.name, reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; } // Skip the first index. @@ -809,7 +810,7 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() if (args[i].annotation) vars.push_back(args[i].annotation); else - vars.push_back(reportTypeAnnotationError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated")); + vars.push_back(reportTypeError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated")); } if (vararg && !varargAnnotation) @@ -846,10 +847,10 @@ AstStat* Parser::parseDeclaration(const Location& start) expectMatchAndConsume(')', matchParen); - AstTypeList retTypes = parseOptionalReturnTypeAnnotation().value_or(AstTypeList{copy(nullptr, 0)}); + AstTypeList retTypes = parseOptionalReturnType().value_or(AstTypeList{copy(nullptr, 0)}); Location end = lexer.current().location; - TempVector vars(scratchAnnotation); + TempVector vars(scratchType); TempVector varNames(scratchArgName); for (size_t i = 0; i < args.size(); ++i) @@ -898,7 +899,7 @@ AstStat* Parser::parseDeclaration(const Location& start) expectMatchAndConsume(']', begin); expectAndConsume(':', "property type annotation"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); // TODO: since AstName conains a char*, it can't contain null bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); @@ -912,7 +913,7 @@ AstStat* Parser::parseDeclaration(const Location& start) { Name propName = parseName("property name"); expectAndConsume(':', "property type annotation"); - AstType* propType = parseTypeAnnotation(); + AstType* propType = parseType(); props.push_back(AstDeclaredClassProp{propName.name, propType, false}); } } @@ -926,7 +927,7 @@ AstStat* Parser::parseDeclaration(const Location& start) { expectAndConsume(':', "global variable declaration"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); return allocator.alloc(Location(start, type->location), globalName->name, type); } else @@ -1027,7 +1028,7 @@ std::pair Parser::parseFunctionBody( expectMatchAndConsume(')', matchParen, true); - std::optional typelist = parseOptionalReturnTypeAnnotation(); + std::optional typelist = parseOptionalReturnType(); AstLocal* funLocal = nullptr; @@ -1085,7 +1086,7 @@ Parser::Binding Parser::parseBinding() if (!name) name = Name(nameError, lexer.current().location); - AstType* annotation = parseOptionalTypeAnnotation(); + AstType* annotation = parseOptionalType(); return Binding(*name, annotation); } @@ -1104,7 +1105,7 @@ std::tuple Parser::parseBindingList(TempVector Parser::parseBindingList(TempVector& result, TempVector>& resultNames) { while (true) { - if (shouldParseTypePackAnnotation(lexer)) - return parseTypePackAnnotation(); + if (shouldParseTypePack(lexer)) + return parseTypePack(); if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == ':') { @@ -1156,7 +1157,7 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector& result, TempVector Parser::parseOptionalReturnTypeAnnotation() +std::optional Parser::parseOptionalReturnType() { if (options.allowTypeAnnotations && (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow)) { @@ -1183,7 +1184,7 @@ std::optional Parser::parseOptionalReturnTypeAnnotation() unsigned int oldRecursionCount = recursionCounter; - auto [_location, result] = parseReturnTypeAnnotation(); + auto [_location, result] = parseReturnType(); // At this point, if we find a , character, it indicates that there are multiple return types // in this type annotation, but the list wasn't wrapped in parentheses. @@ -1202,27 +1203,27 @@ std::optional Parser::parseOptionalReturnTypeAnnotation() return std::nullopt; } -// ReturnType ::= TypeAnnotation | `(' TypeList `)' -std::pair Parser::parseReturnTypeAnnotation() +// ReturnType ::= Type | `(' TypeList `)' +std::pair Parser::parseReturnType() { incrementRecursionCounter("type annotation"); - TempVector result(scratchAnnotation); - TempVector> resultNames(scratchOptArgName); - AstTypePack* varargAnnotation = nullptr; - Lexeme begin = lexer.current(); if (lexer.current().type != '(') { - if (shouldParseTypePackAnnotation(lexer)) - varargAnnotation = parseTypePackAnnotation(); + if (shouldParseTypePack(lexer)) + { + AstTypePack* typePack = parseTypePack(); + + return {typePack->location, AstTypeList{{}, typePack}}; + } else - result.push_back(parseTypeAnnotation()); + { + AstType* type = parseType(); - Location resultLocation = result.size() == 0 ? varargAnnotation->location : result[0]->location; - - return {resultLocation, AstTypeList{copy(result), varargAnnotation}}; + return {type->location, AstTypeList{copy(&type, 1), nullptr}}; + } } nextLexeme(); @@ -1231,6 +1232,10 @@ std::pair Parser::parseReturnTypeAnnotation() matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; + TempVector result(scratchType); + TempVector> resultNames(scratchOptArgName); + AstTypePack* varargAnnotation = nullptr; + // possibly () -> ReturnType if (lexer.current().type != ')') varargAnnotation = parseTypeList(result, resultNames); @@ -1246,9 +1251,9 @@ std::pair Parser::parseReturnTypeAnnotation() // If it turns out that it's just '(A)', it's possible that there are unions/intersections to follow, so fold over it. if (result.size() == 1) { - AstType* returnType = parseTypeAnnotation(result, innerBegin); + AstType* returnType = parseTypeSuffix(result[0], innerBegin); - // If parseTypeAnnotation parses nothing, then returnType->location.end only points at the last non-type-pack + // If parseType parses nothing, then returnType->location.end only points at the last non-type-pack // type to successfully parse. We need the span of the whole annotation. Position endPos = result.size() == 1 ? location.end : returnType->location.end; @@ -1258,39 +1263,33 @@ std::pair Parser::parseReturnTypeAnnotation() return {location, AstTypeList{copy(result), varargAnnotation}}; } - AstArray generics{nullptr, 0}; - AstArray genericPacks{nullptr, 0}; - AstArray types = copy(result); - AstArray> names = copy(resultNames); + AstType* tail = parseFunctionTypeTail(begin, {}, {}, copy(result), copy(resultNames), varargAnnotation); - TempVector fallbackReturnTypes(scratchAnnotation); - fallbackReturnTypes.push_back(parseFunctionTypeAnnotationTail(begin, generics, genericPacks, types, names, varargAnnotation)); - - return {Location{location, fallbackReturnTypes[0]->location}, AstTypeList{copy(fallbackReturnTypes), varargAnnotation}}; + return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}}; } -// TableIndexer ::= `[' TypeAnnotation `]' `:' TypeAnnotation -AstTableIndexer* Parser::parseTableIndexerAnnotation() +// TableIndexer ::= `[' Type `]' `:' Type +AstTableIndexer* Parser::parseTableIndexer() { const Lexeme begin = lexer.current(); nextLexeme(); // [ - AstType* index = parseTypeAnnotation(); + AstType* index = parseType(); expectMatchAndConsume(']', begin); expectAndConsume(':', "table field"); - AstType* result = parseTypeAnnotation(); + AstType* result = parseType(); return allocator.alloc(AstTableIndexer{index, result, Location(begin.location, result->location)}); } -// TableProp ::= Name `:' TypeAnnotation +// TableProp ::= Name `:' Type // TablePropOrIndexer ::= TableProp | TableIndexer // PropList ::= TablePropOrIndexer {fieldsep TablePropOrIndexer} [fieldsep] -// TableTypeAnnotation ::= `{' PropList `}' -AstType* Parser::parseTableTypeAnnotation() +// TableType ::= `{' PropList `}' +AstType* Parser::parseTableType() { incrementRecursionCounter("type annotation"); @@ -1313,7 +1312,7 @@ AstType* Parser::parseTableTypeAnnotation() expectMatchAndConsume(']', begin); expectAndConsume(':', "table field"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); // TODO: since AstName conains a char*, it can't contain null bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); @@ -1329,19 +1328,19 @@ AstType* Parser::parseTableTypeAnnotation() { // maybe we don't need to parse the entire badIndexer... // however, we either have { or [ to lint, not the entire table type or the bad indexer. - AstTableIndexer* badIndexer = parseTableIndexerAnnotation(); + AstTableIndexer* badIndexer = parseTableIndexer(); // we lose all additional indexer expressions from the AST after error recovery here report(badIndexer->location, "Cannot have more than one table indexer"); } else { - indexer = parseTableIndexerAnnotation(); + indexer = parseTableIndexer(); } } else if (props.empty() && !indexer && !(lexer.current().type == Lexeme::Name && lexer.lookahead().type == ':')) { - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); // array-like table type: {T} desugars into {[number]: T} AstType* index = allocator.alloc(type->location, std::nullopt, nameNumber); @@ -1358,7 +1357,7 @@ AstType* Parser::parseTableTypeAnnotation() expectAndConsume(':', "table field"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); props.push_back({name->name, name->location, type}); } @@ -1382,9 +1381,9 @@ AstType* Parser::parseTableTypeAnnotation() return allocator.alloc(Location(start, end), copy(props), indexer); } -// ReturnType ::= TypeAnnotation | `(' TypeList `)' -// FunctionTypeAnnotation ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) +// ReturnType ::= Type | `(' TypeList `)' +// FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType +AstTypeOrPack Parser::parseFunctionType(bool allowPack) { incrementRecursionCounter("type annotation"); @@ -1400,7 +1399,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; - TempVector params(scratchAnnotation); + TempVector params(scratchType); TempVector> names(scratchOptArgName); AstTypePack* varargAnnotation = nullptr; @@ -1432,12 +1431,11 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) AstArray> paramNames = copy(names); - return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; + return {parseFunctionTypeTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } -AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation) - +AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, + AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation) { incrementRecursionCounter("type annotation"); @@ -1458,21 +1456,22 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray(Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList); } -// typeannotation ::= +// Type ::= // nil | // Name[`.' Name] [`<' namelist `>'] | // `{' [PropList] `}' | // `(' [TypeList] `)' `->` ReturnType -// `typeof` typeannotation -AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location& begin) +// `typeof` Type +AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) { - LUAU_ASSERT(!parts.empty()); + TempVector parts(scratchType); + parts.push_back(type); incrementRecursionCounter("type annotation"); @@ -1487,7 +1486,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (c == '|') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); + parts.push_back(parseSimpleType(/* allowPack= */ false).type); isUnion = true; } else if (c == '?') @@ -1500,7 +1499,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location else if (c == '&') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); + parts.push_back(parseSimpleType(/* allowPack= */ false).type); isIntersection = true; } else if (c == Lexeme::Dot3) @@ -1513,11 +1512,11 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location } if (parts.size() == 1) - return parts[0]; + return type; if (isUnion && isIntersection) { - return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), + return reportTypeError(Location(begin, parts.back()->location), copy(parts), "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); } @@ -1533,16 +1532,14 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location ParseError::raise(begin, "Composite type was not an intersection or union."); } -AstTypeOrPack Parser::parseTypeOrPackAnnotation() +AstTypeOrPack Parser::parseTypeOrPack() { unsigned int oldRecursionCount = recursionCounter; incrementRecursionCounter("type annotation"); Location begin = lexer.current().location; - TempVector parts(scratchAnnotation); - - auto [type, typePack] = parseSimpleTypeAnnotation(/* allowPack= */ true); + auto [type, typePack] = parseSimpleType(/* allowPack= */ true); if (typePack) { @@ -1550,31 +1547,28 @@ AstTypeOrPack Parser::parseTypeOrPackAnnotation() return {{}, typePack}; } - parts.push_back(type); - recursionCounter = oldRecursionCount; - return {parseTypeAnnotation(parts, begin), {}}; + return {parseTypeSuffix(type, begin), {}}; } -AstType* Parser::parseTypeAnnotation() +AstType* Parser::parseType() { unsigned int oldRecursionCount = recursionCounter; incrementRecursionCounter("type annotation"); Location begin = lexer.current().location; - TempVector parts(scratchAnnotation); - parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); + AstType* type = parseSimpleType(/* allowPack= */ false).type; recursionCounter = oldRecursionCount; - return parseTypeAnnotation(parts, begin); + return parseTypeSuffix(type, begin); } -// typeannotation ::= nil | Name[`.' Name] [ `<' typeannotation [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' +// Type ::= nil | Name[`.' Name] [ `<' Type [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' // | [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) +AstTypeOrPack Parser::parseSimpleType(bool allowPack) { incrementRecursionCounter("type annotation"); @@ -1603,18 +1597,18 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) return {allocator.alloc(start, svalue)}; } else - return {reportTypeAnnotationError(start, {}, "String literal contains malformed escape sequence")}; + return {reportTypeError(start, {}, "String literal contains malformed escape sequence")}; } else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) { parseInterpString(); - return {reportTypeAnnotationError(start, {}, "Interpolated string literals cannot be used as types")}; + return {reportTypeError(start, {}, "Interpolated string literals cannot be used as types")}; } else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); - return {reportTypeAnnotationError(start, {}, "Malformed string")}; + return {reportTypeError(start, {}, "Malformed string")}; } else if (lexer.current().type == Lexeme::Name) { @@ -1663,17 +1657,17 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) } else if (lexer.current().type == '{') { - return {parseTableTypeAnnotation(), {}}; + return {parseTableType(), {}}; } else if (lexer.current().type == '(' || lexer.current().type == '<') { - return parseFunctionTypeAnnotation(allowPack); + return parseFunctionType(allowPack); } else if (lexer.current().type == Lexeme::ReservedFunction) { nextLexeme(); - return {reportTypeAnnotationError(start, {}, + return {reportTypeError(start, {}, "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " "...any'"), {}}; @@ -1685,12 +1679,11 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) // The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message). // Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau. Location parseErrorLocation(lexer.previousLocation().end, start.end); - return { - reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}}; + return {reportMissingTypeError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}}; } } -AstTypePack* Parser::parseVariadicArgumentAnnotation() +AstTypePack* Parser::parseVariadicArgumentTypePack() { // Generic: a... if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == Lexeme::Dot3) @@ -1705,19 +1698,19 @@ AstTypePack* Parser::parseVariadicArgumentAnnotation() // Variadic: T else { - AstType* variadicAnnotation = parseTypeAnnotation(); + AstType* variadicAnnotation = parseType(); return allocator.alloc(variadicAnnotation->location, variadicAnnotation); } } -AstTypePack* Parser::parseTypePackAnnotation() +AstTypePack* Parser::parseTypePack() { // Variadic: ...T if (lexer.current().type == Lexeme::Dot3) { Location start = lexer.current().location; nextLexeme(); - AstType* varargTy = parseTypeAnnotation(); + AstType* varargTy = parseType(); return allocator.alloc(Location(start, varargTy->location), varargTy); } // Generic: a... @@ -2054,7 +2047,7 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) return expr; } -// asexp -> simpleexp [`::' typeannotation] +// asexp -> simpleexp [`::' Type] AstExpr* Parser::parseAssertionExpr() { Location start = lexer.current().location; @@ -2063,7 +2056,7 @@ AstExpr* Parser::parseAssertionExpr() if (options.allowTypeAnnotations && lexer.current().type == Lexeme::DoubleColon) { nextLexeme(); - AstType* annotation = parseTypeAnnotation(); + AstType* annotation = parseType(); return allocator.alloc(Location(start, annotation->location), expr, annotation); } else @@ -2455,15 +2448,15 @@ std::pair, AstArray> Parser::parseG Lexeme packBegin = lexer.current(); - if (shouldParseTypePackAnnotation(lexer)) + if (shouldParseTypePack(lexer)) { - AstTypePack* typePack = parseTypePackAnnotation(); + AstTypePack* typePack = parseTypePack(); namePacks.push_back({name, nameLocation, typePack}); } else if (!FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument && lexer.current().type == '(') { - auto [type, typePack] = parseTypeOrPackAnnotation(); + auto [type, typePack] = parseTypeOrPack(); if (type) report(Location(packBegin.location.begin, lexer.previousLocation().end), "Expected type pack after '=', got type"); @@ -2472,7 +2465,7 @@ std::pair, AstArray> Parser::parseG } else if (FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument) { - auto [type, typePack] = parseTypeOrPackAnnotation(); + auto [type, typePack] = parseTypeOrPack(); if (type) report(type->location, "Expected type pack after '=', got type"); @@ -2495,7 +2488,7 @@ std::pair, AstArray> Parser::parseG seenDefault = true; nextLexeme(); - AstType* defaultType = parseTypeAnnotation(); + AstType* defaultType = parseType(); names.push_back({name, nameLocation, defaultType}); } @@ -2532,7 +2525,7 @@ std::pair, AstArray> Parser::parseG AstArray Parser::parseTypeParams() { - TempVector parameters{scratchTypeOrPackAnnotation}; + TempVector parameters{scratchTypeOrPack}; if (lexer.current().type == '<') { @@ -2541,15 +2534,15 @@ AstArray Parser::parseTypeParams() while (true) { - if (shouldParseTypePackAnnotation(lexer)) + if (shouldParseTypePack(lexer)) { - AstTypePack* typePack = parseTypePackAnnotation(); + AstTypePack* typePack = parseTypePack(); parameters.push_back({{}, typePack}); } else if (lexer.current().type == '(') { - auto [type, typePack] = parseTypeOrPackAnnotation(); + auto [type, typePack] = parseTypeOrPack(); if (typePack) parameters.push_back({{}, typePack}); @@ -2562,7 +2555,7 @@ AstArray Parser::parseTypeParams() } else { - parameters.push_back({parseTypeAnnotation(), {}}); + parameters.push_back({parseType(), {}}); } if (lexer.current().type == ',') @@ -3018,7 +3011,7 @@ AstExprError* Parser::reportExprError(const Location& location, const AstArray(location, expressions, unsigned(parseErrors.size() - 1)); } -AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) +AstTypeError* Parser::reportTypeError(const Location& location, const AstArray& types, const char* format, ...) { va_list args; va_start(args, format); @@ -3028,7 +3021,7 @@ AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const return allocator.alloc(location, types, false, unsigned(parseErrors.size() - 1)); } -AstTypeError* Parser::reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) +AstTypeError* Parser::reportMissingTypeError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) { va_list args; va_start(args, format); diff --git a/luau/Common/include/Luau/Bytecode.h b/luau/Common/include/Luau/Bytecode.h index 67f9fbe..82bf6e5 100644 --- a/luau/Common/include/Luau/Bytecode.h +++ b/luau/Common/include/Luau/Bytecode.h @@ -25,7 +25,7 @@ // Additionally, in some specific instructions such as ANDK, the limit on the encoded value is smaller; this means that if a value is larger, a different instruction must be selected. // // Registers: 0-254. Registers refer to the values on the function's stack frame, including arguments. -// Upvalues: 0-254. Upvalues refer to the values stored in the closure object. +// Upvalues: 0-199. Upvalues refer to the values stored in the closure object. // Constants: 0-2^23-1. Constants are stored in a table allocated with each proto; to allow for future bytecode tweaks the encodable value is limited to 23 bits. // Closures: 0-2^15-1. Closures are created from child protos via a child index; the limit is for the number of closures immediately referenced in each function. // Jumps: -2^23..2^23. Jump offsets are specified in word increments, so jumping over an instruction may sometimes require an offset of 2 or more. Note that for jump instructions with AUX, the AUX word is included as part of the jump offset. @@ -93,12 +93,12 @@ enum LuauOpcode // GETUPVAL: load upvalue from the upvalue table for the current function // A: target register - // B: upvalue index (0..255) + // B: upvalue index LOP_GETUPVAL, // SETUPVAL: store value into the upvalue table for the current function // A: target register - // B: upvalue index (0..255) + // B: upvalue index LOP_SETUPVAL, // CLOSEUPVALS: close (migrate to heap) all upvalues that were captured for registers >= target diff --git a/luau/Common/include/Luau/ExperimentalFlags.h b/luau/Common/include/Luau/ExperimentalFlags.h index 35e11ca..8eca105 100644 --- a/luau/Common/include/Luau/ExperimentalFlags.h +++ b/luau/Common/include/Luau/ExperimentalFlags.h @@ -11,8 +11,9 @@ inline bool isFlagExperimental(const char* flag) // Flags in this list are disabled by default in various command-line tools. They may have behavior that is not fully final, // or critical bugs that are found after the code has been submitted. static const char* const kList[] = { - "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code - "LuauTypecheckTypeguards", // requires some fixes to lua-apps code (CLI-67030) + "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code + "LuauTypecheckTypeguards", // requires some fixes to lua-apps code (CLI-67030) + "LuauTinyControlFlowAnalysis", // waiting for updates to packages depended by internal builtin plugins // makes sure we always have at least one entry nullptr, }; diff --git a/luau/Compiler/src/Compiler.cpp b/luau/Compiler/src/Compiler.cpp index 78896d3..9478404 100644 --- a/luau/Compiler/src/Compiler.cpp +++ b/luau/Compiler/src/Compiler.cpp @@ -25,9 +25,6 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileTerminateBC, false) -LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinArity, false) - namespace Luau { @@ -143,7 +140,7 @@ struct Compiler return stat->body.size > 0 && alwaysTerminates(stat->body.data[stat->body.size - 1]); else if (node->is()) return true; - else if (FFlag::LuauCompileTerminateBC && (node->is() || node->is())) + else if (node->is() || node->is()) return true; else if (AstStatIf* stat = node->as()) return stat->elsebody && alwaysTerminates(stat->thenbody) && alwaysTerminates(stat->elsebody); @@ -296,7 +293,7 @@ struct Compiler // handles builtin calls that can't be constant-folded but are known to return one value // note: optimizationLevel check is technically redundant but it's important that we never optimize based on builtins in O1 - if (FFlag::LuauCompileBuiltinArity && options.optimizationLevel >= 2) + if (options.optimizationLevel >= 2) if (int* bfid = builtins.find(expr)) return getBuiltinInfo(*bfid).results != 1; @@ -767,7 +764,7 @@ struct Compiler { if (!isExprMultRet(expr->args.data[expr->args.size - 1])) return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); - else if (FFlag::LuauCompileBuiltinArity && options.optimizationLevel >= 2 && int(expr->args.size) == getBuiltinInfo(bfid).params) + else if (options.optimizationLevel >= 2 && int(expr->args.size) == getBuiltinInfo(bfid).params) return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); } diff --git a/luau/VM/include/lua.h b/luau/VM/include/lua.h index 783e60a..3f5e99f 100644 --- a/luau/VM/include/lua.h +++ b/luau/VM/include/lua.h @@ -29,7 +29,7 @@ enum lua_Status LUA_OK = 0, LUA_YIELD, LUA_ERRRUN, - LUA_ERRSYNTAX, + LUA_ERRSYNTAX, // legacy error code, preserved for compatibility LUA_ERRMEM, LUA_ERRERR, LUA_BREAK, // yielded for a debug breakpoint diff --git a/luau/VM/src/lbuiltins.cpp b/luau/VM/src/lbuiltins.cpp index 3c669bf..e0dc8a3 100644 --- a/luau/VM/src/lbuiltins.cpp +++ b/luau/VM/src/lbuiltins.cpp @@ -23,8 +23,6 @@ #endif #endif -LUAU_FASTFLAGVARIABLE(LuauBuiltinSSE41, false) - // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -105,9 +103,7 @@ static int luauF_atan(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } -// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN -LUAU_NOINLINE static int luauF_ceil(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -170,9 +166,7 @@ static int luauF_exp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } -// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN -LUAU_NOINLINE static int luauF_floor(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -949,9 +943,7 @@ static int luauF_sign(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } -// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN -LUAU_NOINLINE static int luauF_round(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -1271,9 +1263,6 @@ LUAU_TARGET_SSE41 inline double roundsd_sse41(double v) LUAU_TARGET_SSE41 static int luauF_floor_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (!FFlag::LuauBuiltinSSE41) - return luauF_floor(L, res, arg0, nresults, args, nparams); - if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) { double a1 = nvalue(arg0); @@ -1286,9 +1275,6 @@ LUAU_TARGET_SSE41 static int luauF_floor_sse41(lua_State* L, StkId res, TValue* LUAU_TARGET_SSE41 static int luauF_ceil_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (!FFlag::LuauBuiltinSSE41) - return luauF_ceil(L, res, arg0, nresults, args, nparams); - if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) { double a1 = nvalue(arg0); @@ -1301,9 +1287,6 @@ LUAU_TARGET_SSE41 static int luauF_ceil_sse41(lua_State* L, StkId res, TValue* a LUAU_TARGET_SSE41 static int luauF_round_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (!FFlag::LuauBuiltinSSE41) - return luauF_round(L, res, arg0, nresults, args, nparams); - if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) { double a1 = nvalue(arg0); diff --git a/luau/VM/src/ldo.cpp b/luau/VM/src/ldo.cpp index ff8105b..264388b 100644 --- a/luau/VM/src/ldo.cpp +++ b/luau/VM/src/ldo.cpp @@ -17,6 +17,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauBetterOOMHandling, false) + /* ** {====================================================== ** Error-recovery functions @@ -79,22 +81,17 @@ public: const char* what() const throw() override { - // LUA_ERRRUN/LUA_ERRSYNTAX pass an object on the stack which is intended to describe the error. - if (status == LUA_ERRRUN || status == LUA_ERRSYNTAX) - { - // Conversion to a string could still fail. For example if a user passes a non-string/non-number argument to `error()`. + // LUA_ERRRUN passes error object on the stack + if (status == LUA_ERRRUN || (status == LUA_ERRSYNTAX && !FFlag::LuauBetterOOMHandling)) if (const char* str = lua_tostring(L, -1)) - { return str; - } - } switch (status) { case LUA_ERRRUN: - return "lua_exception: LUA_ERRRUN (no string/number provided as description)"; + return "lua_exception: runtime error"; case LUA_ERRSYNTAX: - return "lua_exception: LUA_ERRSYNTAX (no string/number provided as description)"; + return "lua_exception: syntax error"; case LUA_ERRMEM: return "lua_exception: " LUA_MEMERRMSG; case LUA_ERRERR: @@ -550,19 +547,42 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e int status = luaD_rawrunprotected(L, func, u); if (status != 0) { + int errstatus = status; + // call user-defined error function (used in xpcall) if (ef) { - // if errfunc fails, we fail with "error in error handling" - if (luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)) != 0) - status = LUA_ERRERR; + if (FFlag::LuauBetterOOMHandling) + { + // push error object to stack top if it's not already there + if (status != LUA_ERRRUN) + seterrorobj(L, status, L->top); + + // if errfunc fails, we fail with "error in error handling" or "not enough memory" + int err = luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)); + + // in general we preserve the status, except for cases when the error handler fails + // out of memory is treated specially because it's common for it to be cascading, in which case we preserve the code + if (err == 0) + errstatus = LUA_ERRRUN; + else if (status == LUA_ERRMEM && err == LUA_ERRMEM) + errstatus = LUA_ERRMEM; + else + errstatus = status = LUA_ERRERR; + } + else + { + // if errfunc fails, we fail with "error in error handling" + if (luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)) != 0) + status = LUA_ERRERR; + } } // since the call failed with an error, we might have to reset the 'active' thread state if (!oldactive) L->isactive = false; - // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. + // restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. L->nCcalls = oldnCcalls; // an error occurred, check if we have a protected error callback @@ -577,7 +597,7 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e StkId oldtop = restorestack(L, old_top); luaF_close(L, oldtop); // close eventual pending closures - seterrorobj(L, status, oldtop); + seterrorobj(L, FFlag::LuauBetterOOMHandling ? errstatus : status, oldtop); L->ci = restoreci(L, old_ci); L->base = L->ci->base; restore_stack_limit(L); diff --git a/luau/VM/src/lstate.h b/luau/VM/src/lstate.h index 1d32489..32a240b 100644 --- a/luau/VM/src/lstate.h +++ b/luau/VM/src/lstate.h @@ -208,14 +208,14 @@ typedef struct global_State uint64_t rngstate; // PCG random number generator state uint64_t ptrenckey[4]; // pointer encoding key for display - void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory - lua_Callbacks cb; #if LUA_CUSTOM_EXECUTION lua_ExecutionCallbacks ecb; #endif + void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory + GCStats gcstats; #ifdef LUAI_GCMETRICS diff --git a/luau/VM/src/ltable.cpp b/luau/VM/src/ltable.cpp index 8d59ecb..5eceea7 100644 --- a/luau/VM/src/ltable.cpp +++ b/luau/VM/src/ltable.cpp @@ -33,6 +33,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauArrBoundResizeFix, false) + // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) @@ -454,15 +456,43 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) int nasize = numusearray(t, nums); // count keys in array part int totaluse = nasize; // all those keys are integer keys totaluse += numusehash(t, nums, &nasize); // count keys in hash part + // count extra key if (ttisnumber(ek)) nasize += countint(nvalue(ek), nums); totaluse++; + // compute new size for array part int na = computesizes(nums, &nasize); int nh = totaluse - na; - // enforce the boundary invariant; for performance, only do hash lookups if we must - nasize = adjustasize(t, nasize, ek); + + if (FFlag::LuauArrBoundResizeFix) + { + // enforce the boundary invariant; for performance, only do hash lookups if we must + int nadjusted = adjustasize(t, nasize, ek); + + // count how many extra elements belong to array part instead of hash part + int aextra = nadjusted - nasize; + + if (aextra != 0) + { + // we no longer need to store those extra array elements in hash part + nh -= aextra; + + // because hash nodes are twice as large as array nodes, the memory we saved for hash parts can be used by array part + // this follows the general sparse array part optimization where array is allocated when 50% occupation is reached + nasize = nadjusted + aextra; + + // since the size was changed, it's again important to enforce the boundary invariant at the new size + nasize = adjustasize(t, nasize, ek); + } + } + else + { + // enforce the boundary invariant; for performance, only do hash lookups if we must + nasize = adjustasize(t, nasize, ek); + } + // resize the table to new computed sizes resize(L, t, nasize, nh); } diff --git a/luau/VM/src/ltablib.cpp b/luau/VM/src/ltablib.cpp index ddee3a7..4443be3 100644 --- a/luau/VM/src/ltablib.cpp +++ b/luau/VM/src/ltablib.cpp @@ -10,7 +10,7 @@ #include "ldebug.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauOptimizedSort, false) +LUAU_FASTFLAGVARIABLE(LuauIntrosort, false) static int foreachi(lua_State* L) { @@ -298,120 +298,6 @@ static int tunpack(lua_State* L) return (int)n; } -/* -** {====================================================== -** Quicksort -** (based on `Algorithms in MODULA-3', Robert Sedgewick; -** Addison-Wesley, 1993.) -*/ - -static void set2(lua_State* L, int i, int j) -{ - LUAU_ASSERT(!FFlag::LuauOptimizedSort); - lua_rawseti(L, 1, i); - lua_rawseti(L, 1, j); -} - -static int sort_comp(lua_State* L, int a, int b) -{ - LUAU_ASSERT(!FFlag::LuauOptimizedSort); - if (!lua_isnil(L, 2)) - { // function? - int res; - lua_pushvalue(L, 2); - lua_pushvalue(L, a - 1); // -1 to compensate function - lua_pushvalue(L, b - 2); // -2 to compensate function and `a' - lua_call(L, 2, 1); - res = lua_toboolean(L, -1); - lua_pop(L, 1); - return res; - } - else // a < b? - return lua_lessthan(L, a, b); -} - -static void auxsort(lua_State* L, int l, int u) -{ - LUAU_ASSERT(!FFlag::LuauOptimizedSort); - while (l < u) - { // for tail recursion - int i, j; - // sort elements a[l], a[(l+u)/2] and a[u] - lua_rawgeti(L, 1, l); - lua_rawgeti(L, 1, u); - if (sort_comp(L, -1, -2)) // a[u] < a[l]? - set2(L, l, u); // swap a[l] - a[u] - else - lua_pop(L, 2); - if (u - l == 1) - break; // only 2 elements - i = (l + u) / 2; - lua_rawgeti(L, 1, i); - lua_rawgeti(L, 1, l); - if (sort_comp(L, -2, -1)) // a[i]= P - while (lua_rawgeti(L, 1, ++i), sort_comp(L, -1, -2)) - { - if (i >= u) - luaL_error(L, "invalid order function for sorting"); - lua_pop(L, 1); // remove a[i] - } - // repeat --j until a[j] <= P - while (lua_rawgeti(L, 1, --j), sort_comp(L, -3, -1)) - { - if (j <= l) - luaL_error(L, "invalid order function for sorting"); - lua_pop(L, 1); // remove a[j] - } - if (j < i) - { - lua_pop(L, 3); // pop pivot, a[i], a[j] - break; - } - set2(L, i, j); - } - lua_rawgeti(L, 1, u - 1); - lua_rawgeti(L, 1, i); - set2(L, u - 1, i); // swap pivot (a[u-1]) with a[i] - // a[l..i-1] <= a[i] == P <= a[i+1..u] - // adjust so that smaller half is in [j..i] and larger one in [l..u] - if (i - l < u - i) - { - j = l; - i = i - 1; - l = i + 2; - } - else - { - j = i + 1; - i = u; - u = j - 2; - } - auxsort(L, j, i); // call recursively the smaller one - } // repeat the routine for the larger one -} - typedef int (*SortPredicate)(lua_State* L, const TValue* l, const TValue* r); static int sort_func(lua_State* L, const TValue* l, const TValue* r) @@ -456,30 +342,77 @@ inline int sort_less(lua_State* L, Table* t, int i, int j, SortPredicate pred) return res; } -static void sort_rec(lua_State* L, Table* t, int l, int u, SortPredicate pred) +static void sort_siftheap(lua_State* L, Table* t, int l, int u, SortPredicate pred, int root) +{ + LUAU_ASSERT(l <= u); + int count = u - l + 1; + + // process all elements with two children + while (root * 2 + 2 < count) + { + int left = root * 2 + 1, right = root * 2 + 2; + int next = root; + next = sort_less(L, t, l + next, l + left, pred) ? left : next; + next = sort_less(L, t, l + next, l + right, pred) ? right : next; + + if (next == root) + break; + + sort_swap(L, t, l + root, l + next); + root = next; + } + + // process last element if it has just one child + int lastleft = root * 2 + 1; + if (lastleft == count - 1 && sort_less(L, t, l + root, l + lastleft, pred)) + sort_swap(L, t, l + root, l + lastleft); +} + +static void sort_heap(lua_State* L, Table* t, int l, int u, SortPredicate pred) +{ + LUAU_ASSERT(l <= u); + int count = u - l + 1; + + for (int i = count / 2 - 1; i >= 0; --i) + sort_siftheap(L, t, l, u, pred, i); + + for (int i = count - 1; i > 0; --i) + { + sort_swap(L, t, l, l + i); + sort_siftheap(L, t, l, l + i - 1, pred, 0); + } +} + +static void sort_rec(lua_State* L, Table* t, int l, int u, int limit, SortPredicate pred) { // sort range [l..u] (inclusive, 0-based) while (l < u) { - int i, j; + // if the limit has been reached, quick sort is going over the permitted nlogn complexity, so we fall back to heap sort + if (FFlag::LuauIntrosort && limit == 0) + return sort_heap(L, t, l, u, pred); + // sort elements a[l], a[(l+u)/2] and a[u] + // note: this simultaneously acts as a small sort and a median selector if (sort_less(L, t, u, l, pred)) // a[u] < a[l]? sort_swap(L, t, u, l); // swap a[l] - a[u] if (u - l == 1) break; // only 2 elements - i = l + ((u - l) >> 1); // midpoint - if (sort_less(L, t, i, l, pred)) // a[i]> 1); // midpoint + if (sort_less(L, t, m, l, pred)) // a[m]= P @@ -498,63 +431,72 @@ static void sort_rec(lua_State* L, Table* t, int l, int u, SortPredicate pred) break; sort_swap(L, t, i, j); } - // swap pivot (a[u-1]) with a[i], which is the new midpoint - sort_swap(L, t, u - 1, i); - // a[l..i-1] <= a[i] == P <= a[i+1..u] - // adjust so that smaller half is in [j..i] and larger one in [l..u] - if (i - l < u - i) + + // swap pivot a[p] with a[i], which is the new midpoint + sort_swap(L, t, p, i); + + if (FFlag::LuauIntrosort) { - j = l; - i = i - 1; - l = i + 2; + // adjust limit to allow 1.5 log2N recursive steps + limit = (limit >> 1) + (limit >> 2); + + // a[l..i-1] <= a[i] == P <= a[i+1..u] + // sort smaller half recursively; the larger half is sorted in the next loop iteration + if (i - l < u - i) + { + sort_rec(L, t, l, i - 1, limit, pred); + l = i + 1; + } + else + { + sort_rec(L, t, i + 1, u, limit, pred); + u = i - 1; + } } else { - j = i + 1; - i = u; - u = j - 2; + // a[l..i-1] <= a[i] == P <= a[i+1..u] + // adjust so that smaller half is in [j..i] and larger one in [l..u] + if (i - l < u - i) + { + j = l; + i = i - 1; + l = i + 2; + } + else + { + j = i + 1; + i = u; + u = j - 2; + } + + // sort smaller half recursively; the larger half is sorted in the next loop iteration + sort_rec(L, t, j, i, limit, pred); } - sort_rec(L, t, j, i, pred); // call recursively the smaller one - } // repeat the routine for the larger one + } } static int tsort(lua_State* L) { - if (FFlag::LuauOptimizedSort) - { - luaL_checktype(L, 1, LUA_TTABLE); - Table* t = hvalue(L->base); - int n = luaH_getn(t); - if (t->readonly) - luaG_readonlyerror(L); + luaL_checktype(L, 1, LUA_TTABLE); + Table* t = hvalue(L->base); + int n = luaH_getn(t); + if (t->readonly) + luaG_readonlyerror(L); - SortPredicate pred = luaV_lessthan; - if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? - { - luaL_checktype(L, 2, LUA_TFUNCTION); - pred = sort_func; - } - lua_settop(L, 2); // make sure there are two arguments - - if (n > 0) - sort_rec(L, t, 0, n - 1, pred); - return 0; - } - else + SortPredicate pred = luaV_lessthan; + if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? { - luaL_checktype(L, 1, LUA_TTABLE); - int n = lua_objlen(L, 1); - luaL_checkstack(L, 40, ""); // assume array is smaller than 2^40 - if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? - luaL_checktype(L, 2, LUA_TFUNCTION); - lua_settop(L, 2); // make sure there is two arguments - auxsort(L, 1, n); - return 0; + luaL_checktype(L, 2, LUA_TFUNCTION); + pred = sort_func; } + lua_settop(L, 2); // make sure there are two arguments + + if (n > 0) + sort_rec(L, t, 0, n - 1, n, pred); + return 0; } -// }====================================================== - static int tcreate(lua_State* L) { int size = luaL_checkinteger(L, 1);