From d42a5dbe48bdc597d0a8a2c79e52f3a6d908169d Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 7 Apr 2022 13:53:47 -0700 Subject: [PATCH 01/19] Sync to upstream/release/522 --- Analysis/include/Luau/Clone.h | 25 ++ Analysis/include/Luau/Frontend.h | 23 +- Analysis/include/Luau/Module.h | 29 +- Analysis/include/Luau/TypeInfer.h | 10 + Analysis/include/Luau/TypeVar.h | 2 + Analysis/include/Luau/Variant.h | 36 +- Analysis/src/Autocomplete.cpp | 117 +++++- Analysis/src/Clone.cpp | 371 ++++++++++++++++++ Analysis/src/Error.cpp | 1 + Analysis/src/Frontend.cpp | 144 +++++-- Analysis/src/IostreamHelpers.cpp | 419 +++++++++------------ Analysis/src/Module.cpp | 359 +----------------- Analysis/src/TxnLog.cpp | 31 +- Analysis/src/TypeInfer.cpp | 156 +++++--- Analysis/src/TypeVar.cpp | 4 + Ast/include/Luau/TimeTrace.h | 11 +- Ast/src/Parser.cpp | 11 + Ast/src/TimeTrace.cpp | 19 +- Compiler/src/ConstantFolding.cpp | 3 +- Sources.cmake | 2 + VM/src/lfunc.h | 2 +- VM/src/ltablib.cpp | 2 +- tests/Autocomplete.test.cpp | 139 ++++++- tests/Compiler.test.cpp | 33 ++ tests/Fixture.cpp | 5 +- tests/Fixture.h | 2 +- tests/Frontend.test.cpp | 64 ++++ tests/Module.test.cpp | 1 + tests/Parser.test.cpp | 21 ++ tests/TypeInfer.functions.test.cpp | 39 +- tests/TypeInfer.intersectionTypes.test.cpp | 39 +- tests/TypeInfer.primitives.test.cpp | 2 + tests/TypeInfer.tables.test.cpp | 15 + tests/TypeInfer.tryUnify.test.cpp | 26 ++ tests/TypeInfer.unionTypes.test.cpp | 25 ++ tools/lldb_formatters.py | 4 +- 36 files changed, 1407 insertions(+), 785 deletions(-) create mode 100644 Analysis/include/Luau/Clone.h create mode 100644 Analysis/src/Clone.cpp diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h new file mode 100644 index 0000000..917ef80 --- /dev/null +++ b/Analysis/include/Luau/Clone.h @@ -0,0 +1,25 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeVar.h" + +#include + +namespace Luau +{ + +// Only exposed so they can be unit tested. +using SeenTypes = std::unordered_map; +using SeenTypePacks = std::unordered_map; + +struct CloneState +{ + int recursionCount = 0; + bool encounteredFreeType = false; +}; + +TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); +TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); + +} // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 0bf8f36..2266f54 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -12,6 +12,8 @@ #include #include +LUAU_FASTFLAG(LuauSeparateTypechecks) + namespace Luau { @@ -55,10 +57,19 @@ std::optional pathExprToModuleName(const ModuleName& currentModuleNa struct SourceNode { + bool isDirty(bool forAutocomplete) const + { + if (FFlag::LuauSeparateTypechecks) + return forAutocomplete ? dirtyAutocomplete : dirty; + else + return dirty; + } + ModuleName name; std::unordered_set requires; std::vector> requireLocations; bool dirty = true; + bool dirtyAutocomplete = true; }; struct FrontendOptions @@ -71,12 +82,16 @@ struct FrontendOptions // When true, we run typechecking twice, once in the regular mode, and once in strict mode // in order to get more precise type information (e.g. for autocomplete). - bool typecheckTwice = false; + bool typecheckTwice_DEPRECATED = false; + + // Run typechecking only in mode required for autocomplete (strict mode in order to get more precise type information) + bool forAutocomplete = false; }; struct CheckResult { std::vector errors; + std::vector timeoutHits; }; struct FrontendModuleResolver : ModuleResolver @@ -123,7 +138,7 @@ struct Frontend CheckResult check(const SourceModule& module); // OLD. TODO KILL LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); - bool isDirty(const ModuleName& name) const; + bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); /** Borrow a pointer into the SourceModule cache. @@ -147,10 +162,10 @@ struct Frontend void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); private: - std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name); + std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); - bool parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root); + bool parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete); static LintResult classifyLints(const std::vector& warnings, const Config& config); diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 6c689b7..9a32f61 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -29,8 +29,8 @@ struct SourceModule std::optional environmentName; bool cyclic = false; - std::unique_ptr allocator; - std::unique_ptr names; + std::shared_ptr allocator; + std::shared_ptr names; std::vector parseErrors; AstStatBlock* root = nullptr; @@ -48,6 +48,12 @@ struct SourceModule bool isWithinComment(const SourceModule& sourceModule, Position pos); +struct RequireCycle +{ + Location location; + std::vector path; // one of the paths for a require() to go all the way back to the originating module +}; + struct TypeArena { TypedAllocator typeVars; @@ -77,20 +83,6 @@ struct TypeArena void freeze(TypeArena& arena); void unfreeze(TypeArena& arena); -// Only exposed so they can be unit tested. -using SeenTypes = std::unordered_map; -using SeenTypePacks = std::unordered_map; - -struct CloneState -{ - int recursionCount = 0; - bool encounteredFreeType = false; -}; - -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); -TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); - struct Module { ~Module(); @@ -98,6 +90,10 @@ struct Module TypeArena interfaceTypes; TypeArena internalTypes; + // Scopes and AST types refer to parse data, so we need to keep that alive + std::shared_ptr allocator; + std::shared_ptr names; + std::vector> scopes; // never empty DenseHashMap astTypes{nullptr}; @@ -109,6 +105,7 @@ struct Module ErrorVec errors; Mode mode; SourceCode::Type type; + bool timeout = false; ScopePtr getModuleScope() const; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 839043c..215da67 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -124,6 +124,12 @@ struct HashBoolNamePair size_t operator()(const std::pair& pair) const; }; +class TimeLimitError : public std::exception +{ +public: + virtual const char* what() const throw(); +}; + // All TypeVars are retained via Environment::typeVars. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker @@ -413,6 +419,10 @@ public: UnifierSharedState unifierState; + std::vector requireCycles; + + std::optional finishTime; + public: const TypeId nilType; const TypeId numberType; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index b8c4b36..f61e404 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -513,6 +513,8 @@ struct SingletonTypes const TypeId stringType; const TypeId booleanType; const TypeId threadType; + const TypeId trueType; + const TypeId falseType; const TypeId anyType; const TypeId optionalNumberType; diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index 63d5a65..5efe89e 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -2,45 +2,14 @@ #pragma once #include "Luau/Common.h" - -#ifndef LUAU_USE_STD_VARIANT -#define LUAU_USE_STD_VARIANT 0 -#endif - -#if LUAU_USE_STD_VARIANT -#include -#else #include #include #include #include -#endif namespace Luau { -#if LUAU_USE_STD_VARIANT -template -using Variant = std::variant; - -template -auto visit(Visitor&& vis, Variant&& var) -{ - // This change resolves the ABI issues with std::variant on libc++; std::visit normally throws bad_variant_access - // but it requires an update to libc++.dylib which ships with macOS 10.14. To work around this, we assert on valueless - // variants since we will never generate them and call into a libc++ function that doesn't throw. - LUAU_ASSERT(!var.valueless_by_exception()); - -#ifdef __APPLE__ - // See https://stackoverflow.com/a/53868971/503215 - return std::__variant_detail::__visitation::__variant::__visit_value(vis, var); -#else - return std::visit(vis, var); -#endif -} - -using std::get_if; -#else template class Variant { @@ -248,6 +217,8 @@ static void fnVisitV(Visitor& vis, std::conditional_t, const template auto visit(Visitor&& vis, const Variant& var) { + static_assert(std::conjunction_v...>, "visitor must accept every alternative as an argument"); + using Result = std::invoke_result_t::first_alternative>; static_assert(std::conjunction_v>...>, "visitor result type must be consistent between alternatives"); @@ -273,6 +244,8 @@ auto visit(Visitor&& vis, const Variant& var) template auto visit(Visitor&& vis, Variant& var) { + static_assert(std::conjunction_v...>, "visitor must accept every alternative as an argument"); + using Result = std::invoke_result_t::first_alternative&>; static_assert(std::conjunction_v>...>, "visitor result type must be consistent between alternatives"); @@ -294,7 +267,6 @@ auto visit(Visitor&& vis, Variant& var) return res; } } -#endif template inline constexpr bool always_false_v = false; diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 492edf2..b7201ab 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,6 +14,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); +LUAU_FASTFLAGVARIABLE(LuauAutocompleteSingletonTypes, false); LUAU_FASTFLAG(LuauSelfCallAutocompleteFix) static const std::unordered_set kStatementStartingKeywords = { @@ -625,6 +626,31 @@ AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position posi return result; } +static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AutocompleteEntryMap& result) +{ + auto formatKey = [addQuotes](const std::string& key) { + if (addQuotes) + return "\"" + escape(key) + "\""; + + return escape(key); + }; + + ty = follow(ty); + + if (auto ss = get(get(ty))) + { + result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; + } + else if (auto uty = get(ty)) + { + for (auto el : uty) + { + if (auto ss = get(get(el))) + result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; + } + } +}; + static bool canSuggestInferredType(ScopePtr scope, TypeId ty) { ty = follow(ty); @@ -1309,17 +1335,38 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul scope = scope->parent; } - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); - TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.booleanType); - TypeCorrectKind correctForFunction = - functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + if (FFlag::LuauAutocompleteSingletonTypes) + { + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); + TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().trueType); + TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().falseType); + TypeCorrectKind correctForFunction = + functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; - result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; - result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; - result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; - result["not"] = {AutocompleteEntryKind::Keyword}; - result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForTrue}; + result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForFalse}; + result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; + result["not"] = {AutocompleteEntryKind::Keyword}; + result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; + + if (auto ty = findExpectedTypeAt(module, node, position)) + autocompleteStringSingleton(*ty, true, result); + } + else + { + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); + TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.booleanType); + TypeCorrectKind correctForFunction = + functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; + result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; + result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; + result["not"] = {AutocompleteEntryKind::Keyword}; + result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; + } } } @@ -1625,17 +1672,33 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } else if (node->is()) { + AutocompleteEntryMap result; + + if (FFlag::LuauAutocompleteSingletonTypes) + { + if (auto it = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*it, false, result); + } + if (finder.ancestry.size() >= 2) { if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) { if (auto it = module->astTypes.find(idxExpr->expr)) + autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry, result); + } + else if (auto binExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as(); + binExpr && FFlag::LuauAutocompleteSingletonTypes) + { + if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) { - return {autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry), finder.ancestry}; + if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) + autocompleteStringSingleton(*it, false, result); } } } - return {}; + + return {result, finder.ancestry}; } if (node->is()) @@ -1653,18 +1716,31 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) { - // FIXME: We can improve performance here by parsing without checking. - // The old type graph is probably fine. (famous last words!) - // FIXME: We don't need to typecheck for script analysis here, just for autocomplete. - frontend.check(moduleName); + if (FFlag::LuauSeparateTypechecks) + { + // FIXME: We can improve performance here by parsing without checking. + // The old type graph is probably fine. (famous last words!) + FrontendOptions opts; + opts.forAutocomplete = true; + frontend.check(moduleName, opts); + } + else + { + // FIXME: We can improve performance here by parsing without checking. + // The old type graph is probably fine. (famous last words!) + // FIXME: We don't need to typecheck for script analysis here, just for autocomplete. + frontend.check(moduleName); + } const SourceModule* sourceModule = frontend.getSourceModule(moduleName); if (!sourceModule) return {}; - TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - ModulePtr module = (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) - : frontend.moduleResolver.getModule(moduleName)); + TypeChecker& typeChecker = + (frontend.options.typecheckTwice_DEPRECATED || FFlag::LuauSeparateTypechecks ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + ModulePtr module = + (frontend.options.typecheckTwice_DEPRECATED || FFlag::LuauSeparateTypechecks ? frontend.moduleResolverForAutocomplete.getModule(moduleName) + : frontend.moduleResolver.getModule(moduleName)); if (!module) return {}; @@ -1692,7 +1768,8 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view sourceModule->mode = Mode::Strict; sourceModule->commentLocations = std::move(result.commentLocations); - TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + TypeChecker& typeChecker = + (frontend.options.typecheckTwice_DEPRECATED || FFlag::LuauSeparateTypechecks ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp new file mode 100644 index 0000000..ac9705a --- /dev/null +++ b/Analysis/src/Clone.cpp @@ -0,0 +1,371 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Clone.h" +#include "Luau/Module.h" +#include "Luau/RecursionCounter.h" +#include "Luau/TypePack.h" +#include "Luau/Unifiable.h" + +LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) + +namespace Luau +{ + +namespace +{ + +struct TypePackCloner; + +/* + * Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set. + * They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage. + */ + +struct TypeCloner +{ + TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) + : dest(dest) + , typeId(typeId) + , seenTypes(seenTypes) + , seenTypePacks(seenTypePacks) + , cloneState(cloneState) + { + } + + TypeArena& dest; + TypeId typeId; + SeenTypes& seenTypes; + SeenTypePacks& seenTypePacks; + CloneState& cloneState; + + template + void defaultClone(const T& t); + + void operator()(const Unifiable::Free& t); + void operator()(const Unifiable::Generic& t); + void operator()(const Unifiable::Bound& t); + void operator()(const Unifiable::Error& t); + void operator()(const PrimitiveTypeVar& t); + void operator()(const SingletonTypeVar& t); + void operator()(const FunctionTypeVar& t); + void operator()(const TableTypeVar& t); + void operator()(const MetatableTypeVar& t); + void operator()(const ClassTypeVar& t); + void operator()(const AnyTypeVar& t); + void operator()(const UnionTypeVar& t); + void operator()(const IntersectionTypeVar& t); + void operator()(const LazyTypeVar& t); +}; + +struct TypePackCloner +{ + TypeArena& dest; + TypePackId typePackId; + SeenTypes& seenTypes; + SeenTypePacks& seenTypePacks; + CloneState& cloneState; + + TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) + : dest(dest) + , typePackId(typePackId) + , seenTypes(seenTypes) + , seenTypePacks(seenTypePacks) + , cloneState(cloneState) + { + } + + template + void defaultClone(const T& t) + { + TypePackId cloned = dest.addTypePack(TypePackVar{t}); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const Unifiable::Free& t) + { + cloneState.encounteredFreeType = true; + + TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); + TypePackId cloned = dest.addTypePack(*err); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const Unifiable::Generic& t) + { + defaultClone(t); + } + void operator()(const Unifiable::Error& t) + { + defaultClone(t); + } + + // While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter. + // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. + void operator()(const Unifiable::Bound& t) + { + TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const VariadicTypePack& t) + { + TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, cloneState)}}); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const TypePack& t) + { + TypePackId cloned = dest.addTypePack(TypePack{}); + TypePack* destTp = getMutable(cloned); + LUAU_ASSERT(destTp != nullptr); + seenTypePacks[typePackId] = cloned; + + for (TypeId ty : t.head) + destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + + if (t.tail) + destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, cloneState); + } +}; + +template +void TypeCloner::defaultClone(const T& t) +{ + TypeId cloned = dest.addType(t); + seenTypes[typeId] = cloned; +} + +void TypeCloner::operator()(const Unifiable::Free& t) +{ + cloneState.encounteredFreeType = true; + TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); + TypeId cloned = dest.addType(*err); + seenTypes[typeId] = cloned; +} + +void TypeCloner::operator()(const Unifiable::Generic& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const Unifiable::Bound& t) +{ + TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + seenTypes[typeId] = boundTo; +} + +void TypeCloner::operator()(const Unifiable::Error& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const PrimitiveTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const SingletonTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const FunctionTypeVar& t) +{ + TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); + FunctionTypeVar* ftv = getMutable(result); + LUAU_ASSERT(ftv != nullptr); + + seenTypes[typeId] = result; + + for (TypeId generic : t.generics) + ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, cloneState)); + + for (TypePackId genericPack : t.genericPacks) + ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, cloneState)); + + ftv->tags = t.tags; + ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, cloneState); + ftv->argNames = t.argNames; + ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, cloneState); +} + +void TypeCloner::operator()(const TableTypeVar& t) +{ + // If table is now bound to another one, we ignore the content of the original + if (t.boundTo) + { + TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + seenTypes[typeId] = boundTo; + return; + } + + TypeId result = dest.addType(TableTypeVar{}); + TableTypeVar* ttv = getMutable(result); + LUAU_ASSERT(ttv != nullptr); + + *ttv = t; + + seenTypes[typeId] = result; + + ttv->level = TypeLevel{0, 0}; + + for (const auto& [name, prop] : t.props) + ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + + if (t.indexer) + ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, cloneState), + clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, cloneState)}; + + for (TypeId& arg : ttv->instantiatedTypeParams) + arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); + + for (TypePackId& arg : ttv->instantiatedTypePackParams) + arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); + + if (ttv->state == TableState::Free) + { + cloneState.encounteredFreeType = true; + + ttv->state = TableState::Sealed; + } + + ttv->definitionModuleName = t.definitionModuleName; + ttv->methodDefinitionLocations = t.methodDefinitionLocations; + ttv->tags = t.tags; +} + +void TypeCloner::operator()(const MetatableTypeVar& t) +{ + TypeId result = dest.addType(MetatableTypeVar{}); + MetatableTypeVar* mtv = getMutable(result); + seenTypes[typeId] = result; + + mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, cloneState); + mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, cloneState); +} + +void TypeCloner::operator()(const ClassTypeVar& t) +{ + TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); + ClassTypeVar* ctv = getMutable(result); + + seenTypes[typeId] = result; + + for (const auto& [name, prop] : t.props) + ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + + if (t.parent) + ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, cloneState); + + if (t.metatable) + ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, cloneState); +} + +void TypeCloner::operator()(const AnyTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const UnionTypeVar& t) +{ + std::vector options; + options.reserve(t.options.size()); + + for (TypeId ty : t.options) + options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + + TypeId result = dest.addType(UnionTypeVar{std::move(options)}); + seenTypes[typeId] = result; +} + +void TypeCloner::operator()(const IntersectionTypeVar& t) +{ + TypeId result = dest.addType(IntersectionTypeVar{}); + seenTypes[typeId] = result; + + IntersectionTypeVar* option = getMutable(result); + LUAU_ASSERT(option != nullptr); + + for (TypeId ty : t.parts) + option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); +} + +void TypeCloner::operator()(const LazyTypeVar& t) +{ + defaultClone(t); +} + +} // anonymous namespace + +TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +{ + if (tp->persistent) + return tp; + + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + + TypePackId& res = seenTypePacks[tp]; + + if (res == nullptr) + { + TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks, cloneState}; + Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. + } + + return res; +} + +TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +{ + if (typeId->persistent) + return typeId; + + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + + TypeId& res = seenTypes[typeId]; + + if (res == nullptr) + { + TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; + Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. + + // Persistent types are not being cloned and we get the original type back which might be read-only + if (!res->persistent) + asMutable(res)->documentationSymbol = typeId->documentationSymbol; + } + + return res; +} + +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +{ + TypeFun result; + + for (auto param : typeFun.typeParams) + { + TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState); + std::optional defaultValue; + + if (param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + + result.typeParams.push_back({ty, defaultValue}); + } + + for (auto param : typeFun.typePackParams) + { + TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState); + std::optional defaultValue; + + if (param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + + result.typePackParams.push_back({tp, defaultValue}); + } + + result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); + + return result; +} + +} // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 210c019..5eb2ea2 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Error.h" +#include "Luau/Clone.h" #include "Luau/Module.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index d8906f6..000769f 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -2,6 +2,7 @@ #include "Luau/Frontend.h" #include "Luau/Common.h" +#include "Luau/Clone.h" #include "Luau/Config.h" #include "Luau/FileResolver.h" #include "Luau/Parser.h" @@ -16,8 +17,11 @@ #include #include +LUAU_FASTFLAG(LuauCyclicModuleTypeSurface) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) +LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false) +LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 0) namespace Luau { @@ -234,12 +238,6 @@ ErrorVec accumulateErrors( return result; } -struct RequireCycle -{ - Location location; - std::vector path; // one of the paths for a require() to go all the way back to the originating module -}; - // Given a source node (start), find all requires that start a transitive dependency path that ends back at start // For each such path, record the full path and the location of the require in the starting module. // Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) @@ -356,33 +354,55 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalsecond.dirty) + if (it != sourceNodes.end() && !it->second.isDirty(frontendOptions.forAutocomplete)) { // No recheck required. - auto it2 = moduleResolver.modules.find(name); - if (it2 == moduleResolver.modules.end() || it2->second == nullptr) - throw std::runtime_error("Frontend::modules does not have data for " + name); + if (FFlag::LuauSeparateTypechecks) + { + if (frontendOptions.forAutocomplete) + { + auto it2 = moduleResolverForAutocomplete.modules.find(name); + if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr) + throw std::runtime_error("Frontend::modules does not have data for " + name); + } + else + { + auto it2 = moduleResolver.modules.find(name); + if (it2 == moduleResolver.modules.end() || it2->second == nullptr) + throw std::runtime_error("Frontend::modules does not have data for " + name); + } - return CheckResult{accumulateErrors(sourceNodes, moduleResolver.modules, name)}; + return CheckResult{accumulateErrors( + sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)}; + } + else + { + auto it2 = moduleResolver.modules.find(name); + if (it2 == moduleResolver.modules.end() || it2->second == nullptr) + throw std::runtime_error("Frontend::modules does not have data for " + name); + + return CheckResult{accumulateErrors(sourceNodes, moduleResolver.modules, name)}; + } } std::vector buildQueue; - bool cycleDetected = parseGraph(buildQueue, checkResult, name); - - FrontendOptions frontendOptions = optionOverride.value_or(options); + bool cycleDetected = parseGraph(buildQueue, checkResult, name, frontendOptions.forAutocomplete); // Keep track of which AST nodes we've reported cycles in std::unordered_set reportedCycles; + double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0; + for (const ModuleName& moduleName : buildQueue) { LUAU_ASSERT(sourceNodes.count(moduleName)); SourceNode& sourceNode = sourceNodes[moduleName]; - if (!sourceNode.dirty) + if (!sourceNode.isDirty(frontendOptions.forAutocomplete)) continue; LUAU_ASSERT(sourceModules.count(moduleName)); @@ -408,13 +428,44 @@ CheckResult Frontend::check(const ModuleName& name, std::optionaltimeout) + checkResult.timeoutHits.push_back(moduleName); + + stats.timeCheck += getTimestamp() - timestamp; + stats.filesStrict += 1; + + sourceNode.dirtyAutocomplete = false; + continue; + } + + if (FFlag::LuauCyclicModuleTypeSurface) + typeChecker.requireCycles = requireCycles; + ModulePtr module = typeChecker.check(sourceModule, mode, environmentScope); // If we're typechecking twice, we do so. // The second typecheck is always in strict mode with DM awareness // to provide better typen information for IDE features. - if (frontendOptions.typecheckTwice) + if (!FFlag::LuauSeparateTypechecks && frontendOptions.typecheckTwice_DEPRECATED) { + if (FFlag::LuauCyclicModuleTypeSurface) + typeCheckerForAutocomplete.requireCycles = requireCycles; + ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; } @@ -467,7 +518,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& buildQueue, CheckResult& checkResult, const ModuleName& root) +bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete) { LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); @@ -486,7 +537,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec bool cyclic = false; { - auto [sourceNode, _] = getSourceNode(checkResult, root); + auto [sourceNode, _] = getSourceNode(checkResult, root, forAutocomplete); if (sourceNode) stack.push_back(sourceNode); } @@ -538,7 +589,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec // this relies on the fact that markDirty marks reverse-dependencies dirty as well // thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need // to be built, *and* can't form a cycle with any nodes we did process. - if (!it->second.dirty) + if (!it->second.isDirty(forAutocomplete)) continue; // note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization @@ -550,7 +601,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec } } - auto [sourceNode, _] = getSourceNode(checkResult, dep); + auto [sourceNode, _] = getSourceNode(checkResult, dep, forAutocomplete); if (sourceNode) { stack.push_back(sourceNode); @@ -594,7 +645,7 @@ LintResult Frontend::lint(const ModuleName& name, std::optionalsecond.dirty; + return it == sourceNodes.end() || it->second.isDirty(forAutocomplete); } /* @@ -699,8 +750,16 @@ bool Frontend::isDirty(const ModuleName& name) const */ void Frontend::markDirty(const ModuleName& name, std::vector* markedDirty) { - if (!moduleResolver.modules.count(name)) - return; + if (FFlag::LuauSeparateTypechecks) + { + if (!moduleResolver.modules.count(name) && !moduleResolverForAutocomplete.modules.count(name)) + return; + } + else + { + if (!moduleResolver.modules.count(name)) + return; + } std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) @@ -722,10 +781,21 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked if (markedDirty) markedDirty->push_back(next); - if (sourceNode.dirty) - continue; + if (FFlag::LuauSeparateTypechecks) + { + if (sourceNode.dirty && sourceNode.dirtyAutocomplete) + continue; - sourceNode.dirty = true; + sourceNode.dirty = true; + sourceNode.dirtyAutocomplete = true; + } + else + { + if (sourceNode.dirty) + continue; + + sourceNode.dirty = true; + } if (0 == reverseDeps.count(name)) continue; @@ -752,13 +822,13 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons } // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. -std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name) +std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete) { LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && !it->second.dirty) + if (it != sourceNodes.end() && !it->second.isDirty(forAutocomplete)) { auto moduleIt = sourceModules.find(name); if (moduleIt != sourceModules.end()) @@ -801,7 +871,19 @@ std::pair Frontend::getSourceNode(CheckResult& check sourceNode.name = name; sourceNode.requires.clear(); sourceNode.requireLocations.clear(); - sourceNode.dirty = true; + + if (FFlag::LuauSeparateTypechecks) + { + if (it == sourceNodes.end()) + { + sourceNode.dirty = true; + sourceNode.dirtyAutocomplete = true; + } + } + else + { + sourceNode.dirty = true; + } for (const auto& [moduleName, location] : requireTrace.requires) sourceNode.requires.insert(moduleName); diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 19c2dda..a8f6758 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -23,9 +23,178 @@ std::ostream& operator<<(std::ostream& stream, const AstName& name) return stream << ""; } -std::ostream& operator<<(std::ostream& stream, const TypeMismatch& tm) +template +static void errorToString(std::ostream& stream, const T& err) { - return stream << "TypeMismatch { " << toString(tm.wantedType) << ", " << toString(tm.givenType) << " }"; + if constexpr (false) + { + } + else if constexpr (std::is_same_v) + stream << "TypeMismatch { " << toString(err.wantedType) << ", " << toString(err.givenType) << " }"; + else if constexpr (std::is_same_v) + stream << "UnknownSymbol { " << err.name << " , context " << err.context << " }"; + else if constexpr (std::is_same_v) + stream << "UnknownProperty { " << toString(err.table) << ", key = " << err.key << " }"; + else if constexpr (std::is_same_v) + stream << "NotATable { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "CannotExtendTable { " << toString(err.tableType) << ", context " << err.context << ", prop \"" << err.prop << "\" }"; + else if constexpr (std::is_same_v) + stream << "OnlyTablesCanHaveMethods { " << toString(err.tableType) << " }"; + else if constexpr (std::is_same_v) + stream << "DuplicateTypeDefinition { " << err.name << " }"; + else if constexpr (std::is_same_v) + stream << "CountMismatch { expected " << err.expected << ", got " << err.actual << ", context " << err.context << " }"; + else if constexpr (std::is_same_v) + stream << "FunctionDoesNotTakeSelf { }"; + else if constexpr (std::is_same_v) + stream << "FunctionRequiresSelf { extraNils " << err.requiredExtraNils << " }"; + else if constexpr (std::is_same_v) + stream << "OccursCheckFailed { }"; + else if constexpr (std::is_same_v) + stream << "UnknownRequire { " << err.modulePath << " }"; + else if constexpr (std::is_same_v) + { + stream << "IncorrectGenericParameterCount { name = " << err.name; + + if (!err.typeFun.typeParams.empty() || !err.typeFun.typePackParams.empty()) + { + stream << "<"; + bool first = true; + for (auto param : err.typeFun.typeParams) + { + if (first) + first = false; + else + stream << ", "; + + stream << toString(param.ty); + } + + for (auto param : err.typeFun.typePackParams) + { + if (first) + first = false; + else + stream << ", "; + + stream << toString(param.tp); + } + + stream << ">"; + } + + stream << ", typeFun = " << toString(err.typeFun.type) << ", actualCount = " << err.actualParameters << " }"; + } + else if constexpr (std::is_same_v) + stream << "SyntaxError { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "CodeTooComplex {}"; + else if constexpr (std::is_same_v) + stream << "UnificationTooComplex {}"; + else if constexpr (std::is_same_v) + { + stream << "UnknownPropButFoundLikeProp { key = '" << err.key << "', suggested = { "; + + bool first = true; + for (Name name : err.candidates) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << name << "'"; + } + + stream << " }, table = " << toString(err.table) << " } "; + } + else if constexpr (std::is_same_v) + stream << "GenericError { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "CannotCallNonFunction { " << toString(err.ty) << " }"; + else if constexpr (std::is_same_v) + stream << "ExtraInformation { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "DeprecatedApiUsed { " << err.symbol << ", useInstead = " << err.useInstead << " }"; + else if constexpr (std::is_same_v) + { + stream << "ModuleHasCyclicDependency {"; + + bool first = true; + for (const ModuleName& name : err.cycle) + { + if (first) + first = false; + else + stream << ", "; + + stream << name; + } + + stream << "}"; + } + else if constexpr (std::is_same_v) + stream << "IllegalRequire { " << err.moduleName << ", reason = " << err.reason << " }"; + else if constexpr (std::is_same_v) + stream << "FunctionExitsWithoutReturning {" << toString(err.expectedReturnType) << "}"; + else if constexpr (std::is_same_v) + stream << "DuplicateGenericParameter { " + err.parameterName + " }"; + else if constexpr (std::is_same_v) + stream << "CannotInferBinaryOperation { op = " + toString(err.op) + ", suggested = '" + + (err.suggestedToAnnotate ? *err.suggestedToAnnotate : "") + "', kind " + << err.kind << "}"; + else if constexpr (std::is_same_v) + { + stream << "MissingProperties { superType = '" << toString(err.superType) << "', subType = '" << toString(err.subType) << "', properties = { "; + + bool first = true; + for (Name name : err.properties) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << name << "'"; + } + + stream << " }, context " << err.context << " } "; + } + else if constexpr (std::is_same_v) + stream << "SwappedGenericTypeParameter { name = '" + err.name + "', kind = " + std::to_string(err.kind) + " }"; + else if constexpr (std::is_same_v) + stream << "OptionalValueAccess { optional = '" + toString(err.optional) + "' }"; + else if constexpr (std::is_same_v) + { + stream << "MissingUnionProperty { type = '" + toString(err.type) + "', missing = { "; + + bool first = true; + for (auto ty : err.missing) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << toString(ty) << "'"; + } + + stream << " }, key = '" + err.key + "' }"; + } + else if constexpr (std::is_same_v) + stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }"; + else + static_assert(always_false_v, "Non-exhaustive type switch"); +} + +std::ostream& operator<<(std::ostream& stream, const TypeErrorData& data) +{ + auto cb = [&](const auto& e) { + return errorToString(stream, e); + }; + visit(cb, data); + return stream; } std::ostream& operator<<(std::ostream& stream, const TypeError& error) @@ -33,241 +202,6 @@ std::ostream& operator<<(std::ostream& stream, const TypeError& error) return stream << "TypeError { \"" << error.moduleName << "\", " << error.location << ", " << error.data << " }"; } -std::ostream& operator<<(std::ostream& stream, const UnknownSymbol& error) -{ - return stream << "UnknownSymbol { " << error.name << " , context " << error.context << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const UnknownProperty& error) -{ - return stream << "UnknownProperty { " << toString(error.table) << ", key = " << error.key << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const NotATable& ge) -{ - return stream << "NotATable { " << toString(ge.ty) << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CannotExtendTable& error) -{ - return stream << "CannotExtendTable { " << toString(error.tableType) << ", context " << error.context << ", prop \"" << error.prop << "\" }"; -} - -std::ostream& operator<<(std::ostream& stream, const OnlyTablesCanHaveMethods& error) -{ - return stream << "OnlyTablesCanHaveMethods { " << toString(error.tableType) << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const DuplicateTypeDefinition& error) -{ - return stream << "DuplicateTypeDefinition { " << error.name << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CountMismatch& error) -{ - return stream << "CountMismatch { expected " << error.expected << ", got " << error.actual << ", context " << error.context << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const FunctionDoesNotTakeSelf&) -{ - return stream << "FunctionDoesNotTakeSelf { }"; -} - -std::ostream& operator<<(std::ostream& stream, const FunctionRequiresSelf& error) -{ - return stream << "FunctionRequiresSelf { extraNils " << error.requiredExtraNils << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const OccursCheckFailed&) -{ - return stream << "OccursCheckFailed { }"; -} - -std::ostream& operator<<(std::ostream& stream, const UnknownRequire& error) -{ - return stream << "UnknownRequire { " << error.modulePath << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCount& error) -{ - stream << "IncorrectGenericParameterCount { name = " << error.name; - - if (!error.typeFun.typeParams.empty() || !error.typeFun.typePackParams.empty()) - { - stream << "<"; - bool first = true; - for (auto param : error.typeFun.typeParams) - { - if (first) - first = false; - else - stream << ", "; - - stream << toString(param.ty); - } - - for (auto param : error.typeFun.typePackParams) - { - if (first) - first = false; - else - stream << ", "; - - stream << toString(param.tp); - } - - stream << ">"; - } - - stream << ", typeFun = " << toString(error.typeFun.type) << ", actualCount = " << error.actualParameters << " }"; - return stream; -} - -std::ostream& operator<<(std::ostream& stream, const SyntaxError& ge) -{ - return stream << "SyntaxError { " << ge.message << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CodeTooComplex&) -{ - return stream << "CodeTooComplex {}"; -} - -std::ostream& operator<<(std::ostream& stream, const UnificationTooComplex&) -{ - return stream << "UnificationTooComplex {}"; -} - -std::ostream& operator<<(std::ostream& stream, const UnknownPropButFoundLikeProp& e) -{ - stream << "UnknownPropButFoundLikeProp { key = '" << e.key << "', suggested = { "; - - bool first = true; - for (Name name : e.candidates) - { - if (first) - first = false; - else - stream << ", "; - - stream << "'" << name << "'"; - } - - return stream << " }, table = " << toString(e.table) << " } "; -} - -std::ostream& operator<<(std::ostream& stream, const GenericError& ge) -{ - return stream << "GenericError { " << ge.message << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CannotCallNonFunction& e) -{ - return stream << "CannotCallNonFunction { " << toString(e.ty) << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const FunctionExitsWithoutReturning& error) -{ - return stream << "FunctionExitsWithoutReturning {" << toString(error.expectedReturnType) << "}"; -} - -std::ostream& operator<<(std::ostream& stream, const ExtraInformation& e) -{ - return stream << "ExtraInformation { " << e.message << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const DeprecatedApiUsed& e) -{ - return stream << "DeprecatedApiUsed { " << e.symbol << ", useInstead = " << e.useInstead << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const ModuleHasCyclicDependency& e) -{ - stream << "ModuleHasCyclicDependency {"; - - bool first = true; - for (const ModuleName& name : e.cycle) - { - if (first) - first = false; - else - stream << ", "; - - stream << name; - } - - return stream << "}"; -} - -std::ostream& operator<<(std::ostream& stream, const IllegalRequire& e) -{ - return stream << "IllegalRequire { " << e.moduleName << ", reason = " << e.reason << " }"; -} - -std::ostream& operator<<(std::ostream& stream, const MissingProperties& e) -{ - stream << "MissingProperties { superType = '" << toString(e.superType) << "', subType = '" << toString(e.subType) << "', properties = { "; - - bool first = true; - for (Name name : e.properties) - { - if (first) - first = false; - else - stream << ", "; - - stream << "'" << name << "'"; - } - - return stream << " }, context " << e.context << " } "; -} - -std::ostream& operator<<(std::ostream& stream, const DuplicateGenericParameter& error) -{ - return stream << "DuplicateGenericParameter { " + error.parameterName + " }"; -} - -std::ostream& operator<<(std::ostream& stream, const CannotInferBinaryOperation& error) -{ - return stream << "CannotInferBinaryOperation { op = " + toString(error.op) + ", suggested = '" + - (error.suggestedToAnnotate ? *error.suggestedToAnnotate : "") + "', kind " - << error.kind << "}"; -} - -std::ostream& operator<<(std::ostream& stream, const SwappedGenericTypeParameter& error) -{ - return stream << "SwappedGenericTypeParameter { name = '" + error.name + "', kind = " + std::to_string(error.kind) + " }"; -} - -std::ostream& operator<<(std::ostream& stream, const OptionalValueAccess& error) -{ - return stream << "OptionalValueAccess { optional = '" + toString(error.optional) + "' }"; -} - -std::ostream& operator<<(std::ostream& stream, const MissingUnionProperty& error) -{ - stream << "MissingUnionProperty { type = '" + toString(error.type) + "', missing = { "; - - bool first = true; - for (auto ty : error.missing) - { - if (first) - first = false; - else - stream << ", "; - - stream << "'" << toString(ty) << "'"; - } - - return stream << " }, key = '" + error.key + "' }"; -} - -std::ostream& operator<<(std::ostream& stream, const TypesAreUnrelated& error) -{ - stream << "TypesAreUnrelated { left = '" + toString(error.left) + "', right = '" + toString(error.right) + "' }"; - return stream; -} - std::ostream& operator<<(std::ostream& stream, const TableState& tv) { return stream << static_cast::type>(tv); @@ -283,15 +217,4 @@ std::ostream& operator<<(std::ostream& stream, const TypePackVar& tv) return stream << toString(tv); } -std::ostream& operator<<(std::ostream& lhs, const TypeErrorData& ted) -{ - Luau::visit( - [&](const auto& a) { - lhs << a; - }, - ted); - - return lhs; -} - } // namespace Luau diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 0787d3a..6bb4524 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -2,6 +2,7 @@ #include "Luau/Module.h" #include "Luau/Common.h" +#include "Luau/Clone.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" @@ -12,7 +13,6 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) -LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAGVARIABLE(LuauCloneDeclaredGlobals, false) namespace Luau @@ -113,363 +113,6 @@ TypePackId TypeArena::addTypePack(TypePackVar tp) return allocated; } -namespace -{ - -struct TypePackCloner; - -/* - * Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set. - * They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage. - */ - -struct TypeCloner -{ - TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) - : dest(dest) - , typeId(typeId) - , seenTypes(seenTypes) - , seenTypePacks(seenTypePacks) - , cloneState(cloneState) - { - } - - TypeArena& dest; - TypeId typeId; - SeenTypes& seenTypes; - SeenTypePacks& seenTypePacks; - CloneState& cloneState; - - template - void defaultClone(const T& t); - - void operator()(const Unifiable::Free& t); - void operator()(const Unifiable::Generic& t); - void operator()(const Unifiable::Bound& t); - void operator()(const Unifiable::Error& t); - void operator()(const PrimitiveTypeVar& t); - void operator()(const SingletonTypeVar& t); - void operator()(const FunctionTypeVar& t); - void operator()(const TableTypeVar& t); - void operator()(const MetatableTypeVar& t); - void operator()(const ClassTypeVar& t); - void operator()(const AnyTypeVar& t); - void operator()(const UnionTypeVar& t); - void operator()(const IntersectionTypeVar& t); - void operator()(const LazyTypeVar& t); -}; - -struct TypePackCloner -{ - TypeArena& dest; - TypePackId typePackId; - SeenTypes& seenTypes; - SeenTypePacks& seenTypePacks; - CloneState& cloneState; - - TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) - : dest(dest) - , typePackId(typePackId) - , seenTypes(seenTypes) - , seenTypePacks(seenTypePacks) - , cloneState(cloneState) - { - } - - template - void defaultClone(const T& t) - { - TypePackId cloned = dest.addTypePack(TypePackVar{t}); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const Unifiable::Free& t) - { - cloneState.encounteredFreeType = true; - - TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); - TypePackId cloned = dest.addTypePack(*err); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const Unifiable::Generic& t) - { - defaultClone(t); - } - void operator()(const Unifiable::Error& t) - { - defaultClone(t); - } - - // While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter. - // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. - void operator()(const Unifiable::Bound& t) - { - TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const VariadicTypePack& t) - { - TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, cloneState)}}); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const TypePack& t) - { - TypePackId cloned = dest.addTypePack(TypePack{}); - TypePack* destTp = getMutable(cloned); - LUAU_ASSERT(destTp != nullptr); - seenTypePacks[typePackId] = cloned; - - for (TypeId ty : t.head) - destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - - if (t.tail) - destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, cloneState); - } -}; - -template -void TypeCloner::defaultClone(const T& t) -{ - TypeId cloned = dest.addType(t); - seenTypes[typeId] = cloned; -} - -void TypeCloner::operator()(const Unifiable::Free& t) -{ - cloneState.encounteredFreeType = true; - TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); - TypeId cloned = dest.addType(*err); - seenTypes[typeId] = cloned; -} - -void TypeCloner::operator()(const Unifiable::Generic& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const Unifiable::Bound& t) -{ - TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); - seenTypes[typeId] = boundTo; -} - -void TypeCloner::operator()(const Unifiable::Error& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const PrimitiveTypeVar& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const SingletonTypeVar& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const FunctionTypeVar& t) -{ - TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); - FunctionTypeVar* ftv = getMutable(result); - LUAU_ASSERT(ftv != nullptr); - - seenTypes[typeId] = result; - - for (TypeId generic : t.generics) - ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, cloneState)); - - for (TypePackId genericPack : t.genericPacks) - ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, cloneState)); - - ftv->tags = t.tags; - ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, cloneState); - ftv->argNames = t.argNames; - ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, cloneState); -} - -void TypeCloner::operator()(const TableTypeVar& t) -{ - // If table is now bound to another one, we ignore the content of the original - if (t.boundTo) - { - TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, cloneState); - seenTypes[typeId] = boundTo; - return; - } - - TypeId result = dest.addType(TableTypeVar{}); - TableTypeVar* ttv = getMutable(result); - LUAU_ASSERT(ttv != nullptr); - - *ttv = t; - - seenTypes[typeId] = result; - - ttv->level = TypeLevel{0, 0}; - - for (const auto& [name, prop] : t.props) - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; - - if (t.indexer) - ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, cloneState), - clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, cloneState)}; - - for (TypeId& arg : ttv->instantiatedTypeParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); - - for (TypePackId& arg : ttv->instantiatedTypePackParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); - - if (ttv->state == TableState::Free) - { - cloneState.encounteredFreeType = true; - - ttv->state = TableState::Sealed; - } - - ttv->definitionModuleName = t.definitionModuleName; - ttv->methodDefinitionLocations = t.methodDefinitionLocations; - ttv->tags = t.tags; -} - -void TypeCloner::operator()(const MetatableTypeVar& t) -{ - TypeId result = dest.addType(MetatableTypeVar{}); - MetatableTypeVar* mtv = getMutable(result); - seenTypes[typeId] = result; - - mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, cloneState); - mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, cloneState); -} - -void TypeCloner::operator()(const ClassTypeVar& t) -{ - TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); - ClassTypeVar* ctv = getMutable(result); - - seenTypes[typeId] = result; - - for (const auto& [name, prop] : t.props) - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; - - if (t.parent) - ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, cloneState); - - if (t.metatable) - ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, cloneState); -} - -void TypeCloner::operator()(const AnyTypeVar& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const UnionTypeVar& t) -{ - std::vector options; - options.reserve(t.options.size()); - - for (TypeId ty : t.options) - options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - - TypeId result = dest.addType(UnionTypeVar{std::move(options)}); - seenTypes[typeId] = result; -} - -void TypeCloner::operator()(const IntersectionTypeVar& t) -{ - TypeId result = dest.addType(IntersectionTypeVar{}); - seenTypes[typeId] = result; - - IntersectionTypeVar* option = getMutable(result); - LUAU_ASSERT(option != nullptr); - - for (TypeId ty : t.parts) - option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); -} - -void TypeCloner::operator()(const LazyTypeVar& t) -{ - defaultClone(t); -} - -} // anonymous namespace - -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) -{ - if (tp->persistent) - return tp; - - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - - TypePackId& res = seenTypePacks[tp]; - - if (res == nullptr) - { - TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks, cloneState}; - Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. - } - - return res; -} - -TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) -{ - if (typeId->persistent) - return typeId; - - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - - TypeId& res = seenTypes[typeId]; - - if (res == nullptr) - { - TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; - Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. - - // Persistent types are not being cloned and we get the original type back which might be read-only - if (!res->persistent) - asMutable(res)->documentationSymbol = typeId->documentationSymbol; - } - - return res; -} - -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) -{ - TypeFun result; - - for (auto param : typeFun.typeParams) - { - TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState); - std::optional defaultValue; - - if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); - - result.typeParams.push_back({ty, defaultValue}); - } - - for (auto param : typeFun.typePackParams) - { - TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState); - std::optional defaultValue; - - if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); - - result.typePackParams.push_back({tp, defaultValue}); - } - - result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); - - return result; -} - ScopePtr Module::getModuleScope() const { LUAU_ASSERT(!scopes.empty()); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 876f5f0..5fbb596 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,6 +7,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauTxnLogPreserveOwner, false) + namespace Luau { @@ -78,11 +80,32 @@ void TxnLog::concat(TxnLog rhs) void TxnLog::commit() { - for (auto& [ty, rep] : typeVarChanges) - *asMutable(ty) = rep.get()->pending; + if (FFlag::LuauTxnLogPreserveOwner) + { + for (auto& [ty, rep] : typeVarChanges) + { + TypeArena* owningArena = ty->owningArena; + TypeVar* mtv = asMutable(ty); + *mtv = rep.get()->pending; + mtv->owningArena = owningArena; + } - for (auto& [tp, rep] : typePackChanges) - *asMutable(tp) = rep.get()->pending; + for (auto& [tp, rep] : typePackChanges) + { + TypeArena* owningArena = tp->owningArena; + TypePackVar* mpv = asMutable(tp); + *mpv = rep.get()->pending; + mpv->owningArena = owningArena; + } + } + else + { + for (auto& [ty, rep] : typeVarChanges) + *asMutable(ty) = rep.get()->pending; + + for (auto& [tp, rep] : typePackChanges) + *asMutable(tp) = rep.get()->pending; + } clear(); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 6df6bff..1093024 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -22,6 +22,9 @@ LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) +LUAU_FASTFLAG(LuauSeparateTypechecks) +LUAU_FASTFLAG(LuauAutocompleteSingletonTypes) +LUAU_FASTFLAGVARIABLE(LuauCyclicModuleTypeSurface, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) @@ -35,7 +38,7 @@ LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify3, false) +LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAG(LuauTypeMismatchModuleName) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) @@ -46,6 +49,7 @@ LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) LUAU_FASTFLAGVARIABLE(LuauFixArgumentCountMismatchAmountWithGenericTypes, false) LUAU_FASTFLAGVARIABLE(LuauFixIncorrectLineNumberDuplicateType, false) +LUAU_FASTFLAGVARIABLE(LuauCheckImplicitNumbericKeys, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false) LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false) @@ -53,6 +57,11 @@ LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false) namespace Luau { +const char* TimeLimitError::what() const throw() +{ + return "Typeinfer failed to complete in allotted time"; +} + static bool typeCouldHaveMetatable(TypeId ty) { return get(follow(ty)) || get(follow(ty)) || get(follow(ty)); @@ -251,6 +260,12 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona currentModule.reset(new Module()); currentModule->type = module.type; + if (FFlag::LuauSeparateTypechecks) + { + currentModule->allocator = module.allocator; + currentModule->names = module.names; + } + iceHandler->moduleName = module.name; ScopePtr parentScope = environmentScope.value_or(globalScope); @@ -271,7 +286,21 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona if (prepareModuleScope) prepareModuleScope(module.name, currentModule->getModuleScope()); - checkBlock(moduleScope, *module.root); + if (FFlag::LuauSeparateTypechecks) + { + try + { + checkBlock(moduleScope, *module.root); + } + catch (const TimeLimitError&) + { + currentModule->timeout = true; + } + } + else + { + checkBlock(moduleScope, *module.root); + } if (get(follow(moduleScope->returnType))) moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); @@ -366,6 +395,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) } else ice("Unknown AstStat"); + + if (FFlag::LuauSeparateTypechecks && finishTime && TimeTrace::getClock() > *finishTime) + throw TimeLimitError(); } // This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly. @@ -1115,22 +1147,18 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; return; } - else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify3) + else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify4) { TypeId exprTy = checkExpr(scope, *name->expr).type; TableTypeVar* ttv = getMutableTableType(exprTy); - if (!ttv) + + if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false)) { - if (isTableIntersection(exprTy)) + if (ttv || isTableIntersection(exprTy)) reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); - else if (!get(exprTy) && !get(exprTy)) + else reportError(TypeError{function.location, OnlyTablesCanHaveMethods{exprTy}}); } - else if (ttv->state == TableState::Sealed) - { - if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false)) - reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); - } ty = follow(ty); @@ -1153,7 +1181,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco if (ttv && ttv->state != TableState::Sealed) ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; } - else if (FFlag::LuauStatFunctionSimplify3) + else if (FFlag::LuauStatFunctionSimplify4) { LUAU_ASSERT(function.name->is()); @@ -1163,7 +1191,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else if (function.func->self) { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify3); + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify4); AstExprIndexName* indexName = function.name->as(); if (!indexName) @@ -1202,7 +1230,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify3); + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify4); TypeId leftType = checkLValueBinding(scope, *function.name); @@ -2030,7 +2058,11 @@ TypeId TypeChecker::checkExprTable( indexer = expectedTable->indexer; if (indexer) + { + if (FFlag::LuauCheckImplicitNumbericKeys) + unify(numberType, indexer->indexType, value->location); unify(valueType, indexer->indexResultType, value->location); + } else indexer = TableIndexer{numberType, anyIfNonstrict(valueType)}; } @@ -2984,35 +3016,33 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T else if (auto indexName = funName.as()) { TypeId lhsType = checkExpr(scope, *indexName->expr).type; - if (get(lhsType) || get(lhsType)) + + if (!FFlag::LuauStatFunctionSimplify4 && (get(lhsType) || get(lhsType))) return lhsType; TableTypeVar* ttv = getMutableTableType(lhsType); - if (!ttv) + + if (FFlag::LuauStatFunctionSimplify4) { - if (!FFlag::LuauErrorRecoveryType && !isTableIntersection(lhsType)) - // This error now gets reported when we check the function body. - reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); - - return errorRecoveryType(scope); - } - - if (FFlag::LuauStatFunctionSimplify3) - { - if (lhsType->persistent) - return errorRecoveryType(scope); - - // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check - if (ttv->state == TableState::Sealed) + if (!ttv || ttv->state == TableState::Sealed) { - if (ttv->indexer && isPrim(ttv->indexer->indexType, PrimitiveTypeVar::String)) - return ttv->indexer->indexResultType; - else - return errorRecoveryType(scope); + if (auto ty = getIndexTypeFromType(scope, lhsType, indexName->index.value, indexName->indexLocation, false)) + return *ty; + + return errorRecoveryType(scope); } } else { + if (!ttv) + { + if (!FFlag::LuauErrorRecoveryType && !isTableIntersection(lhsType)) + // This error now gets reported when we check the function body. + reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); + + return errorRecoveryType(scope); + } + if (lhsType->persistent || ttv->state == TableState::Sealed) return errorRecoveryType(scope); } @@ -3020,7 +3050,12 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T Name name = indexName->index.value; if (ttv->props.count(name)) - return errorRecoveryType(scope); + { + if (FFlag::LuauStatFunctionSimplify4) + return ttv->props[name].type; + else + return errorRecoveryType(scope); + } Property& property = ttv->props[name]; @@ -4155,6 +4190,20 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module return anyType; } + // Types of requires that transitively refer to current module have to be replaced with 'any' + std::string humanReadableName; + + if (FFlag::LuauCyclicModuleTypeSurface) + { + humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); + + for (const auto& [location, path] : requireCycles) + { + if (!path.empty() && path.front() == humanReadableName) + return anyType; + } + } + ModulePtr module = resolver->getModule(moduleInfo.name); if (!module) { @@ -4163,8 +4212,15 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module // we will already have reported the error. if (!resolver->moduleExists(moduleInfo.name) && !moduleInfo.optional) { - std::string reportedModulePath = resolver->getHumanReadableModuleName(moduleInfo.name); - reportError(TypeError{location, UnknownRequire{reportedModulePath}}); + if (FFlag::LuauCyclicModuleTypeSurface) + { + reportError(TypeError{location, UnknownRequire{humanReadableName}}); + } + else + { + std::string reportedModulePath = resolver->getHumanReadableModuleName(moduleInfo.name); + reportError(TypeError{location, UnknownRequire{reportedModulePath}}); + } } return errorRecoveryType(scope); @@ -4172,8 +4228,15 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module if (module->type != SourceCode::Module) { - std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); - reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); + if (FFlag::LuauCyclicModuleTypeSurface) + { + reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); + } + else + { + std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); + reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); + } return errorRecoveryType(scope); } @@ -4185,8 +4248,15 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module std::optional moduleType = first(modulePack); if (!moduleType) { - std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); - reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); + if (FFlag::LuauCyclicModuleTypeSurface) + { + reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); + } + else + { + std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); + reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); + } return errorRecoveryType(scope); } @@ -4629,7 +4699,9 @@ TypeId TypeChecker::freshType(TypeLevel level) TypeId TypeChecker::singletonType(bool value) { - // TODO: cache singleton types + if (FFlag::LuauAutocompleteSingletonTypes) + return value ? getSingletonTypes().trueType : getSingletonTypes().falseType; + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BooleanSingleton{value}))); } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 36545ad..dbc412f 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -652,6 +652,8 @@ static TypeVar numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persist static TypeVar stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true}; static TypeVar booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true}; static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}; +static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true}; +static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}; static TypeVar anyType_{AnyTypeVar{}}; static TypeVar errorType_{ErrorTypeVar{}}; static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}}; @@ -665,6 +667,8 @@ SingletonTypes::SingletonTypes() , stringType(&stringType_) , booleanType(&booleanType_) , threadType(&threadType_) + , trueType(&trueType_) + , falseType(&falseType_) , anyType(&anyType_) , optionalNumberType(&optionalNumberType_) , anyTypePack(&anyTypePack_) diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index 5018456..9f7b2bd 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -9,14 +9,21 @@ LUAU_FASTFLAG(DebugLuauTimeTracing) +namespace Luau +{ +namespace TimeTrace +{ +double getClock(); +uint32_t getClockMicroseconds(); +} // namespace TimeTrace +} // namespace Luau + #if defined(LUAU_ENABLE_TIME_TRACE) namespace Luau { namespace TimeTrace { -uint32_t getClockMicroseconds(); - struct Token { const char* name; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index f6dfd90..f9d3217 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,6 +10,7 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) +LUAU_FASTFLAGVARIABLE(LuauParseRecoverUnexpectedPack, false) namespace Luau { @@ -1420,6 +1421,11 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isIntersection = true; } + else if (FFlag::LuauParseRecoverUnexpectedPack && c == Lexeme::Dot3) + { + report(lexer.current().location, "Unexpected '...' after type annotation"); + nextLexeme(); + } else break; } @@ -1536,6 +1542,11 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) prefix = name.name; name = parseIndexName("field name", pointPosition); } + else if (FFlag::LuauParseRecoverUnexpectedPack && lexer.current().type == Lexeme::Dot3) + { + report(lexer.current().location, "Unexpected '...' after type name; type pack is not allowed in this context"); + nextLexeme(); + } else if (name.name == "typeof") { Lexeme typeofBegin = lexer.current(); diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index 19564f0..e380768 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -26,9 +26,6 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauTimeTracing, false) - -#if defined(LUAU_ENABLE_TIME_TRACE) - namespace Luau { namespace TimeTrace @@ -67,6 +64,14 @@ static double getClockTimestamp() #endif } +double getClock() +{ + static double period = getClockPeriod(); + static double start = getClockTimestamp(); + + return (getClockTimestamp() - start) * period; +} + uint32_t getClockMicroseconds() { static double period = getClockPeriod() * 1e6; @@ -74,7 +79,15 @@ uint32_t getClockMicroseconds() return uint32_t((getClockTimestamp() - start) * period); } +} // namespace TimeTrace +} // namespace Luau +#if defined(LUAU_ENABLE_TIME_TRACE) + +namespace Luau +{ +namespace TimeTrace +{ struct GlobalContext { GlobalContext() = default; diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 60a7c16..35ea0bf 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -290,7 +290,8 @@ struct ConstantVisitor : AstVisitor Constant la = analyze(expr->left); Constant ra = analyze(expr->right); - if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + // note: ra doesn't need to be constant to fold and/or + if (la.type != Constant::Type_Unknown) foldBinary(result, expr->op, la, ra); } else if (AstExprTypeAssertion* expr = node->as()) diff --git a/Sources.cmake b/Sources.cmake index 59b3849..6f110f1 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -47,6 +47,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Autocomplete.h Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/Config.h + Analysis/include/Luau/Clone.h Analysis/include/Luau/Documentation.h Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h @@ -85,6 +86,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Autocomplete.cpp Analysis/src/BuiltinDefinitions.cpp Analysis/src/Config.cpp + Analysis/src/Clone.cpp Analysis/src/Error.cpp Analysis/src/Frontend.cpp Analysis/src/IostreamHelpers.cpp diff --git a/VM/src/lfunc.h b/VM/src/lfunc.h index 8047ceb..a260d00 100644 --- a/VM/src/lfunc.h +++ b/VM/src/lfunc.h @@ -14,6 +14,6 @@ LUAI_FUNC UpVal* luaF_findupval(lua_State* L, StkId level); LUAI_FUNC void luaF_close(lua_State* L, StkId level); LUAI_FUNC void luaF_freeproto(lua_State* L, Proto* f, struct lua_Page* page); LUAI_FUNC void luaF_freeclosure(lua_State* L, Closure* c, struct lua_Page* page); -void luaF_unlinkupval(UpVal* uv); +LUAI_FUNC void luaF_unlinkupval(UpVal* uv); LUAI_FUNC void luaF_freeupval(lua_State* L, UpVal* uv, struct lua_Page* page); LUAI_FUNC const LocVar* luaF_getlocal(const Proto* func, int local_number, int pc); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 241a99e..41887f4 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -201,7 +201,7 @@ static int tmove(lua_State* L) void (*telemetrycb)(lua_State* L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry; - if (DFFlag::LuauTableMoveTelemetry2 && telemetrycb) + if (DFFlag::LuauTableMoveTelemetry2 && telemetrycb && e >= f) { int nf = lua_objlen(L, 1); int nt = lua_objlen(L, tt); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 4e8a1d5..2e7902f 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) +LUAU_FASTFLAG(LuauSeparateTypechecks) using namespace Luau; @@ -25,6 +26,11 @@ static std::optional nullCallback(std::string tag, std::op template struct ACFixtureImpl : BaseType { + ACFixtureImpl() + : Fixture(true, true) + { + } + AutocompleteResult autocomplete(unsigned row, unsigned column) { return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); @@ -72,7 +78,25 @@ struct ACFixtureImpl : BaseType } LUAU_ASSERT("Digit expected after @ symbol" && prevChar != '@'); - return Fixture::check(filteredSource); + return BaseType::check(filteredSource); + } + + LoadDefinitionFileResult loadDefinition(const std::string& source) + { + if (FFlag::LuauSeparateTypechecks) + { + TypeChecker& typeChecker = this->frontend.typeCheckerForAutocomplete; + unfreeze(typeChecker.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, source, "@test"); + freeze(typeChecker.globalTypes); + + REQUIRE_MESSAGE(result.success, "loadDefinition: unable to load definition file"); + return result; + } + else + { + return BaseType::loadDefinition(source); + } } const Position& getPosition(char marker) const @@ -2496,7 +2520,7 @@ local t = { CHECK(ac.entryMap.count("second")); } -TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") +TEST_CASE_FIXTURE(ACFixture, "autocomplete_documentation_symbols") { loadDefinition(R"( declare y: { @@ -2504,13 +2528,11 @@ TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") } )"); - fileResolver.source["Module/A"] = R"( - local a = y. - )"; + check(R"( + local a = y.@1 + )"); - frontend.check("Module/A"); - - auto ac = autocomplete(frontend, "Module/A", Position{1, 21}, nullCallback); + auto ac = autocomplete('1'); REQUIRE(ac.entryMap.count("x")); CHECK_EQ(ac.entryMap["x"].documentationSymbol, "@test/global/y.x"); @@ -2736,6 +2758,107 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") CHECK(ac.entryMap.count("format")); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") +{ + ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; + ScopedFastFlag luauExpectedTypesOfProperties{"LuauExpectedTypesOfProperties", true}; + + check(R"( + type tag = "cat" | "dog" + local function f(a: tag) end + f("@1") + f(@2) + local x: tag = "@3" + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("\"cat\"")); + CHECK(ac.entryMap.count("\"dog\"")); + + ac = autocomplete('3'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + + check(R"( + type tagged = {tag:"cat", fieldx:number} | {tag:"dog", fieldy:number} + local x: tagged = {tag="@4"} + )"); + + ac = autocomplete('4'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality") +{ + ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; + + check(R"( + type tagged = {tag:"cat", fieldx:number} | {tag:"dog", fieldy:number} + local x: tagged = {tag="cat", fieldx=2} + if x.tag == "@1" or "@2" ~= x.tag then end + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("cat")); + CHECK(ac.entryMap.count("dog")); + + // CLI-48823: assignment to x.tag should also autocomplete, but union l-values are not supported yet +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_boolean_singleton") +{ + ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; + + check(R"( +local function f(x: true) end +f(@1) + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("true")); + CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); + REQUIRE(ac.entryMap.count("false")); + CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_escape") +{ + ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; + + check(R"( + type tag = "strange\t\"cat\"" | 'nice\t"dog"' + local function f(x: tag) end + f(@1) + f("@2") + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("\"strange\\t\\\"cat\\\"\"")); + CHECK(ac.entryMap.count("\"nice\\t\\\"dog\\\"\"")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("strange\\t\\\"cat\\\"")); + CHECK(ac.entryMap.count("nice\\t\\\"dog\\\"")); +} + TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses_2") { check(R"( diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 3dc57da..83dad72 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -1074,6 +1074,39 @@ RETURN R1 1 )"); } +TEST_CASE("AndOrFoldLeft") +{ + // constant folding and/or expression is possible even if just the left hand is constant + CHECK_EQ("\n" + compileFunction0("local a = false if a and b then b() end"), R"( +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = true if a or b then b() end"), R"( +GETIMPORT R0 1 +CALL R0 0 0 +RETURN R0 0 +)"); + + // however, if right hand side is constant we can't constant fold the entire expression + // (note that we don't need to evaluate the right hand side, but we do need a branch) + CHECK_EQ("\n" + compileFunction0("local a = false if b and a then b() end"), R"( +GETIMPORT R0 1 +JUMPIFNOT R0 +4 +RETURN R0 0 +GETIMPORT R0 1 +CALL R0 0 0 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = true if b or a then b() end"), R"( +GETIMPORT R0 1 +JUMPIF R0 +0 +GETIMPORT R0 1 +CALL R0 0 0 +RETURN R0 0 +)"); +} + TEST_CASE("AndOrChainCodegen") { const char* source = R"( diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index a7e7ea3..9dc9fee 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -83,7 +83,7 @@ std::optional TestFileResolver::getEnvironmentForModule(const Modul return std::nullopt; } -Fixture::Fixture(bool freeze) +Fixture::Fixture(bool freeze, bool prepareAutocomplete) : sff_DebugLuauFreezeArena("DebugLuauFreezeArena", freeze) , frontend(&fileResolver, &configResolver, {/* retainFullTypeGraphs= */ true}) , typeChecker(frontend.typeChecker) @@ -93,8 +93,11 @@ Fixture::Fixture(bool freeze) configResolver.defaultConfig.parseOptions.captureComments = true; registerBuiltinTypes(frontend.typeChecker); + if (prepareAutocomplete) + registerBuiltinTypes(frontend.typeCheckerForAutocomplete); registerTestTypes(); Luau::freeze(frontend.typeChecker.globalTypes); + Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); Luau::setPrintLine([](auto s) {}); } diff --git a/tests/Fixture.h b/tests/Fixture.h index 4e45a95..0d1233b 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -91,7 +91,7 @@ struct TestConfigResolver : ConfigResolver struct Fixture { - explicit Fixture(bool freeze = true); + explicit Fixture(bool freeze = true, bool prepareAutocomplete = false); ~Fixture(); // Throws Luau::ParseErrors if the parse fails. diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 8a59acd..9fc0a00 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -384,6 +384,70 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_error_paths") CHECK_EQ(ce2->cycle[1], "game/Gui/Modules/A"); } +TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface") +{ + ScopedFastFlag luauCyclicModuleTypeSurface{"LuauCyclicModuleTypeSurface", true}; + + fileResolver.source["game/A"] = R"( + return {hello = 2} + )"; + + CheckResult result = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(result); + + fileResolver.source["game/A"] = R"( + local me = require(game.A) + return {hello = 2} + )"; + frontend.markDirty("game/A"); + + result = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(result); + + auto ty = requireType("game/A", "me"); + CHECK_EQ(toString(ty), "any"); +} + +TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_longer") +{ + ScopedFastFlag luauCyclicModuleTypeSurface{"LuauCyclicModuleTypeSurface", true}; + + fileResolver.source["game/A"] = R"( + return {mod_a = 2} + )"; + + CheckResult result = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(result); + + fileResolver.source["game/B"] = R"( + local me = require(game.A) + return {mod_b = 4} + )"; + + result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); + + fileResolver.source["game/A"] = R"( + local me = require(game.B) + return {mod_a_prime = 3} + )"; + + frontend.markDirty("game/A"); + frontend.markDirty("game/B"); + + result = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(result); + + TypeId tyA = requireType("game/A", "me"); + CHECK_EQ(toString(tyA), "any"); + + result = frontend.check("game/B"); + LUAU_REQUIRE_ERRORS(result); + + TypeId tyB = requireType("game/B", "me"); + CHECK_EQ(toString(tyB), "any"); +} + TEST_CASE_FIXTURE(FrontendFixture, "dont_reparse_clean_file_when_linting") { fileResolver.source["Modules/A"] = R"( diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 82b7a35..de06312 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Clone.h" #include "Luau/Module.h" #include "Luau/Scope.h" diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 7dacc66..79f9eca 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2022,6 +2022,15 @@ TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") matchParseError("type Y number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}}); } +TEST_CASE_FIXTURE(Fixture, "parse_type_pack_errors") +{ + ScopedFastFlag luauParseRecoverUnexpectedPack{"LuauParseRecoverUnexpectedPack", true}; + + matchParseError("type Y = {a: T..., b: number}", "Unexpected '...' after type name; type pack is not allowed in this context", + Location{{0, 20}, {0, 23}}); + matchParseError("type Y = {a: (number | string)...", "Unexpected '...' after type annotation", Location{{0, 36}, {0, 39}}); +} + TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") { { @@ -2590,4 +2599,16 @@ type Y = (T...) -> U... CHECK_EQ(1, result.errors.size()); } +TEST_CASE_FIXTURE(Fixture, "recover_unexpected_type_pack") +{ + ScopedFastFlag luauParseRecoverUnexpectedPack{"LuauParseRecoverUnexpectedPack", true}; + + ParseResult result = tryParse(R"( +type X = { a: T..., b: number } +type Y = { a: T..., b: number } +type Z = { a: string | T..., b: number } + )"); + REQUIRE_EQ(3, result.errors.size()); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index dbae7b5..1713216 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1270,7 +1270,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "function_decl_quantify_right_type") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; fileResolver.source["game/isAMagicMock"] = R"( --!nonstrict @@ -1294,7 +1294,7 @@ end TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; CheckResult result = check(R"( function string.len(): number @@ -1316,7 +1316,7 @@ print(string.len('hello')) TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite_2") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; ScopedFastFlag inferStatFunction{"LuauInferStatFunction", true}; CheckResult result = check(R"( @@ -1324,12 +1324,12 @@ local t: { f: ((x: number) -> number)? } = {} function t.f(x) print(x + 5) - return x .. "asd" + return x .. "asd" -- 1st error: we know that return type is a number, not a string end t.f = function(x) print(x + 5) - return x .. "asd" + return x .. "asd" -- 2nd error: we know that return type is a number, not a string end )"); @@ -1338,6 +1338,33 @@ end CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); } +TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_unsealed_overwrite") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; + ScopedFastFlag inferStatFunction{"LuauInferStatFunction", true}; + + CheckResult result = check(R"( +local t = { f = nil :: ((x: number) -> number)? } + +function t.f(x: string): string -- 1st error: new function value type is incompatible + return x .. "asd" +end + +t.f = function(x) + print(x + 5) + return x .. "asd" -- 2nd error: we know that return type is a number, not a string +end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(string) -> string' could not be converted into '((number) -> number)?' +caused by: + None of the union options are compatible. For example: Type '(string) -> string' could not be converted into '(number) -> number' +caused by: + Argument #1 type is not compatible. Type 'number' could not be converted into 'string')"); + CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); +} + TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") { ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; @@ -1352,7 +1379,7 @@ TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") TEST_CASE_FIXTURE(Fixture, "function_statement_sealed_table_assignment_through_indexer") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify3", true}; + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; CheckResult result = check(R"( local t: {[string]: () -> number} = {} diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index d146f4e..ac7a653 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -311,6 +311,8 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") { + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; + CheckResult result = check(R"( type X = { x: (number) -> number } type Y = { y: (string) -> string } @@ -326,10 +328,39 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") function xy:w(a:number) return a * 10 end )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); - CHECK_EQ(toString(result.errors[1]), "Cannot add property 'y' to table 'X & Y'"); - CHECK_EQ(toString(result.errors[2]), "Cannot add property 'w' to table 'X & Y'"); + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); + CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); +} + +TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; + + // After normalization, previous 'table_intersection_write_sealed_indirect' is identical to this one + CheckResult result = check(R"( + type XY = { x: (number) -> number, y: (string) -> string } + + local xy : XY = { + x = function(a: number) return -a end, + y = function(a: string) return a .. "b" end + } + function xy.z(a:number) return a * 10 end + function xy:y(a:number) return a * 10 end + function xy:w(a:number) return a * 10 end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'XY'"); + CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'XY'"); } TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index 44b7b0d..3ddf981 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -95,6 +95,8 @@ end )"); LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), "Cannot add method to non-table type 'number'"); + CHECK_EQ(toString(result.errors[1]), "Type 'number' could not be converted into 'string'"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 0cc12d1..0484351 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2922,4 +2922,19 @@ TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") +{ + ScopedFastFlag sff{"LuauCheckImplicitNumbericKeys", true}; + + CheckResult result = check(R"( + local t: { [string]: number } = { 5, 6, 7 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[1])); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[2])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index d8de259..c21e162 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -242,4 +242,30 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") state.tryUnify(&func, typeChecker.anyType); } +TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") +{ + ScopedFastFlag luauTxnLogPreserveOwner{"LuauTxnLogPreserveOwner", true}; + + TypeId a = arena.addType(TypeVar{FreeTypeVar{TypeLevel{}}}); + TypeId b = typeChecker.numberType; + + state.tryUnify(a, b); + state.log.commit(); + + CHECK_EQ(a->owningArena, &arena); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") +{ + ScopedFastFlag luauTxnLogPreserveOwner{"LuauTxnLogPreserveOwner", true}; + + TypePackId a = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); + TypePackId b = typeChecker.anyTypePack; + + state.tryUnify(a, b); + state.log.commit(); + + CHECK_EQ(a->owningArena, &arena); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 68b7c4f..ff207a1 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -513,4 +513,29 @@ TEST_CASE_FIXTURE(Fixture, "dont_allow_cyclic_unions_to_be_inferred") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "table_union_write_indirect") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; + + CheckResult result = check(R"( + type A = { x: number, y: (number) -> string } | { z: number, y: (number) -> string } + + local a:A = nil + + function a.y(x) + return tostring(x * 2) + end + + function a.y(x: string): number + return tonumber(x) or 0 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + // NOTE: union normalization will improve this message + CHECK_EQ(toString(result.errors[0]), + R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); +} + + TEST_SUITE_END(); diff --git a/tools/lldb_formatters.py b/tools/lldb_formatters.py index 40f8d6b..b3d2b4f 100644 --- a/tools/lldb_formatters.py +++ b/tools/lldb_formatters.py @@ -37,7 +37,7 @@ def getType(target, typeName): return ty def luau_variant_summary(valobj, internal_dict, options): - type_id = valobj.GetChildMemberWithName("typeid").GetValueAsUnsigned() + type_id = valobj.GetChildMemberWithName("typeId").GetValueAsUnsigned() storage = valobj.GetChildMemberWithName("storage") params = templateParams(valobj.GetType().GetCanonicalType().GetName()) stored_type = params[type_id] @@ -89,7 +89,7 @@ class LuauVariantSyntheticChildrenProvider: return None def update(self): - self.type_index = self.valobj.GetChildMemberWithName("typeid").GetValueAsSigned() + self.type_index = self.valobj.GetChildMemberWithName("typeId").GetValueAsSigned() self.type_params = templateParams(self.valobj.GetType().GetCanonicalType().GetName()) if len(self.type_params) > self.type_index: From 02ed5373ecea3bdf8a68fd40e43733d43a9fbed6 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 14 Apr 2022 14:57:15 -0700 Subject: [PATCH 02/19] Sync to upstream/release/523 --- Analysis/include/Luau/Clone.h | 9 +- Analysis/include/Luau/Error.h | 10 +- Analysis/include/Luau/Frontend.h | 1 + Analysis/include/Luau/LValue.h | 4 + Analysis/include/Luau/Module.h | 2 +- Analysis/include/Luau/Normalize.h | 19 + Analysis/include/Luau/RecursionCounter.h | 26 +- Analysis/include/Luau/Substitution.h | 1 + Analysis/include/Luau/ToString.h | 3 + Analysis/include/Luau/TxnLog.h | 32 +- Analysis/include/Luau/TypeInfer.h | 24 +- Analysis/include/Luau/TypePack.h | 14 +- Analysis/include/Luau/TypeVar.h | 31 +- Analysis/include/Luau/Unifiable.h | 17 +- Analysis/include/Luau/Unifier.h | 22 +- Analysis/include/Luau/UnifierSharedState.h | 2 + Analysis/include/Luau/VisitTypeVar.h | 9 + Analysis/src/Autocomplete.cpp | 8 +- Analysis/src/Clone.cpp | 116 ++- Analysis/src/Error.cpp | 28 +- Analysis/src/Frontend.cpp | 41 +- Analysis/src/IostreamHelpers.cpp | 2 + Analysis/src/JsonEncoder.cpp | 21 +- Analysis/src/LValue.cpp | 17 + Analysis/src/Linter.cpp | 93 +- Analysis/src/Module.cpp | 39 +- Analysis/src/Normalize.cpp | 814 +++++++++++++++++ Analysis/src/Quantify.cpp | 36 + Analysis/src/Substitution.cpp | 138 ++- Analysis/src/ToDot.cpp | 31 + Analysis/src/ToString.cpp | 165 +++- Analysis/src/TopoSortStatements.cpp | 1 + Analysis/src/TxnLog.cpp | 44 +- Analysis/src/TypeAttach.cpp | 13 + Analysis/src/TypeInfer.cpp | 488 +++++++++-- Analysis/src/TypePack.cpp | 46 +- Analysis/src/TypeVar.cpp | 28 +- Analysis/src/Unifier.cpp | 337 +++++-- Ast/include/Luau/DenseHash.h | 118 ++- Ast/include/Luau/Lexer.h | 2 +- Ast/src/Lexer.cpp | 10 +- Ast/src/Parser.cpp | 5 +- Compiler/src/Compiler.cpp | 4 +- Compiler/src/CostModel.cpp | 258 ++++++ Compiler/src/CostModel.h | 18 + Sources.cmake | 6 + VM/src/ltable.cpp | 47 +- VM/src/ltablib.cpp | 2 +- VM/src/lvmexecute.cpp | 4 +- tests/CostModel.test.cpp | 101 +++ tests/Fixture.cpp | 4 +- tests/JsonEncoder.test.cpp | 23 + tests/Linter.test.cpp | 3 +- tests/Module.test.cpp | 75 +- tests/NonstrictMode.test.cpp | 34 + tests/Normalize.test.cpp | 967 +++++++++++++++++++++ tests/Parser.test.cpp | 20 + tests/ToDot.test.cpp | 77 +- tests/ToString.test.cpp | 2 + tests/TopoSort.test.cpp | 32 +- tests/Transpiler.test.cpp | 2 +- tests/TypeInfer.annotations.test.cpp | 10 + tests/TypeInfer.builtins.test.cpp | 13 +- tests/TypeInfer.classes.test.cpp | 3 + tests/TypeInfer.functions.test.cpp | 177 +++- tests/TypeInfer.generics.test.cpp | 77 +- tests/TypeInfer.intersectionTypes.test.cpp | 46 +- tests/TypeInfer.oop.test.cpp | 16 +- tests/TypeInfer.operators.test.cpp | 3 +- tests/TypeInfer.provisional.test.cpp | 131 ++- tests/TypeInfer.refinements.test.cpp | 12 +- tests/TypeInfer.singletons.test.cpp | 11 +- tests/TypeInfer.tables.test.cpp | 61 +- tests/TypeInfer.test.cpp | 66 +- tests/TypeInfer.typePacks.cpp | 38 +- tests/TypeInfer.unionTypes.test.cpp | 25 +- tests/conformance/nextvar.lua | 15 + 77 files changed, 4597 insertions(+), 653 deletions(-) create mode 100644 Analysis/include/Luau/Normalize.h create mode 100644 Analysis/src/Normalize.cpp create mode 100644 Compiler/src/CostModel.cpp create mode 100644 Compiler/src/CostModel.h create mode 100644 tests/CostModel.test.cpp create mode 100644 tests/Normalize.test.cpp diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 917ef80..78aa92c 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -14,12 +14,15 @@ using SeenTypePacks = std::unordered_map; struct CloneState { + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + int recursionCount = 0; bool encounteredFreeType = false; }; -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); -TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); +TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); +TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); } // namespace Luau diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 53b946a..7068314 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -287,12 +287,20 @@ struct TypesAreUnrelated bool operator==(const TypesAreUnrelated& rhs) const; }; +struct NormalizationTooComplex +{ + bool operator==(const NormalizationTooComplex&) const + { + return true; + } +}; + using TypeErrorData = Variant; + MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, TypesAreUnrelated, NormalizationTooComplex>; struct TypeError { diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 2266f54..e24e433 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -70,6 +70,7 @@ struct SourceNode std::vector> requireLocations; bool dirty = true; bool dirtyAutocomplete = true; + double autocompleteLimitsMult = 1.0; }; struct FrontendOptions diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h index 3d510d5..afb7141 100644 --- a/Analysis/include/Luau/LValue.h +++ b/Analysis/include/Luau/LValue.h @@ -35,8 +35,12 @@ const LValue* baseof(const LValue& lvalue); std::optional tryGetLValue(const class AstExpr& expr); // Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. +// TODO: remove with FFlagLuauTypecheckOptPass std::pair> getFullName(const LValue& lvalue); +// Utility function: breaks down an LValue to get at the Symbol +Symbol getBaseSymbol(const LValue& lvalue); + template const T* get(const LValue& lvalue) { diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 9a32f61..0dd4418 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -113,7 +113,7 @@ struct Module // This helps us to force TypeVar ownership into a DAG rather than a DCG. // Returns true if there were any free types encountered in the public interface. This // indicates a bug in the type checker that we want to surface. - bool clonePublicInterface(); + bool clonePublicInterface(InternalErrorReporter& ice); }; } // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h new file mode 100644 index 0000000..262b54b --- /dev/null +++ b/Analysis/include/Luau/Normalize.h @@ -0,0 +1,19 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Substitution.h" +#include "Luau/TypeVar.h" +#include "Luau/Module.h" + +namespace Luau +{ + +struct InternalErrorReporter; + +bool isSubtype(TypeId superTy, TypeId subTy, InternalErrorReporter& ice); + +std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice); +std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice); +std::pair normalize(TypePackId ty, TypeArena& arena, InternalErrorReporter& ice); +std::pair normalize(TypePackId ty, const ModulePtr& module, InternalErrorReporter& ice); + +} // namespace Luau diff --git a/Analysis/include/Luau/RecursionCounter.h b/Analysis/include/Luau/RecursionCounter.h index 89632ce..03ae2c8 100644 --- a/Analysis/include/Luau/RecursionCounter.h +++ b/Analysis/include/Luau/RecursionCounter.h @@ -4,10 +4,21 @@ #include "Luau/Common.h" #include +#include + +LUAU_FASTFLAG(LuauRecursionLimitException); namespace Luau { +struct RecursionLimitException : public std::exception +{ + const char* what() const noexcept + { + return "Internal recursion counter limit exceeded"; + } +}; + struct RecursionCounter { RecursionCounter(int* count) @@ -28,11 +39,22 @@ private: struct RecursionLimiter : RecursionCounter { - RecursionLimiter(int* count, int limit) + // TODO: remove ctx after LuauRecursionLimitException is removed + RecursionLimiter(int* count, int limit, const char* ctx) : RecursionCounter(count) { + LUAU_ASSERT(ctx); if (limit > 0 && *count > limit) - throw std::runtime_error("Internal recursion counter limit exceeded"); + { + if (FFlag::LuauRecursionLimitException) + throw RecursionLimitException(); + else + { + std::string m = "Internal recursion counter limit exceeded: "; + m += ctx; + throw std::runtime_error(m); + } + } } }; diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 9662d5b..6f5931e 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -90,6 +90,7 @@ struct Tarjan std::vector lowlink; int childCount = 0; + int childLimit = 0; // This should never be null; ensure you initialize it before calling // substitution methods. diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 49ee82f..f4db5e3 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -28,6 +28,7 @@ struct ToStringOptions bool functionTypeArguments = false; // If true, output function type argument names when they are available bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. + bool indent = false; size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); std::optional nameMap; @@ -73,6 +74,8 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp std::string dump(TypeId ty); std::string dump(TypePackId ty); +std::string dump(const std::shared_ptr& scope, const char* name); + std::string generateName(size_t n); } // namespace Luau diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index c8ebaae..995ed6c 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -7,7 +7,7 @@ #include "Luau/TypeVar.h" #include "Luau/TypePack.h" -LUAU_FASTFLAG(LuauShareTxnSeen); +LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { @@ -64,13 +64,17 @@ T* getMutable(PendingTypePack* pending) struct TxnLog { TxnLog() - : ownedSeen() + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , ownedSeen() , sharedSeen(&ownedSeen) { } explicit TxnLog(TxnLog* parent) - : parent(parent) + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , parent(parent) { if (parent) { @@ -83,14 +87,19 @@ struct TxnLog } explicit TxnLog(std::vector>* sharedSeen) - : sharedSeen(sharedSeen) + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , sharedSeen(sharedSeen) { } TxnLog(TxnLog* parent, std::vector>* sharedSeen) - : parent(parent) + : typeVarChanges(nullptr) + , typePackChanges(nullptr) + , parent(parent) , sharedSeen(sharedSeen) { + LUAU_ASSERT(!FFlag::LuauTypecheckOptPass); } TxnLog(const TxnLog&) = delete; @@ -243,6 +252,12 @@ struct TxnLog return Luau::getMutable(ty); } + template + const T* get(TID ty) const + { + return this->getMutable(ty); + } + // Returns whether a given type or type pack is a given state, respecting the // log's pending state. // @@ -263,11 +278,8 @@ private: // unique_ptr is used to give us stable pointers across insertions into the // map. Otherwise, it would be really easy to accidentally invalidate the // pointers returned from queue/pending. - // - // We can't use a DenseHashMap here because we need a non-const iterator - // over the map when we concatenate. - std::unordered_map, DenseHashPointer> typeVarChanges; - std::unordered_map, DenseHashPointer> typePackChanges; + DenseHashMap> typeVarChanges; + DenseHashMap> typePackChanges; TxnLog* parent = nullptr; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 215da67..ac88013 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -76,19 +76,32 @@ struct Instantiation : Substitution // A substitution which replaces free types by any struct Anyification : Substitution { - Anyification(TypeArena* arena, TypeId anyType, TypePackId anyTypePack) + Anyification(TypeArena* arena, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack) : Substitution(TxnLog::empty(), arena) + , iceHandler(iceHandler) , anyType(anyType) , anyTypePack(anyTypePack) { } + InternalErrorReporter* iceHandler; + TypeId anyType; TypePackId anyTypePack; + bool normalizationTooComplex = false; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; TypeId clean(TypeId ty) override; TypePackId clean(TypePackId tp) override; + + bool ignoreChildren(TypeId ty) override + { + return ty->persistent; + } + bool ignoreChildren(TypePackId ty) override + { + return ty->persistent; + } }; // A substitution which replaces the type parameters of a type function by arguments @@ -139,6 +152,7 @@ struct TypeChecker TypeChecker& operator=(const TypeChecker&) = delete; ModulePtr check(const SourceModule& module, Mode mode, std::optional environmentScope = std::nullopt); + ModulePtr checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope = std::nullopt); std::vector> getScopes() const; @@ -160,6 +174,7 @@ struct TypeChecker void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); + void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); ExprResult checkExpr( @@ -172,6 +187,7 @@ struct TypeChecker ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); ExprResult checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); + ExprResult checkExpr_(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); ExprResult checkExpr(const ScopePtr& scope, const AstExprUnary& expr); TypeId checkRelationalOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); @@ -258,6 +274,8 @@ struct TypeChecker ErrorVec canUnify(TypeId subTy, TypeId superTy, const Location& location); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const Location& location); + void unifyLowerBound(TypePackId subTy, TypePackId superTy, const Location& location); + std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); @@ -395,6 +413,7 @@ private: void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); bool isNonstrictMode() const; + bool useConstrainedIntersections() const; public: /** Extract the types in a type pack, given the assumption that the pack must have some exact length. @@ -421,7 +440,10 @@ public: std::vector requireCycles; + // Type inference limits std::optional finishTime; + std::optional instantiationChildLimit; + std::optional unifierIterationLimit; public: const TypeId nilType; diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 85fa467..bbc65f9 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -40,6 +40,7 @@ struct TypePack struct VariadicTypePack { TypeId ty; + bool hidden = false; // if true, we don't display this when toString()ing a pack with this variadic as its tail. }; struct TypePackVar @@ -109,10 +110,10 @@ private: }; TypePackIterator begin(TypePackId tp); -TypePackIterator begin(TypePackId tp, TxnLog* log); +TypePackIterator begin(TypePackId tp, const TxnLog* log); TypePackIterator end(TypePackId tp); -using SeenSet = std::set>; +using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); @@ -122,7 +123,7 @@ TypePackId follow(TypePackId tp, std::function mapper); size_t size(TypePackId tp, TxnLog* log = nullptr); bool finite(TypePackId tp, TxnLog* log = nullptr); size_t size(const TypePack& tp, TxnLog* log = nullptr); -std::optional first(TypePackId tp); +std::optional first(TypePackId tp, bool ignoreHiddenVariadics = true); TypePackVar* asMutable(TypePackId tp); TypePack* asMutable(const TypePack* tp); @@ -154,5 +155,12 @@ bool isEmpty(TypePackId tp); /// Flattens out a type pack. Also returns a valid TypePackId tail if the type pack's full size is not known std::pair, std::optional> flatten(TypePackId tp); +std::pair, std::optional> flatten(TypePackId tp, const TxnLog& log); + +/// Returs true if the type pack arose from a function that is declared to be variadic. +/// Returns *false* for function argument packs that are inferred to be safe to oversaturate! +bool isVariadic(TypePackId tp); +bool isVariadic(TypePackId tp, const TxnLog& log); + } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index f61e404..ae7d137 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -109,6 +109,23 @@ struct PrimitiveTypeVar } }; +struct ConstrainedTypeVar +{ + explicit ConstrainedTypeVar(TypeLevel level) + : level(level) + { + } + + explicit ConstrainedTypeVar(TypeLevel level, const std::vector& parts) + : parts(parts) + , level(level) + { + } + + std::vector parts; + TypeLevel level; +}; + // Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md // Types for true and false struct BooleanSingleton @@ -248,6 +265,7 @@ struct FunctionTypeVar MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr. bool hasSelf; Tags tags; + bool hasNoGenerics = false; }; enum class TableState @@ -418,8 +436,8 @@ struct LazyTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = Unifiable::Variant; struct TypeVar final { @@ -436,6 +454,7 @@ struct TypeVar final TypeVar(const TypeVariant& ty, bool persistent) : ty(ty) , persistent(persistent) + , normal(persistent) // We assume that all persistent types are irreducable. { } @@ -446,6 +465,10 @@ struct TypeVar final // Persistent TypeVars do not get cloned. bool persistent = false; + // Normalization sets this for types that are fully normalized. + // This implies that they are transitively immutable. + bool normal = false; + std::optional documentationSymbol; // Pointer to the type arena that allocated this type. @@ -458,7 +481,7 @@ struct TypeVar final TypeVar& operator=(TypeVariant&& rhs); }; -using SeenSet = std::set>; +using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs); // Follow BoundTypeVars until we get to something real @@ -545,6 +568,8 @@ void persist(TypePackId tp); const TypeLevel* getLevel(TypeId ty); TypeLevel* getMutableLevel(TypeId ty); +std::optional getLevel(TypePackId tp); + const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name); bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent); diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index e8eafe6..64fa131 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -56,6 +56,14 @@ struct TypeLevel } }; +inline TypeLevel max(const TypeLevel& a, const TypeLevel& b) +{ + if (a.subsumes(b)) + return b; + else + return a; +} + inline TypeLevel min(const TypeLevel& a, const TypeLevel& b) { if (a.subsumes(b)) @@ -64,7 +72,9 @@ inline TypeLevel min(const TypeLevel& a, const TypeLevel& b) return b; } -namespace Unifiable +} // namespace Luau + +namespace Luau::Unifiable { using Name = std::string; @@ -125,7 +135,6 @@ private: }; template -using Variant = Variant, Generic, Error, Value...>; +using Variant = Luau::Variant, Generic, Error, Value...>; -} // namespace Unifiable -} // namespace Luau +} // namespace Luau::Unifiable diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 474af50..340feb7 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -49,14 +49,14 @@ struct Unifier ErrorVec errors; Location location; Variance variance = Covariant; + bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once. CountMismatch::Context ctx = CountMismatch::Arg; UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, - TxnLog* parentLog = nullptr); - Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); + Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); + Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, Variance variance, + UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); @@ -106,7 +106,12 @@ private: std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); + void tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy); + void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy); + public: + void unifyLowerBound(TypePackId subTy, TypePackId superTy); + // Report an "infinite type error" if the type "needle" already occurs within "haystack" void occursCheck(TypeId needle, TypeId haystack); void occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); @@ -115,12 +120,7 @@ public: Unifier makeChildUnifier(); - // A utility function that appends the given error to the unifier's error log. - // This allows setting a breakpoint wherever the unifier reports an error. - void reportError(TypeError error) - { - errors.push_back(error); - } + void reportError(TypeError err); private: bool isNonstrictMode() const; @@ -135,4 +135,6 @@ private: std::optional firstPackErrorPos; }; +void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, TypePackId tp); + } // namespace Luau diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index 9a3ba56..1a0b8b7 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -28,7 +28,9 @@ struct TypeIdPairHash struct UnifierCounters { int recursionCount = 0; + int recursionLimit = 0; int iterationCount = 0; + int iterationLimit = 0; }; struct UnifierSharedState diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 740854b..d11cbd0 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -82,6 +82,15 @@ void visit(TypeId ty, F& f, Set& seen) else if (auto etv = get(ty)) apply(ty, *etv, seen, f); + else if (auto ctv = get(ty)) + { + if (apply(ty, *ctv, seen, f)) + { + for (TypeId part : ctv->parts) + visit(part, f, seen); + } + } + else if (auto ptv = get(ty)) apply(ty, *ptv, seen, f); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index b7201ab..e0e79cb 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -151,8 +151,12 @@ static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionTyp auto idxExpr = nodes.back()->as(); bool hasImplicitSelf = idxExpr && idxExpr->op == ':'; - auto args = Luau::flatten(func->argTypes); - bool noArgFunction = (args.first.empty() || (hasImplicitSelf && args.first.size() == 1)) && !args.second.has_value(); + auto [argTypes, argVariadicPack] = Luau::flatten(func->argTypes); + + if (argVariadicPack.has_value() && isVariadic(*argVariadicPack)) + return ParenthesesRecommendation::CursorInside; + + bool noArgFunction = argTypes.empty() || (hasImplicitSelf && argTypes.size() == 1); return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside; } diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index ac9705a..8e7f7c0 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -6,7 +6,10 @@ #include "Luau/TypePack.h" #include "Luau/Unifiable.h" +LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) + LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) +LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { @@ -23,11 +26,11 @@ struct TypePackCloner; struct TypeCloner { - TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) + TypeCloner(TypeArena& dest, TypeId typeId, CloneState& cloneState) : dest(dest) , typeId(typeId) - , seenTypes(seenTypes) - , seenTypePacks(seenTypePacks) + , seenTypes(cloneState.seenTypes) + , seenTypePacks(cloneState.seenTypePacks) , cloneState(cloneState) { } @@ -46,6 +49,7 @@ struct TypeCloner void operator()(const Unifiable::Bound& t); void operator()(const Unifiable::Error& t); void operator()(const PrimitiveTypeVar& t); + void operator()(const ConstrainedTypeVar& t); void operator()(const SingletonTypeVar& t); void operator()(const FunctionTypeVar& t); void operator()(const TableTypeVar& t); @@ -65,11 +69,11 @@ struct TypePackCloner SeenTypePacks& seenTypePacks; CloneState& cloneState; - TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) + TypePackCloner(TypeArena& dest, TypePackId typePackId, CloneState& cloneState) : dest(dest) , typePackId(typePackId) - , seenTypes(seenTypes) - , seenTypePacks(seenTypePacks) + , seenTypes(cloneState.seenTypes) + , seenTypePacks(cloneState.seenTypePacks) , cloneState(cloneState) { } @@ -103,13 +107,15 @@ struct TypePackCloner // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. void operator()(const Unifiable::Bound& t) { - TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + TypePackId cloned = clone(t.boundTo, dest, cloneState); + if (FFlag::DebugLuauCopyBeforeNormalizing) + cloned = dest.addTypePack(TypePackVar{BoundTypePack{cloned}}); seenTypePacks[typePackId] = cloned; } void operator()(const VariadicTypePack& t) { - TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, cloneState)}}); + TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, cloneState), /*hidden*/ t.hidden}}); seenTypePacks[typePackId] = cloned; } @@ -121,10 +127,10 @@ struct TypePackCloner seenTypePacks[typePackId] = cloned; for (TypeId ty : t.head) - destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + destTp->head.push_back(clone(ty, dest, cloneState)); if (t.tail) - destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, cloneState); + destTp->tail = clone(*t.tail, dest, cloneState); } }; @@ -150,7 +156,9 @@ void TypeCloner::operator()(const Unifiable::Generic& t) void TypeCloner::operator()(const Unifiable::Bound& t) { - TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + TypeId boundTo = clone(t.boundTo, dest, cloneState); + if (FFlag::DebugLuauCopyBeforeNormalizing) + boundTo = dest.addType(BoundTypeVar{boundTo}); seenTypes[typeId] = boundTo; } @@ -164,6 +172,23 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t) defaultClone(t); } +void TypeCloner::operator()(const ConstrainedTypeVar& t) +{ + cloneState.encounteredFreeType = true; + + TypeId res = dest.addType(ConstrainedTypeVar{t.level}); + ConstrainedTypeVar* ctv = getMutable(res); + LUAU_ASSERT(ctv); + + seenTypes[typeId] = res; + + std::vector parts; + for (TypeId part : t.parts) + parts.push_back(clone(part, dest, cloneState)); + + ctv->parts = std::move(parts); +} + void TypeCloner::operator()(const SingletonTypeVar& t) { defaultClone(t); @@ -178,23 +203,26 @@ void TypeCloner::operator()(const FunctionTypeVar& t) seenTypes[typeId] = result; for (TypeId generic : t.generics) - ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, cloneState)); + ftv->generics.push_back(clone(generic, dest, cloneState)); for (TypePackId genericPack : t.genericPacks) - ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, cloneState)); + ftv->genericPacks.push_back(clone(genericPack, dest, cloneState)); ftv->tags = t.tags; - ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, cloneState); + ftv->argTypes = clone(t.argTypes, dest, cloneState); ftv->argNames = t.argNames; - ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, cloneState); + ftv->retType = clone(t.retType, dest, cloneState); + + if (FFlag::LuauTypecheckOptPass) + ftv->hasNoGenerics = t.hasNoGenerics; } void TypeCloner::operator()(const TableTypeVar& t) { // If table is now bound to another one, we ignore the content of the original - if (t.boundTo) + if (!FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) { - TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, cloneState); + TypeId boundTo = clone(*t.boundTo, dest, cloneState); seenTypes[typeId] = boundTo; return; } @@ -209,18 +237,20 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->level = TypeLevel{0, 0}; + if (FFlag::DebugLuauCopyBeforeNormalizing && t.boundTo) + ttv->boundTo = clone(*t.boundTo, dest, cloneState); + for (const auto& [name, prop] : t.props) - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + ttv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; if (t.indexer) - ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, cloneState), - clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, cloneState)}; + ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)}; for (TypeId& arg : ttv->instantiatedTypeParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); + arg = clone(arg, dest, cloneState); for (TypePackId& arg : ttv->instantiatedTypePackParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); + arg = clone(arg, dest, cloneState); if (ttv->state == TableState::Free) { @@ -240,8 +270,8 @@ void TypeCloner::operator()(const MetatableTypeVar& t) MetatableTypeVar* mtv = getMutable(result); seenTypes[typeId] = result; - mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, cloneState); - mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, cloneState); + mtv->table = clone(t.table, dest, cloneState); + mtv->metatable = clone(t.metatable, dest, cloneState); } void TypeCloner::operator()(const ClassTypeVar& t) @@ -252,13 +282,13 @@ void TypeCloner::operator()(const ClassTypeVar& t) seenTypes[typeId] = result; for (const auto& [name, prop] : t.props) - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + ctv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; if (t.parent) - ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, cloneState); + ctv->parent = clone(*t.parent, dest, cloneState); if (t.metatable) - ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, cloneState); + ctv->metatable = clone(*t.metatable, dest, cloneState); } void TypeCloner::operator()(const AnyTypeVar& t) @@ -272,7 +302,7 @@ void TypeCloner::operator()(const UnionTypeVar& t) options.reserve(t.options.size()); for (TypeId ty : t.options) - options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + options.push_back(clone(ty, dest, cloneState)); TypeId result = dest.addType(UnionTypeVar{std::move(options)}); seenTypes[typeId] = result; @@ -287,7 +317,7 @@ void TypeCloner::operator()(const IntersectionTypeVar& t) LUAU_ASSERT(option != nullptr); for (TypeId ty : t.parts) - option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + option->parts.push_back(clone(ty, dest, cloneState)); } void TypeCloner::operator()(const LazyTypeVar& t) @@ -297,36 +327,36 @@ void TypeCloner::operator()(const LazyTypeVar& t) } // anonymous namespace -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) { if (tp->persistent) return tp; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit, "cloning TypePackId"); - TypePackId& res = seenTypePacks[tp]; + TypePackId& res = cloneState.seenTypePacks[tp]; if (res == nullptr) { - TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks, cloneState}; + TypePackCloner cloner{dest, tp, cloneState}; Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. } return res; } -TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) { if (typeId->persistent) return typeId; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit, "cloning TypeId"); - TypeId& res = seenTypes[typeId]; + TypeId& res = cloneState.seenTypes[typeId]; if (res == nullptr) { - TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; + TypeCloner cloner{dest, typeId, cloneState}; Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. // Persistent types are not being cloned and we get the original type back which might be read-only @@ -337,33 +367,33 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks return res; } -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) { TypeFun result; for (auto param : typeFun.typeParams) { - TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState); + TypeId ty = clone(param.ty, dest, cloneState); std::optional defaultValue; if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + defaultValue = clone(*param.defaultValue, dest, cloneState); result.typeParams.push_back({ty, defaultValue}); } for (auto param : typeFun.typePackParams) { - TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState); + TypePackId tp = clone(param.tp, dest, cloneState); std::optional defaultValue; if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + defaultValue = clone(*param.defaultValue, dest, cloneState); result.typePackParams.push_back({tp, defaultValue}); } - result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); + result.type = clone(typeFun.type, dest, cloneState); return result; } diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 5eb2ea2..cbec0b1 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -8,7 +8,6 @@ #include -LUAU_FASTFLAGVARIABLE(BetterDiagnosticCodesInStudio, false); LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleName, false); static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) @@ -252,14 +251,7 @@ struct ErrorConverter std::string operator()(const Luau::SyntaxError& e) const { - if (FFlag::BetterDiagnosticCodesInStudio) - { - return e.message; - } - else - { - return "Syntax error: " + e.message; - } + return e.message; } std::string operator()(const Luau::CodeTooComplex&) const @@ -451,6 +443,11 @@ struct ErrorConverter { return "Cannot cast '" + toString(e.left) + "' into '" + toString(e.right) + "' because the types are unrelated"; } + + std::string operator()(const NormalizationTooComplex&) const + { + return "Code is too complex to typecheck! Consider simplifying the code around this area"; + } }; struct InvalidNameChecker @@ -716,14 +713,14 @@ bool containsParseErrorName(const TypeError& error) } template -void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState cloneState) +void copyError(T& e, TypeArena& destArena, CloneState cloneState) { auto clone = [&](auto&& ty) { - return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks, cloneState); + return ::Luau::clone(ty, destArena, cloneState); }; auto visitErrorData = [&](auto&& e) { - copyError(e, destArena, seenTypes, seenTypePacks, cloneState); + copyError(e, destArena, cloneState); }; if constexpr (false) @@ -844,18 +841,19 @@ void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& e.left = clone(e.left); e.right = clone(e.right); } + else if constexpr (std::is_same_v) + { + } else static_assert(always_false_v, "Non-exhaustive type switch"); } void copyErrors(ErrorVec& errors, TypeArena& destArena) { - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; auto visitErrorData = [&](auto&& e) { - copyError(e, destArena, seenTypes, seenTypePacks, cloneState); + copyError(e, destArena, cloneState); }; LUAU_ASSERT(!destArena.typeVars.isFrozen()); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 000769f..8b0b221 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -11,16 +11,18 @@ #include "Luau/TimeTrace.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" -#include "Luau/Common.h" #include #include #include +LUAU_FASTINT(LuauTypeInferIterationLimit) +LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauCyclicModuleTypeSurface) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 0) namespace Luau @@ -97,13 +99,11 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t if (checkedModule->errors.size() > 0) return LoadDefinitionFileResult{false, parseResult, checkedModule}; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; for (const auto& [name, ty] : checkedModule->declaredGlobals) { - TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState); + TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; @@ -113,7 +113,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState); + TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); targetScope->exportedTypeBindings[name] = globalTy; @@ -440,13 +440,42 @@ CheckResult Frontend::check(const ModuleName& name, std::optional 0) + typeCheckerForAutocomplete.instantiationChildLimit = + std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.instantiationChildLimit = std::nullopt; + + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckerForAutocomplete.unifierIterationLimit = + std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckerForAutocomplete.unifierIterationLimit = std::nullopt; + } + ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; + double duration = getTimestamp() - timestamp; + if (moduleForAutocomplete->timeout) + { checkResult.timeoutHits.push_back(moduleName); - stats.timeCheck += getTimestamp() - timestamp; + if (FFlag::LuauAutocompleteDynamicLimits) + sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; + } + else if (FFlag::LuauAutocompleteDynamicLimits && duration < autocompleteTimeLimit / 2.0) + { + sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); + } + + stats.timeCheck += duration; stats.filesStrict += 1; sourceNode.dirtyAutocomplete = false; diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index a8f6758..0eaa485 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -184,6 +184,8 @@ static void errorToString(std::ostream& stream, const T& err) } else if constexpr (std::is_same_v) stream << "TypesAreUnrelated { left = '" + toString(err.left) + "', right = '" + toString(err.right) + "' }"; + else if constexpr (std::is_same_v) + stream << "NormalizationTooComplex { }"; else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index 811e7c2..829ffa0 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -403,35 +403,26 @@ struct AstJsonEncoder : public AstVisitor void write(const AstExprTable::Item& item) { writeRaw("{"); - bool comma = pushComma(); + bool c = pushComma(); write("kind", item.kind); switch (item.kind) { case AstExprTable::Item::List: - write(item.value); + write("value", item.value); break; default: - write(item.key); - writeRaw(","); - write(item.value); + write("key", item.key); + write("value", item.value); break; } - popComma(comma); + popComma(c); writeRaw("}"); } void write(class AstExprTable* node) { writeNode(node, "AstExprTable", [&]() { - bool comma = false; - for (const auto& prop : node->items) - { - if (comma) - writeRaw(","); - else - comma = true; - write(prop); - } + PROP(items); }); } diff --git a/Analysis/src/LValue.cpp b/Analysis/src/LValue.cpp index c9466a4..72555ab 100644 --- a/Analysis/src/LValue.cpp +++ b/Analysis/src/LValue.cpp @@ -5,6 +5,8 @@ #include +LUAU_FASTFLAG(LuauTypecheckOptPass) + namespace Luau { @@ -79,6 +81,8 @@ std::optional tryGetLValue(const AstExpr& node) std::pair> getFullName(const LValue& lvalue) { + LUAU_ASSERT(!FFlag::LuauTypecheckOptPass); + const LValue* current = &lvalue; std::vector keys; while (auto field = get(*current)) @@ -92,6 +96,19 @@ std::pair> getFullName(const LValue& lvalue) return {*symbol, std::vector(keys.rbegin(), keys.rend())}; } +Symbol getBaseSymbol(const LValue& lvalue) +{ + LUAU_ASSERT(FFlag::LuauTypecheckOptPass); + + const LValue* current = &lvalue; + while (auto field = get(*current)) + current = baseof(*current); + + const Symbol* symbol = get(*current); + LUAU_ASSERT(symbol); + return *symbol; +} + void merge(RefinementMap& l, const RefinementMap& r, std::function f) { for (const auto& [k, a] : r) diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index b7480e3..5608e4b 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,7 +14,6 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false) -LUAU_FASTFLAGVARIABLE(LuauLintNoRobloxBits, false) namespace Luau { @@ -1140,25 +1139,8 @@ private: Kind_Primitive, // primitive type supported by VM - boolean/userdata/etc. No differentiation between types of userdata. Kind_Vector, // 'vector' but only used when type is used Kind_Userdata, // custom userdata type - - // TODO: remove these with LuauLintNoRobloxBits - Kind_Class, // custom userdata type that reflects Roblox Instance-derived hierarchy - Part/etc. - Kind_Enum, // custom userdata type referring to an enum item of enum classes, e.g. Enum.NormalId.Back/Enum.Axis.X/etc. }; - bool containsPropName(TypeId ty, const std::string& propName) - { - LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits); - - if (auto ctv = get(ty)) - return lookupClassProp(ctv, propName) != nullptr; - - if (auto ttv = get(ty)) - return ttv->props.find(propName) != ttv->props.end(); - - return false; - } - TypeKind getTypeKind(const std::string& name) { if (name == "nil" || name == "boolean" || name == "userdata" || name == "number" || name == "string" || name == "table" || @@ -1168,23 +1150,10 @@ private: if (name == "vector") return Kind_Vector; - if (FFlag::LuauLintNoRobloxBits) - { - if (std::optional maybeTy = context->scope->lookupType(name)) - return Kind_Userdata; + if (std::optional maybeTy = context->scope->lookupType(name)) + return Kind_Userdata; - return Kind_Unknown; - } - else - { - if (std::optional maybeTy = context->scope->lookupType(name)) - // Kind_Userdata is probably not 100% precise but is close enough - return containsPropName(maybeTy->type, "ClassName") ? Kind_Class : Kind_Userdata; - else if (std::optional maybeTy = context->scope->lookupImportedType("Enum", name)) - return Kind_Enum; - - return Kind_Unknown; - } + return Kind_Unknown; } void validateType(AstExprConstantString* expr, std::initializer_list expected, const char* expectedString) @@ -1202,67 +1171,11 @@ private: { if (kind == ek) return; - - // as a special case, Instance and EnumItem are both a userdata type (as returned by typeof) and a class type - if (!FFlag::LuauLintNoRobloxBits && ek == Kind_Userdata && (name == "Instance" || name == "EnumItem")) - return; } emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s' (expected %s)", name.c_str(), expectedString); } - bool acceptsClassName(AstName method) - { - LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits); - - return method.value[0] == 'F' && (method == "FindFirstChildOfClass" || method == "FindFirstChildWhichIsA" || - method == "FindFirstAncestorOfClass" || method == "FindFirstAncestorWhichIsA"); - } - - bool visit(AstExprCall* node) override - { - // TODO: Simply remove the override - if (FFlag::LuauLintNoRobloxBits) - return true; - - if (AstExprIndexName* index = node->func->as()) - { - AstExprConstantString* arg0 = node->args.size > 0 ? node->args.data[0]->as() : NULL; - - if (arg0) - { - if (node->self && index->index == "IsA" && node->args.size == 1) - { - validateType(arg0, {Kind_Class, Kind_Enum}, "class or enum type"); - } - else if (node->self && (index->index == "GetService" || index->index == "FindService") && node->args.size == 1) - { - AstExprGlobal* g = index->expr->as(); - - if (g && (g->name == "game" || g->name == "Game")) - { - validateType(arg0, {Kind_Class}, "class type"); - } - } - else if (node->self && acceptsClassName(index->index) && node->args.size == 1) - { - validateType(arg0, {Kind_Class}, "class type"); - } - else if (!node->self && index->index == "new" && node->args.size <= 2) - { - AstExprGlobal* g = index->expr->as(); - - if (g && g->name == "Instance") - { - validateType(arg0, {Kind_Class}, "class type"); - } - } - } - } - - return true; - } - bool visit(AstExprBinary* node) override { if (node->op == AstExprBinary::CompareNe || node->op == AstExprBinary::CompareEq) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 6bb4524..e2e3b43 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -1,8 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" -#include "Luau/Common.h" #include "Luau/Clone.h" +#include "Luau/Common.h" +#include "Luau/Normalize.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" @@ -14,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(LuauCloneDeclaredGlobals, false) +LUAU_FASTFLAG(LuauLowerBoundsCalculation) namespace Luau { @@ -143,32 +145,51 @@ Module::~Module() unfreeze(internalTypes); } -bool Module::clonePublicInterface() +bool Module::clonePublicInterface(InternalErrorReporter& ice) { LUAU_ASSERT(interfaceTypes.typeVars.empty()); LUAU_ASSERT(interfaceTypes.typePacks.empty()); - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; ScopePtr moduleScope = getModuleScope(); - moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, cloneState); + moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, cloneState); if (moduleScope->varargPack) - moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, cloneState); + moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, cloneState); + + if (FFlag::LuauLowerBoundsCalculation) + { + normalize(moduleScope->returnType, interfaceTypes, ice); + if (moduleScope->varargPack) + normalize(*moduleScope->varargPack, interfaceTypes, ice); + } for (auto& [name, tf] : moduleScope->exportedTypeBindings) - tf = clone(tf, interfaceTypes, seenTypes, seenTypePacks, cloneState); + { + tf = clone(tf, interfaceTypes, cloneState); + if (FFlag::LuauLowerBoundsCalculation) + normalize(tf.type, interfaceTypes, ice); + } for (TypeId ty : moduleScope->returnType) + { if (get(follow(ty))) - *asMutable(ty) = AnyTypeVar{}; + { + auto t = asMutable(ty); + t->ty = AnyTypeVar{}; + t->normal = true; + } + } if (FFlag::LuauCloneDeclaredGlobals) { for (auto& [name, ty] : declaredGlobals) - ty = clone(ty, interfaceTypes, seenTypes, seenTypePacks, cloneState); + { + ty = clone(ty, interfaceTypes, cloneState); + if (FFlag::LuauLowerBoundsCalculation) + normalize(ty, interfaceTypes, ice); + } } freeze(internalTypes); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp new file mode 100644 index 0000000..40341ac --- /dev/null +++ b/Analysis/src/Normalize.cpp @@ -0,0 +1,814 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Normalize.h" + +#include + +#include "Luau/Clone.h" +#include "Luau/DenseHash.h" +#include "Luau/Substitution.h" +#include "Luau/Unifier.h" +#include "Luau/VisitTypeVar.h" + +LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) + +// This could theoretically be 2000 on amd64, but x86 requires this. +LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); +LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineIntersectionFix, false); + +namespace Luau +{ + +namespace +{ + +struct Replacer : Substitution +{ + TypeId sourceType; + TypeId replacedType; + DenseHashMap replacedTypes{nullptr}; + DenseHashMap replacedPacks{nullptr}; + + Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType) + : Substitution(TxnLog::empty(), arena) + , sourceType(sourceType) + , replacedType(replacedType) + { + } + + bool isDirty(TypeId ty) override + { + if (!sourceType) + return false; + + auto vecHasSourceType = [sourceType = sourceType](const auto& vec) { + return end(vec) != std::find(begin(vec), end(vec), sourceType); + }; + + // Walk every kind of TypeVar and find pointers to sourceType + if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return vecHasSourceType(t->parts); + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + { + if (vecHasSourceType(t->generics)) + return true; + + return false; + } + else if (auto t = get(ty)) + { + if (t->boundTo) + return *t->boundTo == sourceType; + + for (const auto& [_name, prop] : t->props) + { + if (prop.type == sourceType) + return true; + } + + if (auto indexer = t->indexer) + { + if (indexer->indexType == sourceType || indexer->indexResultType == sourceType) + return true; + } + + if (vecHasSourceType(t->instantiatedTypeParams)) + return true; + + return false; + } + else if (auto t = get(ty)) + return t->table == sourceType || t->metatable == sourceType; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return false; + else if (auto t = get(ty)) + return vecHasSourceType(t->options); + else if (auto t = get(ty)) + return vecHasSourceType(t->parts); + else if (auto t = get(ty)) + return false; + + LUAU_ASSERT(!"Luau::Replacer::isDirty internal error: Unknown TypeVar type"); + LUAU_UNREACHABLE(); + } + + bool isDirty(TypePackId tp) override + { + if (auto it = replacedPacks.find(tp)) + return false; + + if (auto pack = get(tp)) + { + for (TypeId ty : pack->head) + { + if (ty == sourceType) + return true; + } + return false; + } + else if (auto vtp = get(tp)) + return vtp->ty == sourceType; + else + return false; + } + + TypeId clean(TypeId ty) override + { + LUAU_ASSERT(sourceType && replacedType); + + // Walk every kind of TypeVar and create a copy with sourceType replaced by replacedType + // Before returning, memoize the result for later use. + + // Helpfully, Substitution::clone() only shallow-clones the kinds of types that we care to work with. This + // function returns the identity for things like primitives. + TypeId res = clone(ty); + + if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = getMutable(res)) + { + for (TypeId& part : t->parts) + { + if (part == sourceType) + part = replacedType; + } + } + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = getMutable(res)) + { + // The constituent typepacks are cleaned separately. We just need to walk the generics array. + for (TypeId& g : t->generics) + { + if (g == sourceType) + g = replacedType; + } + } + else if (auto t = getMutable(res)) + { + for (auto& [_key, prop] : t->props) + { + if (prop.type == sourceType) + prop.type = replacedType; + } + } + else if (auto t = getMutable(res)) + { + if (t->table == sourceType) + t->table = replacedType; + if (t->metatable == sourceType) + t->table = replacedType; + } + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else if (auto t = getMutable(res)) + { + for (TypeId& option : t->options) + { + if (option == sourceType) + option = replacedType; + } + } + else if (auto t = getMutable(res)) + { + for (TypeId& part : t->parts) + { + if (part == sourceType) + part = replacedType; + } + } + else if (auto t = get(res)) + LUAU_ASSERT(!"Impossible"); + else + LUAU_ASSERT(!"Luau::Replacer::clean internal error: Unknown TypeVar type"); + + replacedTypes[ty] = res; + return res; + } + + TypePackId clean(TypePackId tp) override + { + TypePackId res = clone(tp); + + if (auto pack = getMutable(res)) + { + for (TypeId& type : pack->head) + { + if (type == sourceType) + type = replacedType; + } + } + else if (auto vtp = getMutable(res)) + { + if (vtp->ty == sourceType) + vtp->ty = replacedType; + } + + replacedPacks[tp] = res; + return res; + } + + TypeId smartClone(TypeId t) + { + std::optional res = replace(t); + LUAU_ASSERT(res.has_value()); // TODO think about this + if (*res == t) + return clone(t); + return *res; + } +}; + +} // anonymous namespace + +bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice) +{ + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState}; + u.anyIsTop = true; + + u.tryUnify(subTy, superTy); + const bool ok = u.errors.empty() && u.log.empty(); + return ok; +} + +template +static bool areNormal_(const T& t, const DenseHashSet& seen, InternalErrorReporter& ice) +{ + int count = 0; + auto isNormal = [&](TypeId ty) { + ++count; + if (count >= FInt::LuauNormalizeIterationLimit) + ice.ice("Luau::areNormal hit iteration limit"); + + return ty->normal || seen.find(asMutable(ty)); + }; + + return std::all_of(begin(t), end(t), isNormal); +} + +static bool areNormal(const std::vector& types, const DenseHashSet& seen, InternalErrorReporter& ice) +{ + return areNormal_(types, seen, ice); +} + +static bool areNormal(TypePackId tp, const DenseHashSet& seen, InternalErrorReporter& ice) +{ + tp = follow(tp); + if (get(tp)) + return false; + + auto [head, tail] = flatten(tp); + + if (!areNormal_(head, seen, ice)) + return false; + + if (!tail) + return true; + + if (auto vtp = get(*tail)) + return vtp->ty->normal || seen.find(asMutable(vtp->ty)); + + return true; +} + +#define CHECK_ITERATION_LIMIT(...) \ + do \ + { \ + if (iterationLimit > FInt::LuauNormalizeIterationLimit) \ + { \ + limitExceeded = true; \ + return __VA_ARGS__; \ + } \ + ++iterationLimit; \ + } while (false) + +struct Normalize +{ + TypeArena& arena; + InternalErrorReporter& ice; + + // Debug data. Types being normalized are invalidated but trying to see what's going on is painful. + // To actually see the original type, read it by using the pointer of the type being normalized. + // e.g. in lldb, `e dump(originalTys[ty])`. + SeenTypes originalTys; + SeenTypePacks originalTps; + + int iterationLimit = 0; + bool limitExceeded = false; + + template + bool operator()(TypePackId, const T&) + { + return true; + } + + template + void cycle(TID) + { + } + + bool operator()(TypeId ty, const FreeTypeVar&) + { + LUAU_ASSERT(!ty->normal); + return false; + } + + bool operator()(TypeId ty, const BoundTypeVar& btv) + { + // It should never be the case that this TypeVar is normal, but is bound to a non-normal type. + LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal); + + asMutable(ty)->normal = btv.boundTo->normal; + return !ty->normal; + } + + bool operator()(TypeId ty, const PrimitiveTypeVar&) + { + LUAU_ASSERT(ty->normal); + return false; + } + + bool operator()(TypeId ty, const GenericTypeVar&) + { + if (!ty->normal) + asMutable(ty)->normal = true; + + return false; + } + + bool operator()(TypeId ty, const ErrorTypeVar&) + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool operator()(TypeId ty, const ConstrainedTypeVar& ctvRef, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + ConstrainedTypeVar* ctv = const_cast(&ctvRef); + + std::vector parts = std::move(ctv->parts); + + // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar + for (TypeId part : parts) + visit_detail::visit(part, *this, seen); + + std::vector newParts = normalizeUnion(parts); + + const bool normal = areNormal(newParts, seen, ice); + + if (newParts.size() == 1) + *asMutable(ty) = BoundTypeVar{newParts[0]}; + else + *asMutable(ty) = UnionTypeVar{std::move(newParts)}; + + asMutable(ty)->normal = normal; + + return false; + } + + bool operator()(TypeId ty, const FunctionTypeVar& ftv) = delete; + bool operator()(TypeId ty, const FunctionTypeVar& ftv, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + visit_detail::visit(ftv.argTypes, *this, seen); + visit_detail::visit(ftv.retType, *this, seen); + + asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retType, seen, ice); + + return false; + } + + bool operator()(TypeId ty, const TableTypeVar& ttv, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + bool normal = true; + + auto checkNormal = [&](TypeId t) { + // if t is on the stack, it is possible that this type is normal. + // If t is not normal and it is not on the stack, this type is definitely not normal. + if (!t->normal && !seen.find(asMutable(t))) + normal = false; + }; + + if (ttv.boundTo) + { + visit_detail::visit(*ttv.boundTo, *this, seen); + asMutable(ty)->normal = (*ttv.boundTo)->normal; + return false; + } + + for (const auto& [_name, prop] : ttv.props) + { + visit_detail::visit(prop.type, *this, seen); + checkNormal(prop.type); + } + + if (ttv.indexer) + { + visit_detail::visit(ttv.indexer->indexType, *this, seen); + checkNormal(ttv.indexer->indexType); + visit_detail::visit(ttv.indexer->indexResultType, *this, seen); + checkNormal(ttv.indexer->indexResultType); + } + + asMutable(ty)->normal = normal; + + return false; + } + + bool operator()(TypeId ty, const MetatableTypeVar& mtv, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + visit_detail::visit(mtv.table, *this, seen); + visit_detail::visit(mtv.metatable, *this, seen); + + asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal; + + return false; + } + + bool operator()(TypeId ty, const ClassTypeVar& ctv) + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool operator()(TypeId ty, const AnyTypeVar&) + { + LUAU_ASSERT(ty->normal); + return false; + } + + bool operator()(TypeId ty, const UnionTypeVar& utvRef, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + UnionTypeVar* utv = &const_cast(utvRef); + std::vector options = std::move(utv->options); + + // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar + for (TypeId option : options) + visit_detail::visit(option, *this, seen); + + std::vector newOptions = normalizeUnion(options); + + const bool normal = areNormal(newOptions, seen, ice); + + LUAU_ASSERT(!newOptions.empty()); + + if (newOptions.size() == 1) + *asMutable(ty) = BoundTypeVar{newOptions[0]}; + else + utv->options = std::move(newOptions); + + asMutable(ty)->normal = normal; + + return false; + } + + bool operator()(TypeId ty, const IntersectionTypeVar& itvRef, DenseHashSet& seen) + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + IntersectionTypeVar* itv = &const_cast(itvRef); + + std::vector oldParts = std::move(itv->parts); + + for (TypeId part : oldParts) + visit_detail::visit(part, *this, seen); + + std::vector tables; + for (TypeId part : oldParts) + { + part = follow(part); + if (get(part)) + tables.push_back(part); + else + { + Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD + combineIntoIntersection(replacer, itv, part); + } + } + + // Don't allocate a new table if there's just one in the intersection. + if (tables.size() == 1) + itv->parts.push_back(tables[0]); + else if (!tables.empty()) + { + const TableTypeVar* first = get(tables[0]); + LUAU_ASSERT(first); + + TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); + TableTypeVar* ttv = getMutable(newTable); + for (TypeId part : tables) + { + // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need + // to be rewritten to point at 'newTable' in the clone. + Replacer replacer{&arena, part, newTable}; + combineIntoTable(replacer, ttv, part); + } + + itv->parts.push_back(newTable); + } + + asMutable(ty)->normal = areNormal(itv->parts, seen, ice); + + if (itv->parts.size() == 1) + { + TypeId part = itv->parts[0]; + *asMutable(ty) = BoundTypeVar{part}; + } + + return false; + } + + bool operator()(TypeId ty, const LazyTypeVar&) + { + return false; + } + + std::vector normalizeUnion(const std::vector& options) + { + if (options.size() == 1) + return options; + + std::vector result; + + for (TypeId part : options) + combineIntoUnion(result, part); + + return result; + } + + void combineIntoUnion(std::vector& result, TypeId ty) + { + ty = follow(ty); + if (auto utv = get(ty)) + { + for (TypeId t : utv) + combineIntoUnion(result, t); + return; + } + + for (TypeId& part : result) + { + if (isSubtype(ty, part, ice)) + return; // no need to do anything + else if (isSubtype(part, ty, ice)) + { + part = ty; // replace the less general type by the more general one + return; + } + } + + result.push_back(ty); + } + + /** + * @param replacer knows how to clone a type such that any recursive references point at the new containing type. + * @param result is an intersection that is safe for us to mutate in-place. + */ + void combineIntoIntersection(Replacer& replacer, IntersectionTypeVar* result, TypeId ty) + { + // Note: this check guards against running out of stack space + // so if you increase the size of a stack frame, you'll need to decrease the limit. + CHECK_ITERATION_LIMIT(); + + ty = follow(ty); + if (auto itv = get(ty)) + { + for (TypeId part : itv->parts) + combineIntoIntersection(replacer, result, part); + return; + } + + // Let's say that the last part of our result intersection is always a table, if any table is part of this intersection + if (get(ty)) + { + if (result->parts.empty()) + result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); + + TypeId theTable = result->parts.back(); + + if (!get(FFlag::LuauNormalizeCombineIntersectionFix ? follow(theTable) : theTable)) + { + result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); + theTable = result->parts.back(); + } + + TypeId newTable = replacer.smartClone(theTable); + result->parts.back() = newTable; + + combineIntoTable(replacer, getMutable(newTable), ty); + } + else if (auto ftv = get(ty)) + { + bool merged = false; + for (TypeId& part : result->parts) + { + if (isSubtype(part, ty, ice)) + { + merged = true; + break; // no need to do anything + } + else if (isSubtype(ty, part, ice)) + { + merged = true; + part = ty; // replace the less general type by the more general one + break; + } + } + + if (!merged) + result->parts.push_back(ty); + } + else + result->parts.push_back(ty); + } + + TableState combineTableStates(TableState lhs, TableState rhs) + { + if (lhs == rhs) + return lhs; + + if (lhs == TableState::Free || rhs == TableState::Free) + return TableState::Free; + + if (lhs == TableState::Unsealed || rhs == TableState::Unsealed) + return TableState::Unsealed; + + return lhs; + } + + /** + * @param replacer gives us a way to clone a type such that recursive references are rewritten to the new + * "containing" type. + * @param table always points into a table that is safe for us to mutate. + */ + void combineIntoTable(Replacer& replacer, TableTypeVar* table, TypeId ty) + { + // Note: this check guards against running out of stack space + // so if you increase the size of a stack frame, you'll need to decrease the limit. + CHECK_ITERATION_LIMIT(); + + LUAU_ASSERT(table); + + ty = follow(ty); + + TableTypeVar* tyTable = getMutable(ty); + LUAU_ASSERT(tyTable); + + for (const auto& [propName, prop] : tyTable->props) + { + if (auto it = table->props.find(propName); it != table->props.end()) + { + /** + * If we are going to recursively merge intersections of tables, we need to ensure that we never mutate + * a table that comes from somewhere else in the type graph. + * + * smarClone() does some nice things for us: It will perform a clone that is as shallow as possible + * while still rewriting any cyclic references back to the new 'root' table. + * + * replacer also keeps a mapping of types that have previously been copied, so we have the added + * advantage here of knowing that, whether or not a new copy was actually made, the resulting TypeVar is + * safe for us to mutate in-place. + */ + TypeId clone = replacer.smartClone(it->second.type); + it->second.type = combine(replacer, clone, prop.type); + } + else + table->props.insert({propName, prop}); + } + + table->state = combineTableStates(table->state, tyTable->state); + table->level = max(table->level, tyTable->level); + } + + /** + * @param a is always cloned by the caller. It is safe to mutate in-place. + * @param b will never be mutated. + */ + TypeId combine(Replacer& replacer, TypeId a, TypeId b) + { + if (FFlag::LuauNormalizeCombineTableFix && a == b) + return a; + + if (!get(a) && !get(a)) + { + if (!FFlag::LuauNormalizeCombineTableFix && a == b) + return a; + else + return arena.addType(IntersectionTypeVar{{a, b}}); + } + + if (auto itv = getMutable(a)) + { + combineIntoIntersection(replacer, itv, b); + return a; + } + else if (auto ttv = getMutable(a)) + { + if (FFlag::LuauNormalizeCombineTableFix && !get(follow(b))) + return arena.addType(IntersectionTypeVar{{a, b}}); + combineIntoTable(replacer, ttv, b); + return a; + } + + LUAU_ASSERT(!"Impossible"); + LUAU_UNREACHABLE(); + } +}; + +#undef CHECK_ITERATION_LIMIT + +/** + * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) + */ +std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice) +{ + CloneState state; + if (FFlag::DebugLuauCopyBeforeNormalizing) + (void)clone(ty, arena, state); + + Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(ty, n, seen); + + return {ty, !n.limitExceeded}; +} + +// TODO: Think about using a temporary arena and cloning types out of it so that we +// reclaim memory used by wantonly allocated intermediate types here. +// The main wrinkle here is that we don't want clone() to copy a type if the source and dest +// arena are the same. +std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice) +{ + return normalize(ty, module->internalTypes, ice); +} + +/** + * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) + */ +std::pair normalize(TypePackId tp, TypeArena& arena, InternalErrorReporter& ice) +{ + CloneState state; + if (FFlag::DebugLuauCopyBeforeNormalizing) + (void)clone(tp, arena, state); + + Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(tp, n, seen); + + return {tp, !n.limitExceeded}; +} + +std::pair normalize(TypePackId tp, const ModulePtr& module, InternalErrorReporter& ice) +{ + return normalize(tp, module->internalTypes, ice); +} + +} // namespace Luau diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 94e169f..305f83c 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -4,6 +4,8 @@ #include "Luau/VisitTypeVar.h" +LUAU_FASTFLAG(LuauTypecheckOptPass) + namespace Luau { @@ -12,6 +14,8 @@ struct Quantifier TypeLevel level; std::vector generics; std::vector genericPacks; + bool seenGenericType = false; + bool seenMutableType = false; Quantifier(TypeLevel level) : level(level) @@ -23,6 +27,9 @@ struct Quantifier bool operator()(TypeId ty, const FreeTypeVar& ftv) { + if (FFlag::LuauTypecheckOptPass) + seenMutableType = true; + if (!level.subsumes(ftv.level)) return false; @@ -44,17 +51,40 @@ struct Quantifier return true; } + bool operator()(TypeId ty, const ConstrainedTypeVar&) + { + return true; + } + bool operator()(TypeId ty, const TableTypeVar&) { TableTypeVar& ttv = *getMutable(ty); + if (FFlag::LuauTypecheckOptPass) + { + if (ttv.state == TableState::Generic) + seenGenericType = true; + + if (ttv.state == TableState::Free) + seenMutableType = true; + } + if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) return false; if (!level.subsumes(ttv.level)) + { + if (FFlag::LuauTypecheckOptPass && ttv.state == TableState::Unsealed) + seenMutableType = true; return false; + } if (ttv.state == TableState::Free) + { ttv.state = TableState::Generic; + + if (FFlag::LuauTypecheckOptPass) + seenGenericType = true; + } else if (ttv.state == TableState::Unsealed) ttv.state = TableState::Sealed; @@ -65,6 +95,9 @@ struct Quantifier bool operator()(TypePackId tp, const FreeTypePack& ftp) { + if (FFlag::LuauTypecheckOptPass) + seenMutableType = true; + if (!level.subsumes(ftp.level)) return false; @@ -84,6 +117,9 @@ void quantify(TypeId ty, TypeLevel level) LUAU_ASSERT(ftv); ftv->generics = q.generics; ftv->genericPacks = q.genericPacks; + + if (FFlag::LuauTypecheckOptPass && ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) + ftv->hasNoGenerics = true; } } // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 770c7a4..8648b21 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -7,24 +7,36 @@ #include #include +LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) +LUAU_FASTFLAG(LuauTypecheckOptPass) +LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowNewTypes, false) namespace Luau { void Tarjan::visitChildren(TypeId ty, int index) { - ty = log->follow(ty); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(ty == log->follow(ty)); + else + ty = log->follow(ty); if (ignoreChildren(ty)) return; - if (const FunctionTypeVar* ftv = log->getMutable(ty)) + if (FFlag::LuauTypecheckOptPass) + { + if (auto pty = log->pending(ty)) + ty = &pty->pending; + } + + if (const FunctionTypeVar* ftv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { visitChild(ftv->argTypes); visitChild(ftv->retType); } - else if (const TableTypeVar* ttv = log->getMutable(ty)) + else if (const TableTypeVar* ttv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { LUAU_ASSERT(!ttv->boundTo); for (const auto& [name, prop] : ttv->props) @@ -41,38 +53,52 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypePackId itp : ttv->instantiatedTypePackParams) visitChild(itp); } - else if (const MetatableTypeVar* mtv = log->getMutable(ty)) + else if (const MetatableTypeVar* mtv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { visitChild(mtv->table); visitChild(mtv->metatable); } - else if (const UnionTypeVar* utv = log->getMutable(ty)) + else if (const UnionTypeVar* utv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { for (TypeId opt : utv->options) visitChild(opt); } - else if (const IntersectionTypeVar* itv = log->getMutable(ty)) + else if (const IntersectionTypeVar* itv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { for (TypeId part : itv->parts) visitChild(part); } + else if (const ConstrainedTypeVar* ctv = get(ty)) + { + for (TypeId part : ctv->parts) + visitChild(part); + } } void Tarjan::visitChildren(TypePackId tp, int index) { - tp = log->follow(tp); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(tp == log->follow(tp)); + else + tp = log->follow(tp); if (ignoreChildren(tp)) return; - if (const TypePack* tpp = log->getMutable(tp)) + if (FFlag::LuauTypecheckOptPass) + { + if (auto ptp = log->pending(tp)) + tp = &ptp->pending; + } + + if (const TypePack* tpp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) { for (TypeId tv : tpp->head) visitChild(tv); if (tpp->tail) visitChild(*tpp->tail); } - else if (const VariadicTypePack* vtp = log->getMutable(tp)) + else if (const VariadicTypePack* vtp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) { visitChild(vtp->ty); } @@ -80,7 +106,10 @@ void Tarjan::visitChildren(TypePackId tp, int index) std::pair Tarjan::indexify(TypeId ty) { - ty = log->follow(ty); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(ty == log->follow(ty)); + else + ty = log->follow(ty); bool fresh = !typeToIndex.contains(ty); int& index = typeToIndex[ty]; @@ -98,7 +127,10 @@ std::pair Tarjan::indexify(TypeId ty) std::pair Tarjan::indexify(TypePackId tp) { - tp = log->follow(tp); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(tp == log->follow(tp)); + else + tp = log->follow(tp); bool fresh = !packToIndex.contains(tp); int& index = packToIndex[tp]; @@ -141,7 +173,7 @@ TarjanResult Tarjan::loop() if (currEdge == -1) { ++childCount; - if (FInt::LuauTarjanChildLimit > 0 && FInt::LuauTarjanChildLimit < childCount) + if (childLimit > 0 && childLimit < childCount) return TarjanResult::TooManyChildren; stack.push_back(index); @@ -229,6 +261,9 @@ TarjanResult Tarjan::loop() TarjanResult Tarjan::visitRoot(TypeId ty) { childCount = 0; + if (childLimit == 0) + childLimit = FInt::LuauTarjanChildLimit; + ty = log->follow(ty); auto [index, fresh] = indexify(ty); @@ -239,6 +274,9 @@ TarjanResult Tarjan::visitRoot(TypeId ty) TarjanResult Tarjan::visitRoot(TypePackId tp) { childCount = 0; + if (childLimit == 0) + childLimit = FInt::LuauTarjanChildLimit; + tp = log->follow(tp); auto [index, fresh] = indexify(tp); @@ -347,7 +385,13 @@ TypeId Substitution::clone(TypeId ty) TypeId result = ty; - if (const FunctionTypeVar* ftv = log->getMutable(ty)) + if (FFlag::LuauTypecheckOptPass) + { + if (auto pty = log->pending(ty)) + ty = &pty->pending; + } + + if (const FunctionTypeVar* ftv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; clone.generics = ftv->generics; @@ -357,7 +401,7 @@ TypeId Substitution::clone(TypeId ty) clone.argNames = ftv->argNames; result = addType(std::move(clone)); } - else if (const TableTypeVar* ttv = log->getMutable(ty)) + else if (const TableTypeVar* ttv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { LUAU_ASSERT(!ttv->boundTo); TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -370,24 +414,29 @@ TypeId Substitution::clone(TypeId ty) clone.tags = ttv->tags; result = addType(std::move(clone)); } - else if (const MetatableTypeVar* mtv = log->getMutable(ty)) + else if (const MetatableTypeVar* mtv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; clone.syntheticName = mtv->syntheticName; result = addType(std::move(clone)); } - else if (const UnionTypeVar* utv = log->getMutable(ty)) + else if (const UnionTypeVar* utv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { UnionTypeVar clone; clone.options = utv->options; result = addType(std::move(clone)); } - else if (const IntersectionTypeVar* itv = log->getMutable(ty)) + else if (const IntersectionTypeVar* itv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) { IntersectionTypeVar clone; clone.parts = itv->parts; result = addType(std::move(clone)); } + else if (const ConstrainedTypeVar* ctv = get(ty)) + { + ConstrainedTypeVar clone{ctv->level, ctv->parts}; + result = addType(std::move(clone)); + } asMutable(result)->documentationSymbol = ty->documentationSymbol; return result; @@ -396,14 +445,21 @@ TypeId Substitution::clone(TypeId ty) TypePackId Substitution::clone(TypePackId tp) { tp = log->follow(tp); - if (const TypePack* tpp = log->getMutable(tp)) + + if (FFlag::LuauTypecheckOptPass) + { + if (auto ptp = log->pending(tp)) + tp = &ptp->pending; + } + + if (const TypePack* tpp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) { TypePack clone; clone.head = tpp->head; clone.tail = tpp->tail; return addTypePack(std::move(clone)); } - else if (const VariadicTypePack* vtp = log->getMutable(tp)) + else if (const VariadicTypePack* vtp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) { VariadicTypePack clone; clone.ty = vtp->ty; @@ -415,25 +471,34 @@ TypePackId Substitution::clone(TypePackId tp) void Substitution::foundDirty(TypeId ty) { - ty = log->follow(ty); - if (isDirty(ty)) - newTypes[ty] = clean(ty); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(ty == log->follow(ty)); else - newTypes[ty] = clone(ty); + ty = log->follow(ty); + + if (isDirty(ty)) + newTypes[ty] = FFlag::LuauSubstituteFollowNewTypes ? follow(clean(ty)) : clean(ty); + else + newTypes[ty] = FFlag::LuauSubstituteFollowNewTypes ? follow(clone(ty)) : clone(ty); } void Substitution::foundDirty(TypePackId tp) { - tp = log->follow(tp); - if (isDirty(tp)) - newPacks[tp] = clean(tp); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(tp == log->follow(tp)); else - newPacks[tp] = clone(tp); + tp = log->follow(tp); + + if (isDirty(tp)) + newPacks[tp] = FFlag::LuauSubstituteFollowNewTypes ? follow(clean(tp)) : clean(tp); + else + newPacks[tp] = FFlag::LuauSubstituteFollowNewTypes ? follow(clone(tp)) : clone(tp); } TypeId Substitution::replace(TypeId ty) { ty = log->follow(ty); + if (TypeId* prevTy = newTypes.find(ty)) return *prevTy; else @@ -443,6 +508,7 @@ TypeId Substitution::replace(TypeId ty) TypePackId Substitution::replace(TypePackId tp) { tp = log->follow(tp); + if (TypePackId* prevTp = newPacks.find(tp)) return *prevTp; else @@ -451,7 +517,13 @@ TypePackId Substitution::replace(TypePackId tp) void Substitution::replaceChildren(TypeId ty) { - ty = log->follow(ty); + if (BoundTypeVar* btv = log->getMutable(ty); FFlag::LuauLowerBoundsCalculation && btv) + btv->boundTo = replace(btv->boundTo); + + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(ty == log->follow(ty)); + else + ty = log->follow(ty); if (ignoreChildren(ty)) return; @@ -493,11 +565,19 @@ void Substitution::replaceChildren(TypeId ty) for (TypeId& part : itv->parts) part = replace(part); } + else if (ConstrainedTypeVar* ctv = getMutable(ty)) + { + for (TypeId& part : ctv->parts) + part = replace(part); + } } void Substitution::replaceChildren(TypePackId tp) { - tp = log->follow(tp); + if (FFlag::LuauTypecheckOptPass) + LUAU_ASSERT(tp == log->follow(tp)); + else + tp = log->follow(tp); if (ignoreChildren(tp)) return; diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index df9d418..cb54bfc 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -237,6 +237,15 @@ void StateDot::visitChildren(TypeId ty, int index) finishNodeLabel(ty); finishNode(); } + else if (const ConstrainedTypeVar* ctv = get(ty)) + { + formatAppend(result, "ConstrainedTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId part : ctv->parts) + visitChild(part, index); + } else if (get(ty)) { formatAppend(result, "ErrorTypeVar %d", index); @@ -258,6 +267,28 @@ void StateDot::visitChildren(TypeId ty, int index) if (ctv->metatable) visitChild(*ctv->metatable, index, "[metatable]"); } + else if (const SingletonTypeVar* stv = get(ty)) + { + std::string res; + + if (const StringSingleton* ss = get(stv)) + { + // Don't put in quotes anywhere. If it's outside of the call to escape, + // then it's invalid syntax. If it's inside, then escaping is super noisy. + res = "string: " + escape(ss->value); + } + else if (const BooleanSingleton* bs = get(stv)) + { + res = "boolean: "; + res += bs->value ? "true" : "false"; + } + else + LUAU_ASSERT(!"unknown singleton type"); + + formatAppend(result, "SingletonTypeVar %s", res.c_str()); + finishNodeLabel(ty); + finishNode(); + } else { LUAU_ASSERT(!"unknown type kind"); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 59ee6de..610842d 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,6 +10,8 @@ #include #include +LUAU_FASTFLAG(LuauLowerBoundsCalculation) + /* * Prefix generic typenames with gen- * Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4 @@ -33,8 +35,8 @@ struct FindCyclicTypes bool exhaustive = false; std::unordered_set visited; std::unordered_set visitedPacks; - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; void cycle(TypeId ty) { @@ -86,7 +88,7 @@ struct FindCyclicTypes }; template -void findCyclicTypes(std::unordered_set& cycles, std::unordered_set& cycleTPs, TID ty, bool exhaustive) +void findCyclicTypes(std::set& cycles, std::set& cycleTPs, TID ty, bool exhaustive) { FindCyclicTypes fct; fct.exhaustive = exhaustive; @@ -124,6 +126,7 @@ struct StringifierState std::unordered_map cycleTpNames; std::unordered_set seen; std::unordered_set usedNames; + size_t indentation = 0; bool exhaustive; @@ -216,6 +219,34 @@ struct StringifierState result.name += s; } + + void indent() + { + indentation += 4; + } + + void dedent() + { + indentation -= 4; + } + + void newline() + { + if (!opts.useLineBreaks) + return emit(" "); + + emit("\n"); + emitIndentation(); + } + +private: + void emitIndentation() + { + if (!opts.indent) + return; + + emit(std::string(indentation, ' ')); + } }; struct TypeVarStringifier @@ -321,7 +352,7 @@ struct TypeVarStringifier stringify(btv.boundTo); } - void operator()(TypeId ty, const Unifiable::Generic& gtv) + void operator()(TypeId ty, const GenericTypeVar& gtv) { if (gtv.explicitName) { @@ -332,6 +363,26 @@ struct TypeVarStringifier state.emit(state.getName(ty)); } + void operator()(TypeId, const ConstrainedTypeVar& ctv) + { + state.result.invalid = true; + + state.emit("[["); + + bool first = true; + for (TypeId ty : ctv.parts) + { + if (first) + first = false; + else + state.emit("|"); + + stringify(ty); + } + + state.emit("]]"); + } + void operator()(TypeId, const PrimitiveTypeVar& ptv) { switch (ptv.type) @@ -415,10 +466,25 @@ struct TypeVarStringifier state.emit(") -> "); bool plural = true; - if (auto retPack = get(follow(ftv.retType))) + + if (FFlag::LuauLowerBoundsCalculation) { - if (retPack->head.size() == 1 && !retPack->tail) - plural = false; + auto retBegin = begin(ftv.retType); + auto retEnd = end(ftv.retType); + if (retBegin != retEnd) + { + ++retBegin; + if (retBegin == retEnd && !retBegin.tail()) + plural = false; + } + } + else + { + if (auto retPack = get(follow(ftv.retType))) + { + if (retPack->head.size() == 1 && !retPack->tail) + plural = false; + } } if (plural) @@ -511,6 +577,7 @@ struct TypeVarStringifier } state.emit(openbrace); + state.indent(); bool comma = false; if (ttv.indexer) @@ -527,7 +594,10 @@ struct TypeVarStringifier for (const auto& [name, prop] : ttv.props) { if (comma) - state.emit(state.opts.useLineBreaks ? ",\n" : ", "); + { + state.emit(","); + state.newline(); + } size_t length = state.result.name.length() - oldLength; @@ -553,6 +623,7 @@ struct TypeVarStringifier ++index; } + state.dedent(); state.emit(closedbrace); state.unsee(&ttv); @@ -563,7 +634,8 @@ struct TypeVarStringifier state.result.invalid = true; state.emit("{ @metatable "); stringify(mtv.metatable); - state.emit(state.opts.useLineBreaks ? ",\n" : ", "); + state.emit(","); + state.newline(); stringify(mtv.table); state.emit(" }"); } @@ -784,13 +856,16 @@ struct TypePackStringifier if (tp.tail && !isEmpty(*tp.tail)) { - const auto& tail = *tp.tail; - if (first) - first = false; - else - state.emit(", "); + TypePackId tail = follow(*tp.tail); + if (auto vtp = get(tail); !vtp || (!FFlag::DebugLuauVerboseTypeNames && !vtp->hidden)) + { + if (first) + first = false; + else + state.emit(", "); - stringify(tail); + stringify(tail); + } } state.unsee(&tp); @@ -805,6 +880,8 @@ struct TypePackStringifier void operator()(TypePackId, const VariadicTypePack& pack) { state.emit("..."); + if (FFlag::DebugLuauVerboseTypeNames && pack.hidden) + state.emit(""); stringify(pack.ty); } @@ -858,15 +935,12 @@ void TypeVarStringifier::stringify(TypePackId tpid, const std::vector& cycles, const std::unordered_set& cycleTPs, +static void assignCycleNames(const std::set& cycles, const std::set& cycleTPs, std::unordered_map& cycleNames, std::unordered_map& cycleTpNames, bool exhaustive) { int nextIndex = 1; - std::vector sortedCycles{cycles.begin(), cycles.end()}; - std::sort(sortedCycles.begin(), sortedCycles.end(), std::less{}); - - for (TypeId cycleTy : sortedCycles) + for (TypeId cycleTy : cycles) { std::string name; @@ -888,10 +962,7 @@ static void assignCycleNames(const std::unordered_set& cycles, const std cycleNames[cycleTy] = std::move(name); } - std::vector sortedCycleTps{cycleTPs.begin(), cycleTPs.end()}; - std::sort(sortedCycleTps.begin(), sortedCycleTps.end(), std::less()); - - for (TypePackId tp : sortedCycleTps) + for (TypePackId tp : cycleTPs) { std::string name = "tp" + std::to_string(nextIndex); ++nextIndex; @@ -913,8 +984,8 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) StringifierState state{opts, result, opts.nameMap}; - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; findCyclicTypes(cycles, cycleTPs, ty, opts.exhaustive); @@ -1016,8 +1087,8 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) ToStringResult result; StringifierState state{opts, result, opts.nameMap}; - std::unordered_set cycles; - std::unordered_set cycleTPs; + std::set cycles; + std::set cycleTPs; findCyclicTypes(cycles, cycleTPs, tp, opts.exhaustive); @@ -1058,7 +1129,7 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) state.emit(name); state.emit(" = "); Luau::visit( - [&tvs, cycleTy = cycleTy](auto&& t) { + [&tvs, cycleTy = cycleTy](auto t) { return tvs(cycleTy, t); }, cycleTy->ty); @@ -1163,14 +1234,18 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp if (argPackIter.tail()) { - if (!first) - state.emit(", "); + if (auto vtp = get(*argPackIter.tail()); !vtp || !vtp->hidden) + { + if (!first) + state.emit(", "); - state.emit("...: "); - if (auto vtp = get(*argPackIter.tail())) - tvs.stringify(vtp->ty); - else - tvs.stringify(*argPackIter.tail()); + state.emit("...: "); + + if (vtp) + tvs.stringify(vtp->ty); + else + tvs.stringify(*argPackIter.tail()); + } } state.emit("): "); @@ -1210,6 +1285,24 @@ std::string dump(TypePackId ty) return s; } +std::string dump(const ScopePtr& scope, const char* name) +{ + auto binding = scope->linearSearchForBinding(name); + if (!binding) + { + printf("No binding %s\n", name); + return {}; + } + + TypeId ty = binding->typeId; + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + std::string s = toString(ty, opts); + printf("%s\n", s.c_str()); + return s; +} + std::string generateName(size_t i) { std::string n; diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index 678001b..1ea2e27 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -215,6 +215,7 @@ struct ArcCollector : public AstVisitor } } + // Adds a dependency from the current node to the named node. void add(const Identifier& name) { Node** it = map.find(name); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 5fbb596..a5f9d26 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauTxnLogPreserveOwner, false) +LUAU_FASTFLAGVARIABLE(LuauJustOneCallFrameForHaveSeen, false) namespace Luau { @@ -161,18 +162,37 @@ void TxnLog::popSeen(TypePackId lhs, TypePackId rhs) bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const { - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) + if (FFlag::LuauJustOneCallFrameForHaveSeen && !FFlag::LuauTypecheckOptPass) { - return true; - } + // This function will technically work if `this` is nullptr, but this + // indicates a bug, so we explicitly assert. + LUAU_ASSERT(static_cast(this) != nullptr); - if (parent) + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + + for (const TxnLog* current = this; current; current = current->parent) + { + if (current->sharedSeen->end() != std::find(current->sharedSeen->begin(), current->sharedSeen->end(), sortedPair)) + return true; + } + + return false; + } + else { - return parent->haveSeen(lhs, rhs); - } + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) + { + return true; + } - return false; + if (!FFlag::LuauTypecheckOptPass && parent) + { + return parent->haveSeen(lhs, rhs); + } + + return false; + } } void TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs) @@ -222,8 +242,8 @@ PendingType* TxnLog::pending(TypeId ty) const for (const TxnLog* current = this; current; current = current->parent) { - if (auto it = current->typeVarChanges.find(ty); it != current->typeVarChanges.end()) - return it->second.get(); + if (auto it = current->typeVarChanges.find(ty)) + return it->get(); } return nullptr; @@ -237,8 +257,8 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const for (const TxnLog* current = this; current; current = current->parent) { - if (auto it = current->typePackChanges.find(tp); it != current->typePackChanges.end()) - return it->second.get(); + if (auto it = current->typePackChanges.find(tp)) + return it->get(); } return nullptr; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index d575e02..bc8d0d4 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -94,6 +94,16 @@ public: } } + AstType* operator()(const ConstrainedTypeVar& ctv) + { + AstArray types; + types.size = ctv.parts.size(); + types.data = static_cast(allocator->allocate(sizeof(AstType*) * ctv.parts.size())); + for (size_t i = 0; i < ctv.parts.size(); ++i) + types.data[i] = Luau::visit(*this, ctv.parts[i]->ty); + return allocator->alloc(Location(), types); + } + AstType* operator()(const SingletonTypeVar& stv) { if (const BooleanSingleton* bs = get(&stv)) @@ -364,6 +374,9 @@ public: AstTypePack* operator()(const VariadicTypePack& vtp) const { + if (vtp.hidden) + return nullptr; + return allocator->alloc(Location(), Luau::visit(*typeVisitor, vtp.ty->ty)); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 1093024..af42a4e 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -3,12 +3,15 @@ #include "Luau/Common.h" #include "Luau/ModuleResolver.h" +#include "Luau/Normalize.h" +#include "Luau/Parser.h" #include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/Substitution.h" #include "Luau/TopoSortStatements.h" #include "Luau/TypePack.h" +#include "Luau/ToString.h" #include "Luau/TypeUtils.h" #include "Luau/ToString.h" #include "Luau/TypeVar.h" @@ -19,14 +22,17 @@ LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) +LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauSeparateTypechecks) +LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAG(LuauAutocompleteSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauCyclicModuleTypeSurface, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. +LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) @@ -39,6 +45,7 @@ LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false) +LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAG(LuauTypeMismatchModuleName) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) @@ -53,6 +60,8 @@ LUAU_FASTFLAGVARIABLE(LuauCheckImplicitNumbericKeys, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false) LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false) +LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false) +LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); namespace Luau { @@ -140,6 +149,34 @@ bool hasBreak(AstStat* node) } } +static bool hasReturn(const AstStat* node) +{ + struct Searcher : AstVisitor + { + bool result = false; + + bool visit(AstStat*) override + { + return !result; // if we've already found a return statement, don't bother to traverse inward anymore + } + + bool visit(AstStatReturn*) override + { + result = true; + return false; + } + + bool visit(AstExprFunction*) override + { + return false; // We don't care if the function uses a lambda that itself returns + } + }; + + Searcher searcher; + const_cast(node)->visit(&searcher); + return searcher.result; +} + // returns the last statement before the block exits, or nullptr if the block never exits const AstStat* getFallthrough(const AstStat* node) { @@ -253,6 +290,26 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan } ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) +{ + if (FFlag::LuauRecursionLimitException) + { + try + { + return checkWithoutRecursionCheck(module, mode, environmentScope); + } + catch (const RecursionLimitException&) + { + reportErrorCodeTooComplex(module.root->location); + return std::move(currentModule); + } + } + else + { + return checkWithoutRecursionCheck(module, mode, environmentScope); + } +} + +ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope) { LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker"); LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); @@ -268,6 +325,12 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona iceHandler->moduleName = module.name; + if (FFlag::LuauAutocompleteDynamicLimits) + { + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; + } + ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr moduleScope = std::make_shared(parentScope); @@ -312,7 +375,7 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona prepareErrorsForDisplay(currentModule->errors); - bool encounteredFreeType = currentModule->clonePublicInterface(); + bool encounteredFreeType = currentModule->clonePublicInterface(*iceHandler); if (encounteredFreeType) { reportError(TypeError{module.root->location, @@ -415,7 +478,26 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) reportErrorCodeTooComplex(block.location); return; } + if (FFlag::LuauRecursionLimitException) + { + try + { + checkBlockWithoutRecursionCheck(scope, block); + } + catch (const RecursionLimitException&) + { + reportErrorCodeTooComplex(block.location); + return; + } + } + else + { + checkBlockWithoutRecursionCheck(scope, block); + } +} +void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) +{ int subLevel = 0; std::vector sorted(block.body.data, block.body.data + block.body.size); @@ -435,6 +517,16 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) std::unordered_map> functionDecls; + auto isLocalLambda = [](AstStat* stat) -> AstStatLocal* { + AstStatLocal* local = stat->as(); + + if (FFlag::LuauLowerBoundsCalculation && local && local->vars.size == 1 && local->values.size == 1 && + local->values.data[0]->is()) + return local; + else + return nullptr; + }; + auto checkBody = [&](AstStat* stat) { if (auto fun = stat->as()) { @@ -482,7 +574,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // function f(x:a):a local x: number = g(37) return x end // function g(x:number):number return f(x) end // ``` - if (containsFunctionCallOrReturn(**protoIter)) + if (containsFunctionCallOrReturn(**protoIter) || (FFlag::LuauLowerBoundsCalculation && isLocalLambda(*protoIter))) { while (checkIter != protoIter) { @@ -513,7 +605,8 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) functionDecls[*protoIter] = pair; ++subLevel; - TypeId leftType = checkFunctionName(scope, *fun->name, funScope->level); + TypeId leftType = follow(checkFunctionName(scope, *fun->name, funScope->level)); + unify(funTy, leftType, fun->location); } else if (auto fun = (*protoIter)->as()) @@ -658,6 +751,16 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) checkExpr(repScope, *statement.condition); } +void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, const Location& location) +{ + Unifier state = mkUnifier(location); + state.unifyLowerBound(subTy, superTy); + + state.log.commit(); + + reportErrors(state.errors); +} + void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) { std::vector> expectedTypes; @@ -682,6 +785,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; + if (useConstrainedIntersections()) + { + unifyLowerBound(retPack, scope->returnType, return_.location); + return; + } + // HACK: Nonstrict mode gets a bit too smart and strict for us when we // start typechecking everything across module boundaries. if (isNonstrictMode() && follow(scope->returnType) == follow(currentModule->getModuleScope()->returnType)) @@ -1209,9 +1318,11 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco else if (tableSelf->state == TableState::Sealed) reportError(TypeError{function.location, CannotExtendTable{selfTy, CannotExtendTable::Property, indexName->index.value}}); + const bool tableIsExtendable = tableSelf && tableSelf->state != TableState::Sealed; + ty = follow(ty); - if (tableSelf && tableSelf->state != TableState::Sealed) + if (tableIsExtendable) tableSelf->props[indexName->index.value] = {ty, /* deprecated */ false, {}, indexName->indexLocation}; const FunctionTypeVar* funTy = get(ty); @@ -1224,7 +1335,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - if (tableSelf && tableSelf->state != TableState::Sealed) + if (tableIsExtendable) tableSelf->props[indexName->index.value] = { follow(quantify(funScope, ty, indexName->indexLocation)), /* deprecated */ false, {}, indexName->indexLocation}; } @@ -1372,7 +1483,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias for (auto param : binding->typePackParams) clone.instantiatedTypePackParams.push_back(param.tp); + bool isNormal = ty->normal; ty = addType(std::move(clone)); + + if (FFlag::LuauLowerBoundsCalculation) + asMutable(ty)->normal = isNormal; } } else @@ -1400,6 +1515,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (FFlag::LuauTwoPassAliasDefinitionFix && ok) bindingType = ty; + + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(bindingType, currentModule, *iceHandler); + bindingType = t; + if (!ok) + reportError(typealias.location, NormalizationTooComplex{}); + } } } @@ -1673,10 +1796,11 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa { return {pack->head.empty() ? nilType : pack->head[0], std::move(result.predicates)}; } - else if (get(retPack)) + else if (const FreeTypePack* ftp = get(retPack)) { - TypeId head = freshType(scope); - TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); + TypeLevel level = FFlag::LuauLowerBoundsCalculation ? ftp->level : scope->level; + TypeId head = freshType(level); + TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(level)}}); unify(pack, retPack, expr.location); return {head, std::move(result.predicates)}; } @@ -1793,7 +1917,7 @@ std::optional TypeChecker::getIndexTypeFromType( for (TypeId t : utv) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "getIndexTypeForType unions"); // Not needed when we normalize types. if (get(follow(t))) @@ -1817,12 +1941,25 @@ std::optional TypeChecker::getIndexTypeFromType( return std::nullopt; } - std::vector result = reduceUnion(goodOptions); + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(addType(UnionTypeVar{std::move(goodOptions)}), currentModule, + *iceHandler); // FIXME Inefficient. We craft a UnionTypeVar and immediately throw it away. - if (result.size() == 1) - return result[0]; + if (!ok) + reportError(location, NormalizationTooComplex{}); - return addType(UnionTypeVar{std::move(result)}); + return t; + } + else + { + std::vector result = reduceUnion(goodOptions); + + if (result.size() == 1) + return result[0]; + + return addType(UnionTypeVar{std::move(result)}); + } } else if (const IntersectionTypeVar* itv = get(type)) { @@ -1830,7 +1967,7 @@ std::optional TypeChecker::getIndexTypeFromType( for (TypeId t : itv->parts) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "getIndexTypeFromType intersections"); if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) parts.push_back(*ty); @@ -1982,7 +2119,6 @@ TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) { if (!std::any_of(begin(utv), end(utv), isNil)) return ty; - } if (std::optional strippedUnion = tryStripUnionFromNil(ty)) @@ -2124,7 +2260,26 @@ TypeId TypeChecker::checkExprTable( ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + if (FFlag::LuauTableUseCounterInstead) + { + RecursionCounter _rc(&checkRecursionCount); + if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) + { + reportErrorCodeTooComplex(expr.location); + return {errorRecoveryType(scope)}; + } + + return checkExpr_(scope, expr, expectedType); + } + else + { + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "checkExpr for tables"); + return checkExpr_(scope, expr, expectedType); + } +} + +ExprResult TypeChecker::checkExpr_(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) +{ std::vector> fieldTypes(expr.items.size); const TableTypeVar* expectedTable = nullptr; @@ -3176,6 +3331,10 @@ std::pair TypeChecker::checkFunctionSignature( funScope->varargPack = anyTypePack; } } + else if (FFlag::LuauLowerBoundsCalculation && !isNonstrictMode()) + { + funScope->varargPack = addTypePack(TypePackVar{VariadicTypePack{anyType, /*hidden*/ true}}); + } std::vector argTypes; @@ -3311,9 +3470,24 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE { check(scope, *function.body); - // We explicitly don't follow here to check if we have a 'true' free type instead of bound one - if (get_if(&funTy->retType->ty)) - *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; + if (useConstrainedIntersections()) + { + TypePackId retPack = follow(funTy->retType); + // It is possible for a function to have no annotation and no return statement, and yet still have an ascribed return type + // if it is expected to conform to some other interface. (eg the function may be a lambda passed as a callback) + if (!hasReturn(function.body) && !function.returnAnnotation.has_value() && get(retPack)) + { + auto level = getLevel(retPack); + if (level && scope->level.subsumes(*level)) + *asMutable(retPack) = TypePack{{}, std::nullopt}; + } + } + else + { + // We explicitly don't follow here to check if we have a 'true' free type instead of bound one + if (get_if(&funTy->retType->ty)) + *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; + } bool reachesImplicitReturn = getFallthrough(function.body) != nullptr; @@ -3418,6 +3592,19 @@ void TypeChecker::checkArgumentList( size_t minParams = FFlag::LuauFixIncorrectLineNumberDuplicateType ? 0 : getMinParameterCount_DEPRECATED(paramPack); + auto reportCountMismatchError = [&state, &argLocations, minParams, paramPack, argPack]() { + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + + size_t mp = minParams; + if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) + mp = getMinParameterCount(&state.log, paramPack); + + state.reportError(TypeError{location, CountMismatch{mp, std::distance(begin(argPack), end(argPack))}}); + }; + while (true) { state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; @@ -3472,6 +3659,8 @@ void TypeChecker::checkArgumentList( } else if (auto vtp = state.log.getMutable(tail)) { + // Function is variadic and requires that all subsequent parameters + // be compatible with a type. while (paramIter != endIter) { state.tryUnify(vtp->ty, *paramIter); @@ -3506,14 +3695,22 @@ void TypeChecker::checkArgumentList( else if (state.log.getMutable(t)) { } // ok - else if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && state.log.getMutable(t)) + else if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && state.log.get(t)) { } // ok else { if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) minParams = getMinParameterCount(&state.log, paramPack); - bool isVariadic = FFlag::LuauArgCountMismatchSaysAtLeastWhenVariadic && !finite(paramPack, &state.log); + + bool isVariadic = false; + if (FFlag::LuauArgCountMismatchSaysAtLeastWhenVariadic) + { + std::optional tail = flatten(paramPack, state.log).second; + if (tail) + isVariadic = Luau::isVariadic(*tail); + } + state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex, CountMismatch::Context::Arg, isVariadic}}); return; } @@ -3532,14 +3729,7 @@ void TypeChecker::checkArgumentList( unify(errorRecoveryType(scope), *argIter, state.location); ++argIter; } - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - minParams = getMinParameterCount(&state.log, paramPack); - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + reportCountMismatchError(); return; } TypePackId tail = state.log.follow(*paramIter.tail()); @@ -3551,6 +3741,21 @@ void TypeChecker::checkArgumentList( } else if (auto vtp = state.log.getMutable(tail)) { + if (FFlag::LuauLowerBoundsCalculation && vtp->hidden) + { + // We know that this function can technically be oversaturated, but we have its definition and we + // know that it's useless. + + TypeId e = errorRecoveryType(scope); + while (argIter != endIter) + { + unify(e, *argIter, state.location); + ++argIter; + } + + reportCountMismatchError(); + return; + } // Function is variadic and requires that all subsequent parameters // be compatible with a type. size_t argIndex = paramIndex; @@ -3595,14 +3800,7 @@ void TypeChecker::checkArgumentList( } else if (state.log.getMutable(tail)) { - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - // TODO: Better error message? - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - minParams = getMinParameterCount(&state.log, paramPack); - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + reportCountMismatchError(); return; } } @@ -3661,7 +3859,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A actualFunctionType = follow(actualFunctionType); TypePackId retPack; - if (!FFlag::LuauWidenIfSupertypeIsFree2) + if (FFlag::LuauLowerBoundsCalculation || !FFlag::LuauWidenIfSupertypeIsFree2) { retPack = freshTypePack(scope->level); } @@ -3809,21 +4007,49 @@ std::optional> TypeChecker::checkCallOverload(const Scope return {{errorRecoveryTypePack(scope)}}; } - if (get(fn)) + if (auto ftv = get(fn)) { // fn is one of the overloads of actualFunctionType, which // has been instantiated, so is a monotype. We can therefore // unify it with a monomorphic function. - TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - if (FFlag::LuauWidenIfSupertypeIsFree2) + if (useConstrainedIntersections()) { - UnifierOptions options; - options.isFunctionCall = true; - unify(r, fn, expr.location, options); + // This ternary is phrased deliberately. We need ties between sibling scopes to bias toward ftv->level. + const TypeLevel level = scope->level.subsumes(ftv->level) ? scope->level : ftv->level; + + std::vector adjustedArgTypes; + auto it = begin(argPack); + auto endIt = end(argPack); + Widen widen{¤tModule->internalTypes}; + for (; it != endIt; ++it) + { + TypeId t = *it; + TypeId widened = widen.substitute(t).value_or(t); // Surely widening is infallible + adjustedArgTypes.push_back(addType(ConstrainedTypeVar{level, {widened}})); + } + + TypePackId adjustedArgPack = addTypePack(TypePack{std::move(adjustedArgTypes), it.tail()}); + + TxnLog log; + promoteTypeLevels(log, ¤tModule->internalTypes, level, retPack); + log.commit(); + + *asMutable(fn) = FunctionTypeVar{level, adjustedArgPack, retPack}; + return {{retPack}}; } else - unify(fn, r, expr.location); - return {{retPack}}; + { + TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); + if (FFlag::LuauWidenIfSupertypeIsFree2) + { + UnifierOptions options; + options.isFunctionCall = true; + unify(r, fn, expr.location, options); + } + else + unify(fn, r, expr.location); + return {{retPack}}; + } } std::vector metaArgLocations; @@ -4363,10 +4589,17 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s bool Instantiation::isDirty(TypeId ty) { - if (log->getMutable(ty)) + if (const FunctionTypeVar* ftv = log->getMutable(ty)) + { + if (FFlag::LuauTypecheckOptPass && ftv->hasNoGenerics) + return false; + return true; + } else + { return false; + } } bool Instantiation::isDirty(TypePackId tp) @@ -4414,14 +4647,21 @@ TypePackId Instantiation::clean(TypePackId tp) bool ReplaceGenerics::ignoreChildren(TypeId ty) { if (const FunctionTypeVar* ftv = log->getMutable(ty)) + { + if (FFlag::LuauTypecheckOptPass && ftv->hasNoGenerics) + return true; + // We aren't recursing in the case of a generic function which // binds the same generics. This can happen if, for example, there's recursive types. // If T = (a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'. // It's OK to use vector equality here, since we always generate fresh generics // whenever we quantify, so the vectors overlap if and only if they are equal. return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); + } else + { return false; + } } bool ReplaceGenerics::isDirty(TypeId ty) @@ -4464,16 +4704,24 @@ TypePackId ReplaceGenerics::clean(TypePackId tp) bool Anyification::isDirty(TypeId ty) { + if (ty->persistent) + return false; + if (const TableTypeVar* ttv = log->getMutable(ty)) return (ttv->state == TableState::Free || (FFlag::LuauSealExports && ttv->state == TableState::Unsealed)); else if (log->getMutable(ty)) return true; + else if (get(ty)) + return true; else return false; } bool Anyification::isDirty(TypePackId tp) { + if (tp->persistent) + return false; + if (log->getMutable(tp)) return true; else @@ -4494,7 +4742,16 @@ TypeId Anyification::clean(TypeId ty) clone.syntheticName = ttv->syntheticName; clone.tags = ttv->tags; } - return addType(std::move(clone)); + TypeId res = addType(std::move(clone)); + asMutable(res)->normal = ty->normal; + return res; + } + else if (auto ctv = get(ty)) + { + auto [t, ok] = normalize(ty, *arena, *iceHandler); + if (!ok) + normalizationTooComplex = true; + return t; } else return anyType; @@ -4511,16 +4768,34 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location ty = follow(ty); const FunctionTypeVar* ftv = get(ty); - if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty()) - return ty; + if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) + Luau::quantify(ty, scope->level); + + if (FFlag::LuauLowerBoundsCalculation && ftv) + { + auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + return t; + } - Luau::quantify(ty, scope->level); return ty; } TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { + if (FFlag::LuauTypecheckOptPass) + { + const FunctionTypeVar* ftv = get(follow(ty)); + if (ftv && ftv->hasNoGenerics) + return ty; + } + Instantiation instantiation{log, ¤tModule->internalTypes, scope->level}; + + if (FFlag::LuauAutocompleteDynamicLimits && instantiationChildLimit) + instantiation.childLimit = *instantiationChildLimit; + std::optional instantiated = instantiation.substitute(ty); if (instantiated.has_value()) return *instantiated; @@ -4533,8 +4808,18 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { - Anyification anyification{¤tModule->internalTypes, anyType, anyTypePack}; + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + ty = t; + } + + Anyification anyification{¤tModule->internalTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); + if (anyification.normalizationTooComplex) + reportError(location, NormalizationTooComplex{}); if (any.has_value()) return *any; else @@ -4546,7 +4831,15 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location location) { - Anyification anyification{¤tModule->internalTypes, anyType, anyTypePack}; + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + ty = t; + } + + Anyification anyification{¤tModule->internalTypes, iceHandler, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (any.has_value()) return *any; @@ -4830,6 +5123,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation ToStringOptions opts; opts.exhaustive = true; opts.maxTableLength = 0; + opts.useLineBreaks = true; TypeId param = resolveType(scope, *lit->parameters.data[0].type); luauPrintLine(format("_luau_print\t%s\t|\t%s", toString(param, opts).c_str(), toString(lit->location).c_str())); @@ -5283,7 +5577,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, bool needsClone = follow(tf.type) == target; bool shouldMutate = (!FFlag::LuauOnlyMutateInstantiatedTables || getTableType(tf.type)); TableTypeVar* ttv = getMutableTableType(target); - + if (shouldMutate && ttv && needsClone) { // Substitution::clone is a shallow clone. If this is a metatable type, we @@ -5487,25 +5781,82 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV // We need to search in the provided Scope. Find t.x.y first. // We fail to find t.x.y. Try t.x. We found it. Now we must return the type of the property y from the mapped-to type of t.x. // If we completely fail to find the Symbol t but the Scope has that entry, then we should walk that all the way through and terminate. - const auto& [symbol, keys] = getFullName(lvalue); + if (!FFlag::LuauTypecheckOptPass) + { + const auto& [symbol, keys] = getFullName(lvalue); + + ScopePtr currentScope = scope; + while (currentScope) + { + std::optional found; + + std::vector childKeys; + const LValue* currentLValue = &lvalue; + while (currentLValue) + { + if (auto it = currentScope->refinements.find(*currentLValue); it != currentScope->refinements.end()) + { + found = it->second; + break; + } + + childKeys.push_back(*currentLValue); + currentLValue = baseof(*currentLValue); + } + + if (!found) + { + // Should not be using scope->lookup. This is already recursive. + if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end()) + found = it->second.typeId; + else + { + // Nothing exists in this Scope. Just skip and try the parent one. + currentScope = currentScope->parent; + continue; + } + } + + for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it) + { + const LValue& key = *it; + + // Symbol can happen. Skip. + if (get(key)) + continue; + else if (auto field = get(key)) + { + found = getIndexTypeFromType(scope, *found, field->key, Location(), false); + if (!found) + return std::nullopt; // Turns out this type doesn't have the property at all. We're done. + } + else + LUAU_ASSERT(!"New LValue alternative not handled here."); + } + + return found; + } + + // No entry for it at all. Can happen when LValue root is a global. + return std::nullopt; + } + + const Symbol symbol = getBaseSymbol(lvalue); ScopePtr currentScope = scope; while (currentScope) { std::optional found; - std::vector childKeys; - const LValue* currentLValue = &lvalue; - while (currentLValue) + const LValue* topLValue = nullptr; + + for (topLValue = &lvalue; topLValue; topLValue = baseof(*topLValue)) { - if (auto it = currentScope->refinements.find(*currentLValue); it != currentScope->refinements.end()) + if (auto it = currentScope->refinements.find(*topLValue); it != currentScope->refinements.end()) { found = it->second; break; } - - childKeys.push_back(*currentLValue); - currentLValue = baseof(*currentLValue); } if (!found) @@ -5521,9 +5872,15 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV } } + // We need to walk the l-value path in reverse, so we collect components into a vector + std::vector childKeys; + + for (const LValue* curr = &lvalue; curr != topLValue; curr = baseof(*curr)) + childKeys.push_back(curr); + for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it) { - const LValue& key = *it; + const LValue& key = **it; // Symbol can happen. Skip. if (get(key)) @@ -5938,6 +6295,11 @@ bool TypeChecker::isNonstrictMode() const return (currentModule->mode == Mode::Nonstrict) || (currentModule->mode == Mode::NoCheck); } +bool TypeChecker::useConstrainedIntersections() const +{ + return FFlag::LuauLowerBoundsCalculation && !isNonstrictMode(); +} + std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp, size_t expectedLength, const Location& location) { TypePackId expectedTypePack = addTypePack({}); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 5bb0523..3050323 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -104,7 +104,7 @@ TypePackIterator begin(TypePackId tp) return TypePackIterator{tp}; } -TypePackIterator begin(TypePackId tp, TxnLog* log) +TypePackIterator begin(TypePackId tp, const TxnLog* log) { return TypePackIterator{tp, log}; } @@ -256,7 +256,7 @@ size_t size(const TypePack& tp, TxnLog* log) return result; } -std::optional first(TypePackId tp) +std::optional first(TypePackId tp, bool ignoreHiddenVariadics) { auto it = begin(tp); auto endIter = end(tp); @@ -266,7 +266,7 @@ std::optional first(TypePackId tp) if (auto tail = it.tail()) { - if (auto vtp = get(*tail)) + if (auto vtp = get(*tail); vtp && (!vtp->hidden || !ignoreHiddenVariadics)) return vtp->ty; } @@ -299,6 +299,46 @@ std::pair, std::optional> flatten(TypePackId tp) return {res, iter.tail()}; } +std::pair, std::optional> flatten(TypePackId tp, const TxnLog& log) +{ + tp = log.follow(tp); + + std::vector flattened; + std::optional tail = std::nullopt; + + TypePackIterator it(tp, &log); + + for (; it != end(tp); ++it) + { + flattened.push_back(*it); + } + + tail = it.tail(); + + return {flattened, tail}; +} + +bool isVariadic(TypePackId tp) +{ + return isVariadic(tp, *TxnLog::empty()); +} + +bool isVariadic(TypePackId tp, const TxnLog& log) +{ + std::optional tail = flatten(tp, log).second; + + if (!tail) + return false; + + if (log.get(*tail)) + return true; + + if (auto vtp = log.get(*tail); vtp && !vtp->hidden) + return true; + + return false; +} + TypePackVar* asMutable(TypePackId tp) { return const_cast(tp); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index dbc412f..0fbfdbf 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -177,7 +177,7 @@ bool maybeString(TypeId ty) if (FFlag::LuauSubtypingAddOptPropsToUnsealedTables) { ty = follow(ty); - + if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) return true; @@ -366,7 +366,7 @@ bool maybeSingleton(TypeId ty) bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) { - RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit, "hasLength"); ty = follow(ty); @@ -654,9 +654,9 @@ static TypeVar booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persi static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}; static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true}; static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}; -static TypeVar anyType_{AnyTypeVar{}}; -static TypeVar errorType_{ErrorTypeVar{}}; -static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}}; +static TypeVar anyType_{AnyTypeVar{}, /*persistent*/ true}; +static TypeVar errorType_{ErrorTypeVar{}, /*persistent*/ true}; +static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}, /*persistent*/ true}; static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true}; static TypePackVar errorTypePack_{Unifiable::Error{}}; @@ -698,7 +698,7 @@ TypeId SingletonTypes::makeStringMetatable() { const TypeId optionalNumber = arena->addType(UnionTypeVar{{nilType, numberType}}); const TypeId optionalString = arena->addType(UnionTypeVar{{nilType, stringType}}); - const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, &booleanType_}}); + const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, booleanType}}); const TypePackId oneStringPack = arena->addTypePack({stringType}); const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true}); @@ -802,6 +802,7 @@ void persist(TypeId ty) continue; asMutable(t)->persistent = true; + asMutable(t)->normal = true; // all persistent types are assumed to be normal if (auto btv = get(t)) queue.push_back(btv->boundTo); @@ -838,6 +839,11 @@ void persist(TypeId ty) for (TypeId opt : itv->parts) queue.push_back(opt); } + else if (auto ctv = get(t)) + { + for (TypeId opt : ctv->parts) + queue.push_back(opt); + } else if (auto mtv = get(t)) { queue.push_back(mtv->table); @@ -899,6 +905,16 @@ TypeLevel* getMutableLevel(TypeId ty) return const_cast(getLevel(ty)); } +std::optional getLevel(TypePackId tp) +{ + tp = follow(tp); + + if (auto ftv = get(tp)) + return ftv->level; + else + return std::nullopt; +} + const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name) { while (cls) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 398dc9e..f9ea58c 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,9 +14,12 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); +LUAU_FASTINT(LuauTypeInferIterationLimit); +LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) +LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauExtendedIndexerError, false); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); +LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) @@ -27,6 +30,7 @@ LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false) LUAU_FASTFLAGVARIABLE(LuauUnifierCacheErrors, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) +LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { @@ -126,7 +130,6 @@ static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel visitTypeVarOnce(ty, ptl, seen); } -// TODO: use this and make it static. void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -305,8 +308,7 @@ static std::optional> getTableMat return std::nullopt; } -Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, - TxnLog* parentLog) +Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) , mode(mode) , log(parentLog) @@ -326,6 +328,7 @@ Unifier::Unifier(TypeArena* types, Mode mode, std::vector 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + if (FFlag::LuauAutocompleteDynamicLimits) { - reportError(TypeError{location, UnificationTooComplex{}}); - return; + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } + } + else + { + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } } superTy = log.follow(superTy); @@ -354,6 +369,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superTy == subTy) return; + if (log.get(superTy)) + return tryUnifyWithConstrainedSuperTypeVar(subTy, superTy); + auto superFree = log.getMutable(superTy); auto subFree = log.getMutable(subTy); @@ -442,7 +460,18 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (get(superTy) || get(superTy)) return tryUnifyWithAny(subTy, superTy); - if (get(subTy) || get(subTy)) + if (get(subTy)) + { + if (anyIsTop) + { + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + return; + } + else + return tryUnifyWithAny(superTy, subTy); + } + + if (get(subTy)) return tryUnifyWithAny(superTy, subTy); bool cacheEnabled; @@ -484,7 +513,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (const UnionTypeVar* uv = log.getMutable(subTy)) + if (log.get(subTy)) + tryUnifyWithConstrainedSubTypeVar(subTy, superTy); + else if (const UnionTypeVar* uv = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, uv, superTy); } @@ -946,7 +977,7 @@ struct WeirdIter LUAU_ASSERT(log.getMutable(newTail)); level = log.getMutable(packId)->level; - log.replace(packId, Unifiable::Bound(newTail)); + log.replace(packId, BoundTypePack(newTail)); packId = newTail; pack = log.getMutable(newTail); index = 0; @@ -994,39 +1025,32 @@ void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall tryUnify_(subTp, superTp, isFunctionCall); } -static std::pair, std::optional> logAwareFlatten(TypePackId tp, const TxnLog& log) -{ - tp = log.follow(tp); - - std::vector flattened; - std::optional tail = std::nullopt; - - TypePackIterator it(tp, &log); - - for (; it != end(tp); ++it) - { - flattened.push_back(*it); - } - - tail = it.tail(); - - return {flattened, tail}; -} - /* * This is quite tricky: we are walking two rope-like structures and unifying corresponding elements. * If one is longer than the other, but the short end is free, we grow it to the required length. */ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "TypePackId tryUnify_"); ++sharedState.counters.iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + if (FFlag::LuauAutocompleteDynamicLimits) { - reportError(TypeError{location, UnificationTooComplex{}}); - return; + if (sharedState.counters.iterationLimit > 0 && sharedState.counters.iterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } + } + else + { + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) + { + reportError(TypeError{location, UnificationTooComplex{}}); + return; + } } superTp = log.follow(superTp); @@ -1087,8 +1111,8 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If the size of two heads does not match, but both packs have free tail // We set the sentinel variable to say so to avoid growing it forever. - auto [superTypes, superTail] = logAwareFlatten(superTp, log); - auto [subTypes, subTail] = logAwareFlatten(subTp, log); + auto [superTypes, superTail] = flatten(superTp, log); + auto [subTypes, subTail] = flatten(subTp, log); bool noInfiniteGrowth = (superTypes.size() != subTypes.size()) && (superTail && log.getMutable(*superTail)) && (subTail && log.getMutable(*subTail)); @@ -1165,19 +1189,20 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal else { // A union type including nil marks an optional argument - if (superIter.good() && isOptional(*superIter)) + if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && superIter.good() && isOptional(*superIter)) { superIter.advance(); continue; } - else if (subIter.good() && isOptional(*subIter)) + else if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && subIter.good() && isOptional(*subIter)) { subIter.advance(); continue; } // In nonstrict mode, any also marks an optional argument. - else if (!FFlag::LuauAnyInIsOptionalIsOptional && superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) + else if (!FFlag::LuauAnyInIsOptionalIsOptional && superIter.good() && isNonstrictMode() && + log.getMutable(log.follow(*superIter))) { superIter.advance(); continue; @@ -1195,7 +1220,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal return; } - if (!isFunctionCall && subIter.good()) + if ((!FFlag::LuauLowerBoundsCalculation || isNonstrictMode()) && !isFunctionCall && subIter.good()) { // Sometimes it is ok to pass too many arguments return; @@ -1418,14 +1443,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (FFlag::LuauAnyInIsOptionalIsOptional) { - if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type)) + if (subIter == subTable->props.end() && + (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type)) missingProperties.push_back(propName); } else { bool isAny = log.getMutable(log.follow(superProp.type)); - if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && !isAny) + if (subIter == subTable->props.end() && + (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && + !isAny) missingProperties.push_back(propName); } } @@ -1438,8 +1466,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } // And vice versa if we're invariant - if (variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && - superTable->state != TableState::Free) + if (variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && superTable->state != TableState::Free) { for (const auto& [propName, subProp] : subTable->props) { @@ -1453,7 +1480,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else { bool isAny = log.is(log.follow(subProp.type)); - if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) + if (superIter == superTable->props.end() && + (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) extraProperties.push_back(propName); } } @@ -1499,13 +1527,15 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); } - else if (FFlag::LuauAnyInIsOptionalIsOptional && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) + else if (FFlag::LuauAnyInIsOptionalIsOptional && + (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) { } - else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && (isOptional(prop.type) || get(follow(prop.type)))) + else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && + (isOptional(prop.type) || get(follow(prop.type)))) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: should isOptional(anyType) be true? @@ -1664,9 +1694,9 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (FFlag::LuauTxnLogDontRetryForIndexers) { - // Changing the indexer can invalidate the table pointers. - superTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); + // Changing the indexer can invalidate the table pointers. + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); } else if (FFlag::LuauTxnLogCheckForInvalidation) { @@ -1921,8 +1951,6 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (!superTable || !subTable) ice("passed non-table types to unifySealedTables"); - Unifier innerState = makeChildUnifier(); - std::vector missingPropertiesInSuper; bool isUnnamedTable = subTable->name == std::nullopt && subTable->syntheticName == std::nullopt; bool errorReported = false; @@ -1944,6 +1972,8 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec } } + Unifier innerState = makeChildUnifier(); + // Tables must have exactly the same props and their types must all unify for (const auto& it : superTable->props) { @@ -2376,6 +2406,180 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N return Luau::findTablePropertyRespectingMeta(errors, lhsType, name, location); } +void Unifier::tryUnifyWithConstrainedSubTypeVar(TypeId subTy, TypeId superTy) +{ + const ConstrainedTypeVar* subConstrained = get(subTy); + if (!subConstrained) + ice("tryUnifyWithConstrainedSubTypeVar received non-ConstrainedTypeVar subTy!"); + + const std::vector& subTyParts = subConstrained->parts; + + // A | B <: T if A <: T and B <: T + bool failed = false; + std::optional unificationTooComplex; + + const size_t count = subTyParts.size(); + + for (size_t i = 0; i < count; ++i) + { + TypeId type = subTyParts[i]; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, superTy); + + if (i == count - 1) + log.concat(std::move(innerState.log)); + + ++i; + + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + + if (!innerState.errors.empty()) + { + failed = true; + break; + } + } + + if (unificationTooComplex) + reportError(*unificationTooComplex); + else if (failed) + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + else + log.replace(subTy, BoundTypeVar{superTy}); +} + +void Unifier::tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy) +{ + ConstrainedTypeVar* superC = log.getMutable(superTy); + if (!superC) + ice("tryUnifyWithConstrainedSuperTypeVar received non-ConstrainedTypeVar superTy!"); + + // subTy could be a + // table + // metatable + // class + // function + // primitive + // free + // generic + // intersection + // union + // Do we really just tack it on? I think we might! + // We can certainly do some deduplication. + // Is there any point to deducing Player|Instance when we could just reduce to Instance? + // Is it actually ok to have multiple free types in a single intersection? What if they are later unified into the same type? + // Maybe we do a simplification step during quantification. + + auto it = std::find(superC->parts.begin(), superC->parts.end(), subTy); + if (it != superC->parts.end()) + return; + + superC->parts.push_back(subTy); +} + +void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy) +{ + // The duplication between this and regular typepack unification is tragic. + + auto superIter = begin(superTy, &log); + auto superEndIter = end(superTy); + + auto subIter = begin(subTy, &log); + auto subEndIter = end(subTy); + + int count = FInt::LuauTypeInferLowerBoundsIterationLimit; + + for (; subIter != subEndIter; ++subIter) + { + if (0 >= --count) + ice("Internal recursion counter limit exceeded in Unifier::unifyLowerBound"); + + if (superIter != superEndIter) + { + tryUnify_(*subIter, *superIter); + ++superIter; + continue; + } + + if (auto t = superIter.tail()) + { + TypePackId tailPack = follow(*t); + + if (log.get(tailPack)) + occursCheck(tailPack, subTy); + + FreeTypePack* freeTailPack = log.getMutable(tailPack); + if (!freeTailPack) + return; + + TypeLevel level = freeTailPack->level; + + TypePack* tp = getMutable(log.replace(tailPack, TypePack{})); + + for (; subIter != subEndIter; ++subIter) + { + tp->head.push_back(types->addType(ConstrainedTypeVar{level, {follow(*subIter)}})); + } + + tp->tail = subIter.tail(); + } + + return; + } + + if (superIter != superEndIter) + { + if (auto subTail = subIter.tail()) + { + TypePackId subTailPack = follow(*subTail); + if (get(subTailPack)) + { + TypePack* tp = getMutable(log.replace(subTailPack, TypePack{})); + + for (; superIter != superEndIter; ++superIter) + tp->head.push_back(*superIter); + } + } + else + { + while (superIter != superEndIter) + { + if (!isOptional(*superIter)) + { + errors.push_back(TypeError{location, CountMismatch{size(superTy), size(subTy), CountMismatch::Return}}); + return; + } + ++superIter; + } + } + + return; + } + + // Both iters are at their respective tails + auto subTail = subIter.tail(); + auto superTail = superIter.tail(); + if (subTail && superTail) + tryUnify(*subTail, *superTail); + else if (subTail) + { + const FreeTypePack* freeSubTail = log.getMutable(*subTail); + if (freeSubTail) + { + log.replace(*subTail, TypePack{}); + } + } + else if (superTail) + { + const FreeTypePack* freeSuperTail = log.getMutable(*superTail); + if (freeSuperTail) + { + log.replace(*superTail, TypePack{}); + } + } +} + void Unifier::occursCheck(TypeId needle, TypeId haystack) { sharedState.tempSeenTy.clear(); @@ -2385,7 +2589,8 @@ void Unifier::occursCheck(TypeId needle, TypeId haystack) void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) { - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "occursCheck for TypeId"); auto check = [&](TypeId tv) { occursCheck(seen, needle, tv); @@ -2425,6 +2630,11 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays for (TypeId ty : a->parts) check(ty); } + else if (auto a = log.getMutable(haystack)) + { + for (TypeId ty : a->parts) + check(ty); + } } void Unifier::occursCheck(TypePackId needle, TypePackId haystack) @@ -2450,7 +2660,8 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ if (!log.getMutable(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "occursCheck for TypePackId"); while (!log.getMutable(haystack)) { @@ -2474,7 +2685,23 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; + if (FFlag::LuauTypecheckOptPass) + { + Unifier u = Unifier{types, mode, location, variance, sharedState, &log}; + u.anyIsTop = anyIsTop; + return u; + } + + Unifier u = Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; + u.anyIsTop = anyIsTop; + return u; +} + +// A utility function that appends the given error to the unifier's error log. +// This allows setting a breakpoint wherever the unifier reports an error. +void Unifier::reportError(TypeError err) +{ + errors.push_back(std::move(err)); } bool Unifier::isNonstrictMode() const diff --git a/Ast/include/Luau/DenseHash.h b/Ast/include/Luau/DenseHash.h index 65939be..f854311 100644 --- a/Ast/include/Luau/DenseHash.h +++ b/Ast/include/Luau/DenseHash.h @@ -32,6 +32,7 @@ class DenseHashTable { public: class const_iterator; + class iterator; DenseHashTable(const Key& empty_key, size_t buckets = 0) : count(0) @@ -43,7 +44,7 @@ public: // don't move this to initializer list! this works around an MSVC codegen issue on AMD CPUs: // https://developercommunity.visualstudio.com/t/stdvector-constructor-from-size-t-is-25-times-slow/1546547 if (buckets) - data.resize(buckets, ItemInterface::create(empty_key)); + resize_data(buckets); } void clear() @@ -125,7 +126,7 @@ public: if (data.empty() && data.capacity() >= newsize) { LUAU_ASSERT(count == 0); - data.resize(newsize, ItemInterface::create(empty_key)); + resize_data(newsize); return; } @@ -169,6 +170,21 @@ public: return const_iterator(this, data.size()); } + iterator begin() + { + size_t start = 0; + + while (start < data.size() && eq(ItemInterface::getKey(data[start]), empty_key)) + start++; + + return iterator(this, start); + } + + iterator end() + { + return iterator(this, data.size()); + } + size_t size() const { return count; @@ -233,7 +249,82 @@ public: size_t index; }; + class iterator + { + public: + iterator() + : set(0) + , index(0) + { + } + + iterator(DenseHashTable* set, size_t index) + : set(set) + , index(index) + { + } + + MutableItem& operator*() const + { + return *reinterpret_cast(&set->data[index]); + } + + MutableItem* operator->() const + { + return reinterpret_cast(&set->data[index]); + } + + bool operator==(const iterator& other) const + { + return set == other.set && index == other.index; + } + + bool operator!=(const iterator& other) const + { + return set != other.set || index != other.index; + } + + iterator& operator++() + { + size_t size = set->data.size(); + + do + { + index++; + } while (index < size && set->eq(ItemInterface::getKey(set->data[index]), set->empty_key)); + + return *this; + } + + iterator operator++(int) + { + iterator res = *this; + ++*this; + return res; + } + + private: + DenseHashTable* set; + size_t index; + }; + private: + template + void resize_data(size_t count, typename std::enable_if_t>* dummy = nullptr) + { + data.resize(count, ItemInterface::create(empty_key)); + } + + template + void resize_data(size_t count, typename std::enable_if_t>* dummy = nullptr) + { + size_t size = data.size(); + data.resize(count); + + for (size_t i = size; i < count; i++) + data[i].first = empty_key; + } + std::vector data; size_t count; Key empty_key; @@ -290,6 +381,7 @@ class DenseHashSet public: typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::iterator iterator; DenseHashSet(const Key& empty_key, size_t buckets = 0) : impl(empty_key, buckets) @@ -336,6 +428,16 @@ public: { return impl.end(); } + + iterator begin() + { + return impl.begin(); + } + + iterator end() + { + return impl.end(); + } }; // This is a faster alternative of unordered_map, but it does not implement the same interface (i.e. it does not support erasing and has @@ -348,6 +450,7 @@ class DenseHashMap public: typedef typename Impl::const_iterator const_iterator; + typedef typename Impl::iterator iterator; DenseHashMap(const Key& empty_key, size_t buckets = 0) : impl(empty_key, buckets) @@ -401,10 +504,21 @@ public: { return impl.begin(); } + const_iterator end() const { return impl.end(); } + + iterator begin() + { + return impl.begin(); + } + + iterator end() + { + return impl.end(); + } }; } // namespace Luau diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index d7d867f..4f3dbbd 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -173,7 +173,7 @@ public: } const Lexeme& next(); - const Lexeme& next(bool skipComments); + const Lexeme& next(bool skipComments, bool updatePrevLocation); void nextline(); Lexeme lookahead(); diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 70c6c78..5dd4f04 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -349,13 +349,11 @@ void Lexer::setReadNames(bool read) const Lexeme& Lexer::next() { - return next(this->skipComments); + return next(this->skipComments, true); } -const Lexeme& Lexer::next(bool skipComments) +const Lexeme& Lexer::next(bool skipComments, bool updatePrevLocation) { - bool first = true; - // in skipComments mode we reject valid comments do { @@ -363,11 +361,11 @@ const Lexeme& Lexer::next(bool skipComments) while (isSpace(peekch())) consume(); - if (!FFlag::LuauParseLocationIgnoreCommentSkip || first) + if (!FFlag::LuauParseLocationIgnoreCommentSkip || updatePrevLocation) prevLocation = lexeme.location; lexeme = readNext(); - first = false; + updatePrevLocation = false; } while (skipComments && (lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment)); return lexeme; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index f9d3217..badd3fd 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -11,6 +11,7 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParseRecoverUnexpectedPack, false) +LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkipInCapture, false) namespace Luau { @@ -2789,7 +2790,7 @@ void Parser::nextLexeme() { if (options.captureComments) { - Lexeme::Type type = lexer.next(/* skipComments= */ false).type; + Lexeme::Type type = lexer.next(/* skipComments= */ false, true).type; while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment) { @@ -2813,7 +2814,7 @@ void Parser::nextLexeme() hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); } - type = lexer.next(/* skipComments= */ false).type; + type = lexer.next(/* skipComments= */ false, !FFlag::LuauParseLocationIgnoreCommentSkipInCapture).type; } } else diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 6330bf1..8ef69e7 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -1386,8 +1386,8 @@ struct Compiler const Constant* cv = constants.find(expr->index); - if (cv && cv->type == Constant::Type_Number && double(int(cv->valueNumber)) == cv->valueNumber && cv->valueNumber >= 1 && - cv->valueNumber <= 256) + if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && + double(int(cv->valueNumber)) == cv->valueNumber) { uint8_t rt = compileExprAuto(expr->expr, rs); uint8_t i = uint8_t(int(cv->valueNumber) - 1); diff --git a/Compiler/src/CostModel.cpp b/Compiler/src/CostModel.cpp new file mode 100644 index 0000000..d8511bd --- /dev/null +++ b/Compiler/src/CostModel.cpp @@ -0,0 +1,258 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "CostModel.h" + +#include "Luau/Common.h" +#include "Luau/DenseHash.h" + +namespace Luau +{ +namespace Compile +{ + +inline uint64_t parallelAddSat(uint64_t x, uint64_t y) +{ + uint64_t s = x + y; + uint64_t m = s & 0x8080808080808080ull; // saturation mask + + return (s ^ m) | (m - (m >> 7)); +} + +struct Cost +{ + static const uint64_t kLiteral = ~0ull; + + // cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant + uint64_t model; + // constant mask: 8-byte 0xff mask; equal to all ff's for literals, for variables only byte #i (1+) is set to align with model + uint64_t constant; + + Cost(int cost = 0, uint64_t constant = 0) + : model(cost < 0x7f ? cost : 0x7f) + , constant(constant) + { + } + + Cost operator+(const Cost& other) const + { + Cost result; + result.model = parallelAddSat(model, other.model); + return result; + } + + Cost& operator+=(const Cost& other) + { + model = parallelAddSat(model, other.model); + constant = 0; + return *this; + } + + static Cost fold(const Cost& x, const Cost& y) + { + uint64_t newmodel = parallelAddSat(x.model, y.model); + uint64_t newconstant = x.constant & y.constant; + + // the extra cost for folding is 1; the discount is 1 for the variable that is shared by x&y (or whichever one is used in x/y if the other is + // literal) + uint64_t extra = (newconstant == kLiteral) ? 0 : (1 | (0x0101010101010101ull & newconstant)); + + Cost result; + result.model = parallelAddSat(newmodel, extra); + result.constant = newconstant; + + return result; + } +}; + +struct CostVisitor : AstVisitor +{ + DenseHashMap vars; + Cost result; + + CostVisitor() + : vars(nullptr) + { + } + + Cost model(AstExpr* node) + { + if (AstExprGroup* expr = node->as()) + { + return model(expr->expr); + } + else if (node->is() || node->is() || node->is() || + node->is()) + { + return Cost(0, Cost::kLiteral); + } + else if (AstExprLocal* expr = node->as()) + { + const uint64_t* i = vars.find(expr->local); + + return Cost(0, i ? *i : 0); // locals typically don't require extra instructions to compute + } + else if (node->is()) + { + return 1; + } + else if (node->is()) + { + return 3; + } + else if (AstExprCall* expr = node->as()) + { + Cost cost = 3; + cost += model(expr->func); + + for (size_t i = 0; i < expr->args.size; ++i) + { + Cost ac = model(expr->args.data[i]); + // for constants/locals we still need to copy them to the argument list + cost += ac.model == 0 ? Cost(1) : ac; + } + + return cost; + } + else if (AstExprIndexName* expr = node->as()) + { + return model(expr->expr) + 1; + } + else if (AstExprIndexExpr* expr = node->as()) + { + return model(expr->expr) + model(expr->index) + 1; + } + else if (AstExprFunction* expr = node->as()) + { + return 10; // high baseline cost due to allocation + } + else if (AstExprTable* expr = node->as()) + { + Cost cost = 10; // high baseline cost due to allocation + + for (size_t i = 0; i < expr->items.size; ++i) + { + const AstExprTable::Item& item = expr->items.data[i]; + + if (item.key) + cost += model(item.key); + + cost += model(item.value); + cost += 1; + } + + return cost; + } + else if (AstExprUnary* expr = node->as()) + { + return Cost::fold(model(expr->expr), Cost(0, Cost::kLiteral)); + } + else if (AstExprBinary* expr = node->as()) + { + return Cost::fold(model(expr->left), model(expr->right)); + } + else if (AstExprTypeAssertion* expr = node->as()) + { + return model(expr->expr); + } + else if (AstExprIfElse* expr = node->as()) + { + return model(expr->condition) + model(expr->trueExpr) + model(expr->falseExpr) + 2; + } + else + { + LUAU_ASSERT(!"Unknown expression type"); + return {}; + } + } + + void assign(AstExpr* expr) + { + // variable assignments reset variable mask, so that further uses of this variable aren't discounted + // this doesn't work perfectly with backwards control flow like loops, but is good enough for a single pass + if (AstExprLocal* lv = expr->as()) + if (uint64_t* i = vars.find(lv->local)) + *i = 0; + } + + bool visit(AstExpr* node) override + { + // note: we short-circuit the visitor traversal through any expression trees by returning false + // recursive traversal is happening inside model() which makes it easier to get the resulting value of the subexpression + result += model(node); + + return false; + } + + bool visit(AstStat* node) override + { + if (node->is()) + result += 2; + else if (node->is() || node->is() || node->is() || node->is()) + result += 2; + else if (node->is() || node->is()) + result += 1; + + return true; + } + + bool visit(AstStatLocal* node) override + { + for (size_t i = 0; i < node->values.size; ++i) + { + Cost arg = model(node->values.data[i]); + + // propagate constant mask from expression through variables + if (arg.constant && i < node->vars.size) + vars[node->vars.data[i]] = arg.constant; + + result += arg; + } + + return false; + } + + bool visit(AstStatAssign* node) override + { + for (size_t i = 0; i < node->vars.size; ++i) + assign(node->vars.data[i]); + + return true; + } + + bool visit(AstStatCompoundAssign* node) override + { + assign(node->var); + + // if lhs is not a local, setting it requires an extra table operation + result += node->var->is() ? 1 : 2; + + return true; + } +}; + +uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount) +{ + CostVisitor visitor; + for (size_t i = 0; i < varCount && i < 7; ++i) + visitor.vars[vars[i]] = 0xffull << (i * 8 + 8); + + root->visit(&visitor); + + return visitor.result.model; +} + +int computeCost(uint64_t model, const bool* varsConst, size_t varCount) +{ + int cost = int(model & 0x7f); + + // don't apply discounts to what is likely a saturated sum + if (cost == 0x7f) + return cost; + + for (size_t i = 0; i < varCount && i < 7; ++i) + cost -= int((model >> (8 * i + 8)) & 0x7f) * varsConst[i]; + + return cost; +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/CostModel.h b/Compiler/src/CostModel.h new file mode 100644 index 0000000..c27861e --- /dev/null +++ b/Compiler/src/CostModel.h @@ -0,0 +1,18 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" + +namespace Luau +{ +namespace Compile +{ + +// cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant +uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount); + +// cost is computed as B - sum(Di * Ci), where B is baseline cost, Di is the discount for each variable and Ci is 1 when variable #i is constant +int computeCost(uint64_t model, const bool* varsConst, size_t varCount); + +} // namespace Compile +} // namespace Luau diff --git a/Sources.cmake b/Sources.cmake index 6f110f1..60e5dfd 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -32,11 +32,13 @@ target_sources(Luau.Compiler PRIVATE Compiler/src/Compiler.cpp Compiler/src/Builtins.cpp Compiler/src/ConstantFolding.cpp + Compiler/src/CostModel.cpp Compiler/src/TableShape.cpp Compiler/src/ValueTracking.cpp Compiler/src/lcode.cpp Compiler/src/Builtins.h Compiler/src/ConstantFolding.h + Compiler/src/CostModel.h Compiler/src/TableShape.h Compiler/src/ValueTracking.h ) @@ -58,6 +60,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/LValue.h Analysis/include/Luau/Module.h Analysis/include/Luau/ModuleResolver.h + Analysis/include/Luau/Normalize.h Analysis/include/Luau/Predicate.h Analysis/include/Luau/Quantify.h Analysis/include/Luau/RecursionCounter.h @@ -94,6 +97,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Linter.cpp Analysis/src/LValue.cpp Analysis/src/Module.cpp + Analysis/src/Normalize.cpp Analysis/src/Quantify.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp @@ -216,6 +220,7 @@ if(TARGET Luau.UnitTest) tests/Autocomplete.test.cpp tests/BuiltinDefinitions.test.cpp tests/Compiler.test.cpp + tests/CostModel.test.cpp tests/Config.test.cpp tests/Error.test.cpp tests/Frontend.test.cpp @@ -224,6 +229,7 @@ if(TARGET Luau.UnitTest) tests/LValue.test.cpp tests/Module.test.cpp tests/NonstrictMode.test.cpp + tests/Normalize.test.cpp tests/Parser.test.cpp tests/RequireTracer.test.cpp tests/StringUtils.test.cpp diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 1c75c0b..dc40b6e 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -34,7 +34,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauTableRehashRework, false) -LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary, false) +LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary2, false) // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 @@ -390,6 +390,8 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) setarrayvector(L, t, nasize); /* create new hash part with appropriate size */ setnodevector(L, t, nhsize); + /* used for the migration check at the end */ + LuaNode* nnew = t->node; if (nasize < oldasize) { /* array part must shrink? */ t->sizearray = nasize; @@ -413,6 +415,8 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) /* shrink array */ luaM_reallocarray(L, t->array, oldasize, nasize, TValue, t->memcat); } + /* used for the migration check at the end */ + TValue* anew = t->array; /* re-insert elements from hash part */ if (FFlag::LuauTableRehashRework) { @@ -441,14 +445,30 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) } } + /* make sure we haven't recursively rehashed during element migration */ + LUAU_ASSERT(nnew == t->node); + LUAU_ASSERT(anew == t->array); + if (nold != dummynode) luaM_freearray(L, nold, twoto(oldhsize), LuaNode, t->memcat); /* free old array */ } +static int adjustasize(Table* t, int size, const TValue* ek) +{ + LUAU_ASSERT(FFlag::LuauTableNewBoundary2); + bool tbound = t->node != dummynode || size < t->sizearray; + int ekindex = ek && ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; + /* move the array size up until the boundary is guaranteed to be inside the array part */ + while (size + 1 == ekindex || (tbound && !ttisnil(luaH_getnum(t, size + 1)))) + size++; + return size; +} + void luaH_resizearray(lua_State* L, Table* t, int nasize) { int nsize = (t->node == dummynode) ? 0 : sizenode(t); - resize(L, t, nasize, nsize); + int asize = FFlag::LuauTableNewBoundary2 ? adjustasize(t, nasize, NULL) : nasize; + resize(L, t, asize, nsize); } void luaH_resizehash(lua_State* L, Table* t, int nhsize) @@ -470,21 +490,12 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) totaluse++; /* compute new size for array part */ int na = computesizes(nums, &nasize); + int nh = totaluse - na; /* enforce the boundary invariant; for performance, only do hash lookups if we must */ - if (FFlag::LuauTableNewBoundary) - { - bool tbound = t->node != dummynode || nasize < t->sizearray; - int ekindex = ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; - /* move the array size up until the boundary is guaranteed to be inside the array part */ - while (nasize + 1 == ekindex || (tbound && !ttisnil(luaH_getnum(t, nasize + 1)))) - { - nasize++; - na++; - } - } + if (FFlag::LuauTableNewBoundary2) + nasize = adjustasize(t, nasize, ek); /* resize the table to new computed sizes */ - LUAU_ASSERT(na <= totaluse); - resize(L, t, nasize, totaluse - na); + resize(L, t, nasize, nh); } /* @@ -544,7 +555,7 @@ static LuaNode* getfreepos(Table* t) static TValue* newkey(lua_State* L, Table* t, const TValue* key) { /* enforce boundary invariant */ - if (FFlag::LuauTableNewBoundary && ttisnumber(key) && nvalue(key) == t->sizearray + 1) + if (FFlag::LuauTableNewBoundary2 && ttisnumber(key) && nvalue(key) == t->sizearray + 1) { rehash(L, t, key); /* grow table */ @@ -735,7 +746,7 @@ TValue* luaH_setstr(lua_State* L, Table* t, TString* key) static LUAU_NOINLINE int unbound_search(Table* t, unsigned int j) { - LUAU_ASSERT(!FFlag::LuauTableNewBoundary); + LUAU_ASSERT(!FFlag::LuauTableNewBoundary2); unsigned int i = j; /* i is zero or a present index */ j++; /* find `i' and `j' such that i is present and j is not */ @@ -820,7 +831,7 @@ int luaH_getn(Table* t) maybesetaboundary(t, boundary); return boundary; } - else if (FFlag::LuauTableNewBoundary) + else if (FFlag::LuauTableNewBoundary2) { /* validate boundary invariant */ LUAU_ASSERT(t->node == dummynode || ttisnil(luaH_getnum(t, j + 1))); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 41887f4..9c1f387 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -199,7 +199,7 @@ static int tmove(lua_State* L) int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ luaL_checktype(L, tt, LUA_TTABLE); - void (*telemetrycb)(lua_State* L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry; + void (*telemetrycb)(lua_State * L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry; if (DFFlag::LuauTableMoveTelemetry2 && telemetrycb && e >= f) { diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 34949ef..39c60ea 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,7 +16,7 @@ #include -LUAU_FASTFLAG(LuauTableNewBoundary) +LUAU_FASTFLAG(LuauTableNewBoundary2) // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ @@ -2268,7 +2268,7 @@ static void luau_execute(lua_State* L) VM_NEXT(); } } - else if (FFlag::LuauTableNewBoundary || (h->lsizenode == 0 && ttisnil(gval(h->node)))) + else if (FFlag::LuauTableNewBoundary2 || (h->lsizenode == 0 && ttisnil(gval(h->node)))) { // fallthrough to exit VM_NEXT(); diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp new file mode 100644 index 0000000..ec04932 --- /dev/null +++ b/tests/CostModel.test.cpp @@ -0,0 +1,101 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" + +#include "doctest.h" + +using namespace Luau; + +namespace Luau +{ +namespace Compile +{ + +uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount); +int computeCost(uint64_t model, const bool* varsConst, size_t varCount); + +} // namespace Compile +} // namespace Luau + +TEST_SUITE_BEGIN("CostModel"); + +static uint64_t modelFunction(const char* source) +{ + Allocator allocator; + AstNameTable names(allocator); + + ParseResult result = Parser::parse(source, strlen(source), names, allocator); + REQUIRE(result.root != nullptr); + + AstStatFunction* func = result.root->body.data[0]->as(); + REQUIRE(func); + + return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size); +} + +TEST_CASE("Expression") +{ + uint64_t model = modelFunction(R"( +function test(a, b, c) + return a + (b + 1) * (b + 1) - c +end +)"); + + const bool args1[] = {false, false, false}; + const bool args2[] = {false, true, false}; + + CHECK_EQ(5, Luau::Compile::computeCost(model, args1, 3)); + CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 3)); +} + +TEST_CASE("PropagateVariable") +{ + uint64_t model = modelFunction(R"( +function test(a) + local b = a * a * a + return b * b +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(0, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("LoopAssign") +{ + uint64_t model = modelFunction(R"( +function test(a) + for i=1,3 do + a[i] = i + end +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + // loop baseline cost is 2 + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("MutableVariable") +{ + uint64_t model = modelFunction(R"( +function test(a, b) + local x = a * a + x += b + return x * x +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 9dc9fee..d8b37a6 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -231,7 +231,7 @@ ModulePtr Fixture::getMainModule() SourceModule* Fixture::getMainSourceModule() { - return frontend.getSourceModule(fromString("MainModule")); + return frontend.getSourceModule(fromString(mainModuleName)); } std::optional Fixture::getPrimitiveType(TypeId ty) @@ -259,7 +259,7 @@ std::optional Fixture::getType(const std::string& name) TypeId Fixture::requireType(const std::string& name) { std::optional ty = getType(name); - REQUIRE(bool(ty)); + REQUIRE_MESSAGE(bool(ty), "Unable to requireType \"" << name << "\""); return follow(*ty); } diff --git a/tests/JsonEncoder.test.cpp b/tests/JsonEncoder.test.cpp index cb50807..1d2ad64 100644 --- a/tests/JsonEncoder.test.cpp +++ b/tests/JsonEncoder.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Ast.h" #include "Luau/JsonEncoder.h" +#include "Luau/Parser.h" #include "doctest.h" @@ -50,4 +51,26 @@ TEST_CASE("encode_AstStatBlock") toJson(&block)); } +TEST_CASE("encode_tables") +{ + std::string src = R"( + local x: { + foo: number + } = { + foo = 123, + } + )"; + + Allocator allocator; + AstNameTable names(allocator); + ParseResult parseResult = Parser::parse(src.c_str(), src.length(), names, allocator); + + REQUIRE(parseResult.errors.size() == 0); + std::string json = toJson(parseResult.root); + + CHECK( + json == + R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"type":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","location":"2,12 - 2,15","type":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":false},"name":"x","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})"); +} + TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 9ce9a4c..05ee9a7 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -597,8 +597,6 @@ return foo1 TEST_CASE_FIXTURE(Fixture, "UnknownType") { - ScopedFastFlag sff("LuauLintNoRobloxBits", true); - unfreeze(typeChecker.globalTypes); TableTypeVar::Props instanceProps{ {"ClassName", {typeChecker.anyType}}, @@ -1439,6 +1437,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") { unfreeze(typeChecker.globalTypes); TypeId instanceType = typeChecker.globalTypes.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, {}}); + persist(instanceType); typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; getMutable(instanceType)->props = { diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index de06312..738893d 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -2,6 +2,7 @@ #include "Luau/Clone.h" #include "Luau/Module.h" #include "Luau/Scope.h" +#include "Luau/RecursionCounter.h" #include "Fixture.h" @@ -9,6 +10,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("ModuleTests"); TEST_CASE_FIXTURE(Fixture, "is_within_comment") @@ -42,29 +45,23 @@ TEST_CASE_FIXTURE(Fixture, "is_within_comment") TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; // numberType is persistent. We leave it as-is. - TypeId newNumber = clone(typeChecker.numberType, dest, seenTypes, seenTypePacks, cloneState); + TypeId newNumber = clone(typeChecker.numberType, dest, cloneState); CHECK_EQ(newNumber, typeChecker.numberType); } TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; // Create a new number type that isn't persistent unfreeze(typeChecker.globalTypes); TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveTypeVar{PrimitiveTypeVar::Number}); freeze(typeChecker.globalTypes); - TypeId newNumber = clone(oldNumber, dest, seenTypes, seenTypePacks, cloneState); + TypeId newNumber = clone(oldNumber, dest, cloneState); CHECK_NE(newNumber, oldNumber); CHECK_EQ(*oldNumber, *newNumber); @@ -90,12 +87,9 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") TypeId counterType = requireType("Cyclic"); - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - CloneState cloneState; - TypeArena dest; - TypeId counterCopy = clone(counterType, dest, seenTypes, seenTypePacks, cloneState); + CloneState cloneState; + TypeId counterCopy = clone(counterType, dest, cloneState); TableTypeVar* ttv = getMutable(counterCopy); REQUIRE(ttv != nullptr); @@ -112,8 +106,11 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") REQUIRE(methodReturnType); CHECK_EQ(methodReturnType, counterCopy); - CHECK_EQ(2, dest.typePacks.size()); // one for the function args, and another for its return type - CHECK_EQ(2, dest.typeVars.size()); // One table and one function + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(3, dest.typePacks.size()); // function args, its return type, and the hidden any... pack + else + CHECK_EQ(2, dest.typePacks.size()); // one for the function args, and another for its return type + CHECK_EQ(2, dest.typeVars.size()); // One table and one function } TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena") @@ -143,15 +140,12 @@ TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena") TEST_CASE_FIXTURE(Fixture, "deepClone_union") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldUnion = typeChecker.globalTypes.addType(UnionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newUnion = clone(oldUnion, dest, seenTypes, seenTypePacks, cloneState); + TypeId newUnion = clone(oldUnion, dest, cloneState); CHECK_NE(newUnion, oldUnion); CHECK_EQ("number | string", toString(newUnion)); @@ -161,15 +155,12 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_union") TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") { TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newIntersection = clone(oldIntersection, dest, seenTypes, seenTypePacks, cloneState); + TypeId newIntersection = clone(oldIntersection, dest, cloneState); CHECK_NE(newIntersection, oldIntersection); CHECK_EQ("number & string", toString(newIntersection)); @@ -191,12 +182,9 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") std::nullopt, &exampleMetaClass, {}, {}}}; TypeArena dest; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - TypeId cloned = clone(&exampleClass, dest, seenTypes, seenTypePacks, cloneState); + TypeId cloned = clone(&exampleClass, dest, cloneState); const ClassTypeVar* ctv = get(cloned); REQUIRE(ctv != nullptr); @@ -216,16 +204,14 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") TypePackVar freeTp(FreeTypePack{TypeLevel{}}); TypeArena dest; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, cloneState); + TypeId clonedTy = clone(&freeTy, dest, cloneState); CHECK_EQ("any", toString(clonedTy)); CHECK(cloneState.encounteredFreeType); cloneState = {}; - TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, cloneState); + TypePackId clonedTp = clone(&freeTp, dest, cloneState); CHECK_EQ("...any", toString(clonedTp)); CHECK(cloneState.encounteredFreeType); } @@ -237,16 +223,32 @@ TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") ttv->state = TableState::Free; TypeArena dest; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - TypeId cloned = clone(&tableTy, dest, seenTypes, seenTypePacks, cloneState); + TypeId cloned = clone(&tableTy, dest, cloneState); const TableTypeVar* clonedTtv = get(cloned); CHECK_EQ(clonedTtv->state, TableState::Sealed); CHECK(cloneState.encounteredFreeType); } +TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection") +{ + TypeArena src; + + TypeId constrained = src.addType(ConstrainedTypeVar{TypeLevel{}, {getSingletonTypes().numberType, getSingletonTypes().stringType}}); + + TypeArena dest; + CloneState cloneState; + + TypeId cloned = clone(constrained, dest, cloneState); + CHECK_NE(constrained, cloned); + + const ConstrainedTypeVar* ctv = get(cloned); + REQUIRE_EQ(2, ctv->parts.size()); + CHECK_EQ(getSingletonTypes().numberType, ctv->parts[0]); + CHECK_EQ(getSingletonTypes().stringType, ctv->parts[1]); +} + TEST_CASE_FIXTURE(Fixture, "clone_self_property") { ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; @@ -284,6 +286,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") int limit = 400; #endif ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; + ScopedFastFlag sff{"LuauRecursionLimitException", true}; TypeArena src; @@ -299,11 +302,9 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") } TypeArena dest; - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; CloneState cloneState; - CHECK_THROWS_AS(clone(table, dest, seenTypes, seenTypePacks, cloneState), std::runtime_error); + CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); } TEST_SUITE_END(); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index d3faea2..a8a12b6 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -275,4 +275,38 @@ TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") REQUIRE_EQ("any", toString(getMainModule()->getModuleScope()->returnType)); } +TEST_CASE_FIXTURE(Fixture, "returning_insufficient_return_values") +{ + CheckResult result = check(R"( + --!nonstrict + + function foo(): (boolean, string?) + if true then + return true, "hello" + else + return false + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "returning_too_many_values") +{ + CheckResult result = check(R"( + --!nonstrict + + function foo(): boolean + if true then + return true, "hello" + else + return false + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp new file mode 100644 index 0000000..5a84201 --- /dev/null +++ b/tests/Normalize.test.cpp @@ -0,0 +1,967 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" + +#include "Luau/Normalize.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +struct NormalizeFixture : Fixture +{ + ScopedFastFlag sff1{"LuauLowerBoundsCalculation", true}; + ScopedFastFlag sff2{"LuauTableSubtypingVariance2", true}; +}; + +void createSomeClasses(TypeChecker& typeChecker) +{ + auto& arena = typeChecker.globalTypes; + + unfreeze(arena); + + TypeId parentType = arena.addType(ClassTypeVar{"Parent", {}, std::nullopt, std::nullopt, {}, nullptr}); + + ClassTypeVar* parentClass = getMutable(parentType); + parentClass->props["method"] = {makeFunction(arena, parentType, {}, {})}; + + parentClass->props["virtual_method"] = {makeFunction(arena, parentType, {}, {})}; + + addGlobalBinding(typeChecker, "Parent", {parentType}); + typeChecker.globalScope->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; + + TypeId childType = arena.addType(ClassTypeVar{"Child", {}, parentType, std::nullopt, {}, nullptr}); + + ClassTypeVar* childClass = getMutable(childType); + childClass->props["virtual_method"] = {makeFunction(arena, childType, {}, {})}; + + addGlobalBinding(typeChecker, "Child", {childType}); + typeChecker.globalScope->exportedTypeBindings["Child"] = TypeFun{{}, childType}; + + TypeId unrelatedType = arena.addType(ClassTypeVar{"Unrelated", {}, std::nullopt, std::nullopt, {}, nullptr}); + + addGlobalBinding(typeChecker, "Unrelated", {unrelatedType}); + typeChecker.globalScope->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; + + freeze(arena); +} + +static bool isSubtype(TypeId a, TypeId b) +{ + InternalErrorReporter ice; + return isSubtype(a, b, ice); +} + +TEST_SUITE_BEGIN("isSubtype"); + +TEST_CASE_FIXTURE(NormalizeFixture, "primitives") +{ + check(R"( + local a = 41 + local b = 32 + + local c = "hello" + local d = "world" + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(isSubtype(d, c)); + CHECK(!isSubtype(d, a)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions") +{ + check(R"( + function a(x: number): number return x end + function b(x: number): number return x end + + function c(x: number?): number return x end + function d(x: number): number? return x end + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(isSubtype(c, a)); + CHECK(!isSubtype(d, a)); + CHECK(isSubtype(a, d)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions_and_any") +{ + check(R"( + function a(n: number) return "string" end + function b(q: any) return 5 :: any end + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + // Intuition: + // We cannot use b where a is required because we cannot rely on b to return a string. + // We cannot use a where b is required because we cannot rely on a to accept non-number arguments. + + CHECK(!isSubtype(b, a)); + CHECK(!isSubtype(a, b)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions_of_different_arities") +{ + check(R"( + type A = (any) -> () + type B = (any, any) -> () + type T = A & B + + local a: A + local b: B + local t: T + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(!isSubtype(a, b)); // !! + CHECK(!isSubtype(b, a)); + + CHECK("((any) -> ()) & ((any, any) -> ())" == toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity") +{ + check(R"( + local a: (number) -> () + local b: () -> () + + local c: () -> number + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(!isSubtype(b, a)); + CHECK(!isSubtype(c, a)); + + CHECK(!isSubtype(a, b)); + CHECK(!isSubtype(c, b)); + + CHECK(!isSubtype(a, c)); + CHECK(!isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "functions_with_mismatching_arity_but_optional_parameters") +{ + /* + * (T0..TN) <: (T0..TN, A?) + * (T0..TN) <: (T0..TN, any) + * (T0..TN, A?) R <: U -> S if U <: T and R <: S + * A | B <: T if A <: T and B <: T + * T <: A | B if T <: A or T <: B + */ + check(R"( + local a: (number?) -> () + local b: (number) -> () + local c: (number, number?) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + /* + * (number) -> () () + * because number? () + * because number? () <: (number) -> () + * because number <: number? (because number <: number) + */ + CHECK(isSubtype(a, b)); + + /* + * (number, number?) -> () <: (number) -> (number) + * The packs have inequal lengths, but (number) <: (number, number?) + * and number <: number + */ + CHECK(!isSubtype(c, b)); + + /* + * (number?) -> () () + * because (number, number?) () () + * because (number, number?) () + local b: (number) -> () + local c: (number, any) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + /* + * (number) -> () () + * because number? () + * because number? () <: (number) -> () + * because number <: number? (because number <: number) + */ + CHECK(isSubtype(a, b)); + + /* + * (number, any) -> () (number) + * The packs have inequal lengths + */ + CHECK(!isSubtype(c, b)); + + /* + * (number?) -> () () + * The packs have inequal lengths + */ + CHECK(!isSubtype(a, c)); + + /* + * (number) -> () () + * The packs have inequal lengths + */ + CHECK(!isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "variadic_functions_with_no_head") +{ + check(R"( + local a: (...number) -> () + local b: (...number?) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); +} + +#if 0 +TEST_CASE_FIXTURE(NormalizeFixture, "variadic_function_with_head") +{ + check(R"( + local a: (...number) -> () + local b: (number, number) -> () + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); +} +#endif + +TEST_CASE_FIXTURE(NormalizeFixture, "union") +{ + check(R"( + local a: number | string + local b: number + local c: string + local d: number? + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(!isSubtype(a, c)); + + CHECK(!isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(b, d)); + CHECK(!isSubtype(d, b)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_with_union_prop") +{ + check(R"( + local a: {x: number} + local b: {x: number?} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_with_any_prop") +{ + check(R"( + local a: {x: number} + local b: {x: any} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection") +{ + check(R"( + local a: number & string + local b: number + local c: string + local d: number & nil + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); + + CHECK(!isSubtype(c, a)); + CHECK(isSubtype(a, c)); + + CHECK(!isSubtype(d, a)); + CHECK(!isSubtype(a, d)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_and_intersection") +{ + check(R"( + local a: number & string + local b: number | nil + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(!isSubtype(b, a)); + CHECK(isSubtype(a, b)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_with_table_prop") +{ + check(R"( + type T = {x: {y: number}} & {x: {y: string}} + local a: T + )"); + + CHECK_EQ("{| x: {| y: number & string |} |}", toString(requireType("a"))); +} + +#if 0 +TEST_CASE_FIXTURE(NormalizeFixture, "tables") +{ + check(R"( + local a: {x: number} + local b: {x: any} + local c: {y: number} + local d: {x: number, y: number} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + + CHECK(isSubtype(a, b)); + CHECK(!isSubtype(b, a)); + + CHECK(!isSubtype(c, a)); + CHECK(!isSubtype(a, c)); + + CHECK(isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(d, b)); + CHECK(!isSubtype(b, d)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "table_indexers_are_invariant") +{ + check(R"( + local a: {[string]: number} + local b: {[string]: any} + local c: {[string]: number} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(!isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(isSubtype(a, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "mismatched_indexers") +{ + check(R"( + local a: {x: number} + local b: {[string]: number} + local c: {} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(!isSubtype(c, b)); + CHECK(isSubtype(b, c)); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_table") +{ + check(R"( + type A = {method: (A) -> ()} + local a: A + + type B = {method: (any) -> ()} + local b: B + + type C = {method: (C) -> ()} + local c: C + + type D = {method: (D) -> (), another: (D) -> ()} + local d: D + + type E = {method: (A) -> (), another: (E) -> ()} + local e: E + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + TypeId d = requireType("d"); + TypeId e = requireType("e"); + + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, b)); + + CHECK(isSubtype(c, a)); + CHECK(isSubtype(a, c)); + + CHECK(!isSubtype(d, a)); + CHECK(!isSubtype(a, d)); + + CHECK(isSubtype(e, a)); + CHECK(!isSubtype(a, e)); +} +#endif + +TEST_CASE_FIXTURE(NormalizeFixture, "classes") +{ + createSomeClasses(typeChecker); + + TypeId p = typeChecker.globalScope->lookupType("Parent")->type; + TypeId c = typeChecker.globalScope->lookupType("Child")->type; + TypeId u = typeChecker.globalScope->lookupType("Unrelated")->type; + + CHECK(isSubtype(c, p)); + CHECK(!isSubtype(p, c)); + CHECK(!isSubtype(u, p)); + CHECK(!isSubtype(p, u)); +} + +#if 0 +TEST_CASE_FIXTURE(NormalizeFixture, "metatable" * doctest::expected_failures{1}) +{ + check(R"( + local T = {} + T.__index = T + function T.new() + return setmetatable({}, T) + end + + function T:method() end + + local a: typeof(T.new) + local b: {method: (any) -> ()} + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); +} +#endif + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_tables") +{ + check(R"( + type T = {x: number} & ({x: number} & {y: string?}) + local t: T + )"); + + CHECK("{| x: number, y: string? |}" == toString(requireType("t"))); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("Normalize"); + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_disjoint_tables") +{ + check(R"( + type T = {a: number} & {b: number} + local t: T + )"); + + CHECK_EQ("{| a: number, b: number |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_overlapping_tables") +{ + check(R"( + type T = {a: number, b: string} & {b: number, c: string} + local t: T + )"); + + CHECK_EQ("{| a: number, b: number & string, c: string |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_confluent_overlapping_tables") +{ + check(R"( + type T = {a: number, b: string} & {b: string, c: string} + local t: T + )"); + + CHECK_EQ("{| a: number, b: string, c: string |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "union_with_overlapping_field_that_has_a_subtype_relationship") +{ + check(R"( + local t: {x: number} | {x: number?} + )"); + + ModulePtr tempModule{new Module}; + + // HACK: Normalization is an in-place operation. We need to cheat a little here and unfreeze + // the arena that the type lives in. + ModulePtr mainModule = getMainModule(); + unfreeze(mainModule->internalTypes); + + TypeId tType = requireType("t"); + normalize(tType, tempModule, *typeChecker.iceHandler); + + CHECK_EQ("{| x: number? |}", toString(tType, {true})); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_functions") +{ + check(R"( + type T = ((any) -> string) & ((number) -> string) + local t: T + )"); + + CHECK_EQ("(any) -> string", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "normalize_module_return_type") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + check(R"( + --!nonstrict + + if Math.random() then + return function(initialState, handlers) + return function(state, action) + return state + end + end + else + return function(initialState, handlers) + return function(state, action) + return state + end + end + end + )"); + + CHECK_EQ("(any, any) -> (...any)", toString(getMainModule()->getModuleScope()->returnType)); +} + +TEST_CASE_FIXTURE(Fixture, "return_type_is_not_a_constrained_intersection") +{ + check(R"( + function foo(x:number, y:number) + return x + y + end + )"); + + CHECK_EQ("(number, number) -> number", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function") +{ + check(R"( + function apply(f, x) + return f(x) + end + + local a = apply(function(x: number) return x + x end, 5) + )"); + + TypeId aType = requireType("a"); + CHECK_MESSAGE(isNumber(follow(aType)), "Expected a number but got ", toString(aType)); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_with_annotation") +{ + check(R"( + function apply(f: (a) -> b, x) + return f(x) + end + )"); + + CHECK_EQ("((a) -> b, a) -> b", toString(requireType("apply"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_table_is_marked_normal") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + check(R"( + type Fiber = { + return_: Fiber? + } + + local f: Fiber + )"); + + TypeId t = requireType("f"); + CHECK(t->normal); +} + +TEST_CASE_FIXTURE(Fixture, "variadic_tail_is_marked_normal") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + type Weirdo = (...{x: number}) -> () + + local w: Weirdo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("w"); + auto ftv = get(t); + REQUIRE(ftv); + + auto [argHead, argTail] = flatten(ftv->argTypes); + CHECK(argHead.empty()); + REQUIRE(argTail.has_value()); + + auto vtp = get(*argTail); + REQUIRE(vtp); + CHECK(vtp->ty->normal); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_table_normalizes_sensibly") +{ + CheckResult result = check(R"( + local Cyclic = {} + function Cyclic.get() + return Cyclic + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId ty = requireType("Cyclic"); + CHECK_EQ("t1 where t1 = { get: () -> t1 }", toString(ty, {true})); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_distinct_free_types") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function fussy(a, b) + if math.random() > 0.5 then + return a + else + return b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("(a, b) -> a | b" == toString(requireType("fussy"))); +} + +TEST_CASE_FIXTURE(Fixture, "constrained_intersection_of_intersections") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + local f : (() -> number) | ((number) -> number) + local g : (() -> number) | ((string) -> number) + + function h() + if math.random() then + return f + else + return g + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId h = requireType("h"); + + CHECK("() -> (() -> number) | ((number) -> number) | ((string) -> number)" == toString(h)); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + type X = {} + type Y = {y: number} + type Z = {z: string} + type W = {w: boolean} + type T = {x: Y & X} & {x:Z & W} + + local x: X + local y: Y + local z: Z + local w: W + local t: T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("{| |}" == toString(requireType("x"), {true})); + CHECK("{| y: number |}" == toString(requireType("y"), {true})); + CHECK("{| z: string |}" == toString(requireType("z"), {true})); + CHECK("{| w: boolean |}" == toString(requireType("w"), {true})); + CHECK("{| x: {| w: boolean, y: number, z: string |} |}" == toString(requireType("t"), {true})); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_2") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(w, x, y, z) + y.y = 5 + z.z = "five" + w.w = true + + type Z = {x: typeof(x) & typeof(y)} & {x: typeof(w) & typeof(z)} + + return ((nil :: any) :: Z) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("strange"); + auto ftv = get(t); + REQUIRE(ftv != nullptr); + + std::vector args = flatten(ftv->argTypes).first; + + REQUIRE(4 == args.size()); + CHECK("{+ w: boolean +}" == toString(args[0])); + CHECK("a" == toString(args[1])); + CHECK("{+ y: number +}" == toString(args[2])); + CHECK("{+ z: string +}" == toString(args[3])); + + std::vector ret = flatten(ftv->retType).first; + + REQUIRE(1 == ret.size()); + CHECK("{| x: a & {- w: boolean, y: number, z: string -} |}" == toString(ret[0])); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_3") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(x, y, z) + x.x = true + y.y = y + z.z = "five" + + type Z = {x: typeof(y)} & {x: typeof(x) & typeof(z)} + + return ((nil :: any) :: Z) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("strange"); + auto ftv = get(t); + REQUIRE(ftv != nullptr); + + std::vector args = flatten(ftv->argTypes).first; + + REQUIRE(3 == args.size()); + CHECK("{+ x: boolean +}" == toString(args[0])); + CHECK("t1 where t1 = {+ y: t1 +}" == toString(args[1])); + CHECK("{+ z: string +}" == toString(args[2])); + + std::vector ret = flatten(ftv->retType).first; + + REQUIRE(1 == ret.size()); + CHECK("{| x: {- x: boolean, y: t1, z: string -} |} where t1 = {+ y: t1 +}" == toString(ret[0])); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_4") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(x, y, z) + x.x = true + z.z = "five" + + type R = {x: typeof(y)} & {x: typeof(x) & typeof(z)} + local r: R + + y.y = r + + return r + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId t = requireType("strange"); + auto ftv = get(t); + REQUIRE(ftv != nullptr); + + std::vector args = flatten(ftv->argTypes).first; + + REQUIRE(3 == args.size()); + CHECK("{+ x: boolean +}" == toString(args[0])); + CHECK("{+ y: t1 +} where t1 = {| x: {- x: boolean, y: t1, z: string -} |}" == toString(args[1])); + CHECK("{+ z: string +}" == toString(args[2])); + + std::vector ret = flatten(ftv->retType).first; + + REQUIRE(1 == ret.size()); + CHECK("t1 where t1 = {| x: {- x: boolean, y: t1, z: string -} |}" == toString(ret[0])); +} + +TEST_CASE_FIXTURE(Fixture, "nested_table_normalization_with_non_table__no_ice") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeCombineTableFix", true}, + }; + // CLI-52787 + // ends up combining {_:any} with any, recursively + // which used to ICE because this combines a table with a non-table. + CheckResult result = check(R"( + export type t0 = any & { _: {_:any} } & { _:any } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "fuzz_failure_instersection_combine_must_follow") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeCombineIntersectionFix", true}, + }; + + CheckResult result = check(R"( + export type t0 = {_:{_:any} & {_:any|string}} & {_:{_:{}}} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 79f9eca..b941103 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1618,6 +1618,26 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments") CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end); } +TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments_even_with_capture") +{ + ScopedFastFlag luauParseLocationIgnoreCommentSkip{"LuauParseLocationIgnoreCommentSkip", true}; + ScopedFastFlag luauParseLocationIgnoreCommentSkipInCapture{"LuauParseLocationIgnoreCommentSkipInCapture", true}; + + // Same should hold when comments are captured + ParseOptions opts; + opts.captureComments = true; + + AstStatBlock* block = parse(R"( + type F = number + --comment + print('hello') + )", + opts); + + REQUIRE_EQ(2, block->body.size); + CHECK_EQ((Position{1, 23}), block->body.data[0]->location.end); +} + TEST_CASE_FIXTURE(Fixture, "parse_error_loop_control") { matchParseError("break", "break statement must be inside a loop"); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 29bdd86..f3fda54 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -7,6 +7,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauLowerBoundsCalculation) + using namespace Luau; struct ToDotClassFixture : Fixture @@ -101,9 +103,34 @@ local function f(a, ...: string) return a end )"); LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(a, ...string) -> a", toString(requireType("f"))); + ToDotOptions opts; opts.showPointers = false; - CHECK_EQ(R"(digraph graphname { + + if (FFlag::LuauLowerBoundsCalculation) + { + CHECK_EQ(R"(digraph graphname { +n1 [label="FunctionTypeVar 1"]; +n1 -> n2 [label="arg"]; +n2 [label="TypePack 2"]; +n2 -> n3; +n3 [label="GenericTypeVar 3"]; +n2 -> n4 [label="tail"]; +n4 [label="VariadicTypePack 4"]; +n4 -> n5; +n5 [label="string"]; +n1 -> n6 [label="ret"]; +n6 [label="TypePack 6"]; +n6 -> n7; +n7 [label="BoundTypeVar 7"]; +n7 -> n3; +})", + toDot(requireType("f"), opts)); + } + else + { + CHECK_EQ(R"(digraph graphname { n1 [label="FunctionTypeVar 1"]; n1 -> n2 [label="arg"]; n2 [label="TypePack 2"]; @@ -119,7 +146,8 @@ n6 -> n7; n7 [label="TypePack 7"]; n7 -> n3; })", - toDot(requireType("f"), opts)); + toDot(requireType("f"), opts)); + } } TEST_CASE_FIXTURE(Fixture, "union") @@ -361,4 +389,49 @@ n3 [label="number"]; toDot(*ty, opts)); } +TEST_CASE_FIXTURE(Fixture, "constrained") +{ + // ConstrainedTypeVars never appear in the final type graph, so we have to create one directly + // to dotify it. + TypeVar t{ConstrainedTypeVar{TypeLevel{}, {typeChecker.numberType, typeChecker.stringType, typeChecker.nilType}}}; + + ToDotOptions opts; + opts.showPointers = false; + + CHECK_EQ(R"(digraph graphname { +n1 [label="ConstrainedTypeVar 1"]; +n1 -> n2; +n2 [label="number"]; +n1 -> n3; +n3 [label="string"]; +n1 -> n4; +n4 [label="nil"]; +})", + toDot(&t, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "singletontypes") +{ + CheckResult result = check(R"( + local x: "hi" | "\"hello\"" | true | false + )"); + + ToDotOptions opts; + opts.showPointers = false; + + CHECK_EQ(R"(digraph graphname { +n1 [label="UnionTypeVar 1"]; +n1 -> n2; +n2 [label="SingletonTypeVar string: hi"]; +n1 -> n3; +)" +"n3 [label=\"SingletonTypeVar string: \\\"hello\\\"\"];" +R"( +n1 -> n4; +n4 [label="SingletonTypeVar boolean: true"]; +n1 -> n5; +n5 [label="SingletonTypeVar boolean: false"]; +})", toDot(requireType("x"), opts)); +} + TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 3051e20..ccf5c58 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); + TEST_SUITE_BEGIN("ToString"); TEST_CASE_FIXTURE(Fixture, "primitive") diff --git a/tests/TopoSort.test.cpp b/tests/TopoSort.test.cpp index 9b99086..1f14ae8 100644 --- a/tests/TopoSort.test.cpp +++ b/tests/TopoSort.test.cpp @@ -340,26 +340,28 @@ TEST_CASE_FIXTURE(Fixture, "nested_type_annotations_depends_on_later_typealiases TEST_CASE_FIXTURE(Fixture, "return_comes_last") { - CheckResult result = check(R"( -export type Module = { bar: (number) -> boolean, foo: () -> string } + AstStatBlock* program = parse(R"( + local module = {} -return function() : Module - local module = {} + local function confuseCompiler() return module.foo() end - local function confuseCompiler() return module.foo() end - - module.foo = function() return "" end + module.foo = function() return "" end - function module.bar(x:number) - confuseCompiler() - return true - end - - return module -end + function module.bar(x:number) + confuseCompiler() + return true + end + + return module )"); - LUAU_REQUIRE_NO_ERRORS(result); + auto sorted = toposort(*program); + + CHECK_EQ(sorted[0], program->body.data[0]); + CHECK_EQ(sorted[2], program->body.data[1]); + CHECK_EQ(sorted[1], program->body.data[2]); + CHECK_EQ(sorted[3], program->body.data[3]); + CHECK_EQ(sorted[4], program->body.data[4]); } TEST_CASE_FIXTURE(Fixture, "break_comes_last") diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 5ac45ff..0c324cd 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -388,7 +388,7 @@ TEST_CASE_FIXTURE(Fixture, "type_lists_should_be_emitted_correctly") std::string actual = decorateWithTypes(code); - CHECK_EQ(expected, decorateWithTypes(code)); + CHECK_EQ(expected, actual); } TEST_CASE_FIXTURE(Fixture, "function_type_location") diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 2ad11d0..e2971ad 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -753,4 +753,14 @@ TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_intersection_typevar") REQUIRE(ocf); } +TEST_CASE_FIXTURE(Fixture, "instantiation_clone_has_to_follow") +{ + CheckResult result = check(R"( + export type t8 = (t0)&(((true)|(any))->"") + export type t0 = ({})&({_:{[any]:number},}) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index c6fbebe..1ae6594 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -8,6 +8,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("BuiltinTests"); TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") @@ -557,9 +559,9 @@ TEST_CASE_FIXTURE(Fixture, "xpcall") )"); LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("boolean", toString(requireType("a"))); - REQUIRE_EQ("number", toString(requireType("b"))); - REQUIRE_EQ("boolean", toString(requireType("c"))); + CHECK_EQ("boolean", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + CHECK_EQ("boolean", toString(requireType("c"))); } TEST_CASE_FIXTURE(Fixture, "see_thru_select") @@ -881,7 +883,10 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); + else + CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types2") diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 98fa66e..8e3629e 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -91,6 +91,9 @@ struct ClassFixture : Fixture typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; addGlobalBinding(typeChecker, "Vector2", vector2Type, "@test"); + for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings) + persist(tf.type); + freeze(arena); } }; diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 1713216..6599368 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -13,6 +13,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("TypeInferFunctions"); TEST_CASE_FIXTURE(Fixture, "tc_function") @@ -98,7 +100,7 @@ TEST_CASE_FIXTURE(Fixture, "vararg_function_is_quantified") end return result - end + end return T )"); @@ -274,6 +276,10 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") { + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( function f(g) return f(f) @@ -281,7 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f"))); + CHECK_EQ("t1 where t1 = (t1) -> (a...)", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "another_higher_order_function") @@ -481,10 +487,10 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") std::vector fArgs = flatten(fType->argTypes).first; - TypeId xType = argVec[1]; + TypeId xType = follow(argVec[1]); CHECK_EQ(1, fArgs.size()); - CHECK_EQ(xType, fArgs[0]); + CHECK_EQ(xType, follow(fArgs[0])); } TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") @@ -1043,13 +1049,16 @@ f(function(x) return x * 2 end) LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); - // Return type doesn't inference 'nil' - result = check(R"( -function f(a: (number) -> nil) return a(4) end -f(function(x) print(x) end) - )"); + if (!FFlag::LuauLowerBoundsCalculation) + { + // Return type doesn't inference 'nil' + result = check(R"( + function f(a: (number) -> nil) return a(4) end + f(function(x) print(x) end) + )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") @@ -1142,13 +1151,16 @@ f(function(x) return x * 2 end) LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); - // Return type doesn't inference 'nil' - result = check(R"( -function f(a: (number) -> nil) return a(4) end -f(function(x) print(x) end) - )"); + if (!FFlag::LuauLowerBoundsCalculation) + { + // Return type doesn't inference 'nil' + result = check(R"( + function f(a: (number) -> nil) return a(4) end + f(function(x) print(x) end) + )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") @@ -1338,6 +1350,126 @@ end CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); } +TEST_CASE_FIXTURE(Fixture, "inconsistent_return_types") +{ + const ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function foo(a: boolean, b: number) + if a then + return nil + else + return b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(boolean, number) -> number?", toString(requireType("foo"))); + + // TODO: Test multiple returns + // Think of various cases where typepacks need to grow. maybe consult other tests + // Basic normalization of ConstrainedTypeVars during quantification +} + +TEST_CASE_FIXTURE(Fixture, "inconsistent_higher_order_function") +{ + const ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function foo(f) + f(5) + f("six") + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("((number | string) -> (a...)) -> ()", toString(requireType("foo"))); +} + + +/* The bug here is that we are using the same level 2.0 for both the body of resolveDispatcher and the + * lambda useCallback. + * + * I think what we want to do is, at each scope level, never reuse the same sublevel. + * + * We also adjust checkBlock to consider the syntax `local x = function() ... end` to be sortable + * in the same way as `local function x() ... end`. This causes the function `resolveDispatcher` to be + * checked before the lambda. + */ +TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + --!strict + + local function resolveDispatcher() + return (nil :: any) :: {useCallback: (any) -> any} + end + + local useCallback = function(deps: any) + return resolveDispatcher().useCallback(deps) + end + )"); + + // LUAU_REQUIRE_NO_ERRORS is particularly unhelpful when this test is broken. + // You get a TypeMismatch error where both types stringify the same. + + CHECK(result.errors.empty()); + if (!result.errors.empty()) + { + for (const auto& e : result.errors) + printf("%s: %s\n", toString(e.location).c_str(), toString(e).c_str()); + } +} + +TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time2") +{ + CheckResult result = check(R"( + --!strict + + local function resolveDispatcher() + return (nil :: any) :: {useContext: (number?) -> any} + end + + local useContext + useContext = function(unstable_observedBits: number?) + resolveDispatcher().useContext(unstable_observedBits) + end + )"); + + // LUAU_REQUIRE_NO_ERRORS is particularly unhelpful when this test is broken. + // You get a TypeMismatch error where both types stringify the same. + + CHECK(result.errors.empty()); + if (!result.errors.empty()) + { + for (const auto& e : result.errors) + printf("%s %s: %s\n", e.moduleName.c_str(), toString(e.location).c_str(), toString(e).c_str()); + } +} + +TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_the_right_time3") +{ + CheckResult result = check(R"( + local foo + + foo():bar(function() + return foo() + end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_unsealed_overwrite") { ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; @@ -1471,4 +1603,17 @@ pcall(wrapper, test) CHECK(acm->isVariadic); } +TEST_CASE_FIXTURE(Fixture, "occurs_check_failure_in_function_return_type") +{ + CheckResult result = check(R"( + function f() + return 5, f() + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(nullptr != get(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index f360a77..49d31fc 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -230,8 +230,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function") CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); - CHECK_EQ(args[0], idFun->generics[0]); - CHECK_EQ(rets[0], idFun->generics[0]); + CHECK_EQ(follow(args[0]), follow(idFun->generics[0])); + CHECK_EQ(follow(rets[0]), follow(idFun->generics[0])); } TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") @@ -253,8 +253,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); - CHECK_EQ(args[0], idFun->generics[0]); - CHECK_EQ(rets[0], idFun->generics[0]); + CHECK_EQ(follow(args[0]), follow(idFun->generics[0])); + CHECK_EQ(follow(rets[0]), follow(idFun->generics[0])); } TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") @@ -705,10 +705,10 @@ end TEST_CASE_FIXTURE(Fixture, "generic_functions_should_be_memory_safe") { ScopedFastFlag sffs[] = { - { "LuauTableSubtypingVariance2", true }, - { "LuauUnsealedTableLiteral", true }, - { "LuauPropertiesGetExpectedType", true }, - { "LuauRecursiveTypeParameterRestriction", true }, + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + {"LuauPropertiesGetExpectedType", true}, + {"LuauRecursiveTypeParameterRestriction", true}, }; CheckResult result = check(R"( @@ -843,6 +843,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_function") LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(a) -> a", toString(requireType("id"))); CHECK_EQ(*typeChecker.numberType, *requireType("a")); CHECK_EQ(*typeChecker.nilType, *requireType("b")); } @@ -1037,25 +1038,39 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; CheckResult result = check(R"( -local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end -return sum(2, 3, function(a, b) return a + b end) + local function sum(x: a, y: a, f: (a, a) -> a) + return f(x, y) + end + return sum(2, 3, function(a, b) return a + b end) )"); LUAU_REQUIRE_NO_ERRORS(result); result = check(R"( -local function map(arr: {a}, f: (a) -> b) local r = {} for i,v in ipairs(arr) do table.insert(r, f(v)) end return r end -local a = {1, 2, 3} -local r = map(a, function(a) return a + a > 100 end) + local function map(arr: {a}, f: (a) -> b) + local r = {} + for i,v in ipairs(arr) do + table.insert(r, f(v)) + end + return r + end + local a = {1, 2, 3} + local r = map(a, function(a) return a + a > 100 end) )"); LUAU_REQUIRE_NO_ERRORS(result); REQUIRE_EQ("{boolean}", toString(requireType("r"))); check(R"( -local function foldl(arr: {a}, init: b, f: (b, a) -> b) local r = init for i,v in ipairs(arr) do r = f(r, v) end return r end -local a = {1, 2, 3} -local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) + local function foldl(arr: {a}, init: b, f: (b, a) -> b) + local r = init + for i,v in ipairs(arr) do + r = f(r, v) + end + return r + end + local a = {1, 2, 3} + local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1065,25 +1080,19 @@ local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} e TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") { CheckResult result = check(R"( -local function g1(a: T, f: (T) -> T) return f(a) end -local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + local g12: ((T, (T) -> T) -> T) & ((T, T, (T, T) -> T) -> T) -local g12: typeof(g1) & typeof(g2) - -g12(1, function(x) return x + x end) -g12(1, 2, function(x, y) return x + y end) + g12(1, function(x) return x + x end) + g12(1, 2, function(x, y) return x + y end) )"); LUAU_REQUIRE_NO_ERRORS(result); result = check(R"( -local function g1(a: T, f: (T) -> T) return f(a) end -local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + local g12: ((T, (T) -> T) -> T) & ((T, T, (T, T) -> T) -> T) -local g12: typeof(g1) & typeof(g2) - -g12({x=1}, function(x) return {x=-x.x} end) -g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) + g12({x=1}, function(x) return {x=-x.x} end) + g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1121,12 +1130,12 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") { CheckResult result = check(R"( -type A = { x: number } -local a: A = { x = 1 } -local b = a -type B = typeof(b) -type X = T -local c: X + type A = { x: number } + local a: A = { x = 1 } + local b = a + type B = typeof(b) + type X = T + local c: X )"); LUAU_REQUIRE_NO_ERRORS(result); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index ac7a653..3675919 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -8,6 +8,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("IntersectionTypes"); TEST_CASE_FIXTURE(Fixture, "select_correct_union_fn") @@ -306,7 +308,10 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table '{| x: number, y: number |}'"); + else + CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); } TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") @@ -314,27 +319,34 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; CheckResult result = check(R"( - type X = { x: (number) -> number } - type Y = { y: (string) -> string } + type X = { x: (number) -> number } + type Y = { y: (string) -> string } - type XY = X & Y + type XY = X & Y - local xy : XY = { - x = function(a: number) return -a end, - y = function(a: string) return a .. "b" end - } - function xy.z(a:number) return a * 10 end - function xy:y(a:number) return a * 10 end - function xy:w(a:number) return a * 10 end + local xy : XY = { + x = function(a: number) return -a end, + y = function(a: string) return a .. "b" end + } + function xy.z(a:number) return a * 10 end + function xy:y(a:number) return a * 10 end + function xy:w(a:number) return a * 10 end )"); LUAU_REQUIRE_ERROR_COUNT(4, result); CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' caused by: Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); - CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table '{| x: (number) -> number, y: (string) -> string |}'"); + else + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); - CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); + + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table '{| x: (number) -> number, y: (string) -> string |}'"); + else + CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); } TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect") @@ -375,6 +387,8 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_part") { + ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", false}}; + CheckResult result = check(R"( type X = { x: number } type Y = { y: number } @@ -393,6 +407,8 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all") { + ScopedFastFlag flags[] = {{"LuauLowerBoundsCalculation", false}}; + CheckResult result = check(R"( type X = { x: number } type Y = { y: number } @@ -427,8 +443,8 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection") repeat type t0 = ((any)|((any)&((any)|((any)&((any)|(any))))))&(t0) function _(l0):(t0)&(t0) - while nil do - end + while nil do + end end until _(_)(_)._ )"); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 40831bf..5cd3f3b 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -199,16 +199,16 @@ end TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") { CheckResult result = check(R"( ---!nonstrict -local f = {} -function f:foo(a: number, b: number) end + --!nonstrict + local f = {} + function f:foo(a: number, b: number) end -function bar(...) - f.foo(f, 1, ...) -end + function bar(...) + f.foo(f, 1, ...) + end -bar(2) -)"); + bar(2) + )"); LUAU_REQUIRE_NO_ERRORS(result); } diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 6a8a9d9..5f2e240 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -91,7 +91,8 @@ TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") const FunctionTypeVar* functionType = get(requireType("add")); std::optional retType = first(functionType->retType); - CHECK_EQ(std::optional(typeChecker.numberType), retType); + REQUIRE(retType.has_value()); + CHECK_EQ(typeChecker.numberType, follow(*retType)); CHECK_EQ(requireType("n"), typeChecker.numberType); CHECK_EQ(requireType("s"), typeChecker.stringType); } diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 2e16b21..8c8059d 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauLowerBoundsCalculation) using namespace Luau; @@ -527,6 +528,7 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table LUAU_REQUIRE_NO_ERRORS(result); } +// FIXME: Move this test to another source file when removing FFlag::LuauLowerBoundsCalculation TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack") { ScopedFastFlag sff[]{ @@ -556,10 +558,19 @@ TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type LUAU_REQUIRE_NO_ERRORS(result); - // f and g should have the type () -> () - CHECK_EQ("() -> (a...)", toString(requireType("f"))); - CHECK_EQ("() -> (a...)", toString(requireType("g"))); - CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now + if (FFlag::LuauLowerBoundsCalculation) + { + CHECK_EQ("() -> ()", toString(requireType("f"))); + CHECK_EQ("() -> ()", toString(requireType("g"))); + CHECK_EQ("nil", toString(requireType("x"))); + } + else + { + // f and g should have the type () -> () + CHECK_EQ("() -> (a...)", toString(requireType("f"))); + CHECK_EQ("() -> (a...)", toString(requireType("g"))); + CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now + } } TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") @@ -575,6 +586,10 @@ TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") { + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", false}, + }; + CheckResult result = check(R"( local function f() return end local g = function() return f() end @@ -585,6 +600,10 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") { + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", false}, + }; + CheckResult result = check(R"( --!strict local function f(...) return ... end @@ -594,4 +613,108 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") LUAU_REQUIRE_ERRORS(result); // Should not have any errors. } +TEST_CASE_FIXTURE(Fixture, "lower_bounds_calculation_is_too_permissive_with_overloaded_higher_order_functions") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + function foo(f) + f(5, 'a') + f('b', 6) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // We incorrectly infer that the argument to foo could be called with (number, number) or (string, string) + // even though that is strictly more permissive than the actual source text shows. + CHECK("((number | string, number | string) -> (a...)) -> ()" == toString(requireType("foo"))); +} + +// Once fixed, move this to Normalize.test.cpp +TEST_CASE_FIXTURE(Fixture, "normalization_fails_on_certain_kinds_of_cyclic_tables") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. + // This exposes a bug where the type of y is mutated. + CheckResult result = check(R"( + function strange(x, y) + x.x = y + y.x = x + + type R = {x: typeof(x)} & {x: typeof(y)} + local r: R + + return r + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(nullptr != get(result.errors[0])); +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(Fixture, "pcall_returns_at_least_two_value_but_function_returns_nothing") +{ + CheckResult result = check(R"( + local function f(): () end + local ok, res = pcall(f) + )"); + + LUAU_REQUIRE_ERRORS(result); + // LUAU_REQUIRE_NO_ERRORS(result); + // CHECK_EQ("boolean", toString(requireType("ok"))); + // CHECK_EQ("any", toString(requireType("res"))); +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(Fixture, "choose_the_right_overload_for_pcall") +{ + CheckResult result = check(R"( + local function f(): number + if math.random() > 0.5 then + return 5 + else + error("something") + end + end + + local ok, res = pcall(f) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("boolean", toString(requireType("ok"))); + CHECK_EQ("number", toString(requireType("res"))); + // CHECK_EQ("any", toString(requireType("res"))); +} + +// Belongs in TypeInfer.builtins.test.cpp. +TEST_CASE_FIXTURE(Fixture, "function_returns_many_things_but_first_of_it_is_forgotten") +{ + CheckResult result = check(R"( + local function f(): (number, string, boolean) + if math.random() > 0.5 then + return 5, "hello", true + else + error("something") + end + end + + local ok, res, s, b = pcall(f) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("boolean", toString(requireType("ok"))); + CHECK_EQ("number", toString(requireType("res"))); + // CHECK_EQ("any", toString(requireType("res"))); + CHECK_EQ("string", toString(requireType("s"))); + CHECK_EQ("boolean", toString(requireType("b"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index cddeab6..ce22bcb 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Normalize.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" @@ -8,6 +9,7 @@ LUAU_FASTFLAG(LuauDiscriminableUnions2) LUAU_FASTFLAG(LuauWeakEqConstraint) +LUAU_FASTFLAG(LuauLowerBoundsCalculation) using namespace Luau; @@ -48,6 +50,7 @@ struct RefinementClassFixture : Fixture {"Y", Property{typeChecker.numberType}}, {"Z", Property{typeChecker.numberType}}, }; + normalize(vec3, arena, *typeChecker.iceHandler); TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr}); @@ -55,17 +58,21 @@ struct RefinementClassFixture : Fixture TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); getMutable(isA)->magicFunction = magicFunctionInstanceIsA; + normalize(isA, arena, *typeChecker.iceHandler); getMutable(inst)->props = { {"Name", Property{typeChecker.stringType}}, {"IsA", Property{isA}}, }; + normalize(inst, arena, *typeChecker.iceHandler); TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr}); + normalize(folder, arena, *typeChecker.iceHandler); TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr}); getMutable(part)->props = { {"Position", Property{vec3}}, }; + normalize(part, arena, *typeChecker.iceHandler); typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; @@ -697,7 +704,10 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}))); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ("{| x: number, y: number |}", toString(requireTypeAtPosition({4, 28}))); + else + CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}))); CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index d39341e..2b01c29 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -5,8 +5,6 @@ #include "doctest.h" #include "Luau/BuiltinDefinitions.h" -LUAU_FASTFLAG(BetterDiagnosticCodesInStudio) - using namespace Luau; TEST_SUITE_BEGIN("TypeSingletons"); @@ -261,14 +259,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::BetterDiagnosticCodesInStudio) - { - CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0])); - } - else - { - CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0])); - } + CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 0484351..ca1b8de 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -11,6 +11,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("TableTests"); TEST_CASE_FIXTURE(Fixture, "basic") @@ -1211,7 +1213,10 @@ TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_c )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(get(result.errors[0])); + if (FFlag::LuauLowerBoundsCalculation) + CHECK(get(result.errors[0])); + else + CHECK(get(result.errors[0])); } // This unit test could be flaky if the fix has regressed. @@ -2922,6 +2927,60 @@ TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the LUAU_REQUIRE_NO_ERRORS(result); } +// The real bug here was that we weren't always uncondionally typechecking a trailing return statement last. +TEST_CASE_FIXTURE(Fixture, "dont_leak_free_table_props") +{ + CheckResult result = check(R"( + local function a(state) + print(state.blah) + end + + local function b(state) -- The bug was that we inferred state: {blah: any, gwar: any} + print(state.gwar) + end + + return function() + return function(state) + a(state) + b(state) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({+ blah: a +}) -> ()", toString(requireType("a"))); + CHECK_EQ("({+ gwar: a +}) -> ()", toString(requireType("b"))); + CHECK_EQ("() -> ({+ blah: a, gwar: b +}) -> ()", toString(getMainModule()->getModuleScope()->returnType)); +} + +TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") +{ + ScopedFastFlag sff[] = { + {"LuauLowerBoundsCalculation", true}, + }; + + check(R"( + function Base64FileReader(data) + local reader = {} + local index: number + + function reader:PeekByte() + return data:byte(index) + end + + function reader:Byte() + return data:byte(index - 1) + end + + return reader + end + )"); + + CHECK_EQ("(t1) -> {| Byte: (b) -> (a...), PeekByte: (c) -> (a...) |} where t1 = {+ byte: (t1, number) -> (a...) +}", + toString(requireType("Base64FileReader"))); +} + TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") { ScopedFastFlag sff{"LuauCheckImplicitNumbericKeys", true}; diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 660ddcf..6abd96b 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -13,6 +13,7 @@ #include +LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) LUAU_FASTFLAG(LuauEqConstraint) @@ -177,7 +178,6 @@ TEST_CASE_FIXTURE(Fixture, "weird_case") )"); LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); } TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") @@ -293,7 +293,7 @@ TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") // In these tests, a successful parse is required, so we need the parser to return the AST and then we can test the recursion depth limit in type // checker. We also want it to somewhat match up with production values, so we push up the parser recursion limit a little bit instead. -TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit") +TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_count") { #if defined(LUAU_ENABLE_ASAN) int limit = 250; @@ -302,12 +302,14 @@ TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit") #else int limit = 600; #endif - ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; - ScopedFastInt luauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", limit - 100}; - ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", 0}; - CHECK_NOTHROW(check("print('Hello!')")); - CHECK_THROWS_AS(check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"), std::runtime_error); + ScopedFastFlag sff{"LuauTableUseCounterInstead", true}; + ScopedFastInt sfi{"LuauCheckRecursionLimit", limit}; + + CheckResult result = check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(nullptr != get(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "check_block_recursion_limit") @@ -721,9 +723,9 @@ TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") local l0 do end while _ do - function _:_() - _ += _(_._(_:n0(xpcall,_))) - end + function _:_() + _ += _(_._(_:n0(xpcall,_))) + end end )"); @@ -978,4 +980,48 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") +{ + ScopedFastInt sfi("LuauTypeInferRecursionLimit", 2); + ScopedFastFlag sff{"LuauRecursionLimitException", true}; + + CheckResult result = check(R"( + function complex() + function _(l0:t0): (any, ()->()) + return 0,_ + end + type t0 = t0 | {} + _(nil) + end + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") +{ + ScopedFastFlag substituteFollowNewTypes{"LuauSubstituteFollowNewTypes", true}; + + CheckResult result = check(R"( + local obj = {} + + function obj:Method() + self.fieldA = function(object) + if object.a then + self.arr[object] = true + elseif object.b then + self.fieldB[object] = object:Connect(function(arg) + self.arr[arg] = nil + end) + end + end + end + + return obj + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 130f33d..f141622 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + TEST_SUITE_BEGIN("TypePackTests"); TEST_CASE_FIXTURE(Fixture, "infer_multi_return") @@ -27,8 +29,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_multi_return") const auto& [returns, tail] = flatten(takeTwoType->retType); CHECK_EQ(2, returns.size()); - CHECK_EQ(typeChecker.numberType, returns[0]); - CHECK_EQ(typeChecker.numberType, returns[1]); + CHECK_EQ(typeChecker.numberType, follow(returns[0])); + CHECK_EQ(typeChecker.numberType, follow(returns[1])); CHECK(!tail); } @@ -74,9 +76,9 @@ TEST_CASE_FIXTURE(Fixture, "last_element_of_return_statement_can_itself_be_a_pac const auto& [rets, tail] = flatten(takeOneMoreType->retType); REQUIRE_EQ(3, rets.size()); - CHECK_EQ(typeChecker.numberType, rets[0]); - CHECK_EQ(typeChecker.numberType, rets[1]); - CHECK_EQ(typeChecker.numberType, rets[2]); + CHECK_EQ(typeChecker.numberType, follow(rets[0])); + CHECK_EQ(typeChecker.numberType, follow(rets[1])); + CHECK_EQ(typeChecker.numberType, follow(rets[2])); CHECK(!tail); } @@ -91,26 +93,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function") LUAU_REQUIRE_NO_ERRORS(result); - const FunctionTypeVar* applyType = get(requireType("apply")); - REQUIRE(applyType != nullptr); - - std::vector applyArgs = flatten(applyType->argTypes).first; - REQUIRE_EQ(3, applyArgs.size()); - - const FunctionTypeVar* fType = get(follow(applyArgs[0])); - REQUIRE(fType != nullptr); - - const FunctionTypeVar* gType = get(follow(applyArgs[1])); - REQUIRE(gType != nullptr); - - std::vector gArgs = flatten(gType->argTypes).first; - REQUIRE_EQ(1, gArgs.size()); - - // function(function(t1, T2...): (t3, T4...), function(t5): (t1, T2...), t5): (t3, T4...) - - REQUIRE_EQ(*gArgs[0], *applyArgs[2]); - REQUIRE_EQ(toString(fType->argTypes), toString(gType->retType)); - REQUIRE_EQ(toString(fType->retType), toString(applyType->retType)); + CHECK_EQ("((b...) -> (c...), (a) -> (b...), a) -> (c...)", toString(requireType("apply"))); } TEST_CASE_FIXTURE(Fixture, "return_type_should_be_empty_if_nothing_is_returned") @@ -328,7 +311,10 @@ local c: Packed auto ttvA = get(requireType("a")); REQUIRE(ttvA); CHECK_EQ(toString(requireType("a")), "Packed"); - CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> number |}"); + else + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}"); REQUIRE(ttvA->instantiatedTypeParams.size() == 1); REQUIRE(ttvA->instantiatedTypePackParams.size() == 1); CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number"); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index ff207a1..96bdd53 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -6,6 +6,7 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauEqConstraint) using namespace Luau; @@ -254,11 +255,11 @@ local c = bf.a.y TEST_CASE_FIXTURE(Fixture, "optional_union_functions") { CheckResult result = check(R"( -local a = {} -function a.foo(x:number, y:number) return x + y end -type A = typeof(a) -local b: A? = a -local c = b.foo(1, 2) + local a = {} + function a.foo(x:number, y:number) return x + y end + type A = typeof(a) + local b: A? = a + local c = b.foo(1, 2) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -356,7 +357,10 @@ a.x = 2 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0])); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ("Value of type '{| x: number, y: number |}?' could be nil", toString(result.errors[0])); + else + CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "optional_length_error") @@ -533,8 +537,13 @@ TEST_CASE_FIXTURE(Fixture, "table_union_write_indirect") LUAU_REQUIRE_ERROR_COUNT(1, result); // NOTE: union normalization will improve this message - CHECK_EQ(toString(result.errors[0]), - R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); + if (FFlag::LuauLowerBoundsCalculation) + CHECK_EQ(toString(result.errors[0]), "Type '(string) -> number' could not be converted into '(number) -> string'\n" + "caused by:\n" + " Argument #1 type is not compatible. Type 'number' could not be converted into 'string'"); + else + CHECK_EQ(toString(result.errors[0]), + R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); } diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua index ab9be42..c817645 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/nextvar.lua @@ -581,4 +581,19 @@ do assert(#arr == 5) end +-- test boundary invariant maintenance when table is filled using SETLIST opcode +do + local arr = {[2]=2,1} + assert(#arr == 2) +end + +-- test boundary invariant maintenance when table is filled using table.move +do + local t1 = {1, 2, 3, 4, 5} + local t2 = {[6] = 6} + + table.move(t1, 1, 5, 1, t2) + assert(#t2 == 6) +end + return"OK" From 25f90eae7d0650aaf1509880d8b26ac3c4cbebdb Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 14 Apr 2022 16:48:36 -0700 Subject: [PATCH 03/19] Fix test in debug --- tests/TypeInfer.provisional.test.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 8c8059d..6b3741f 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -636,6 +636,10 @@ TEST_CASE_FIXTURE(Fixture, "lower_bounds_calculation_is_too_permissive_with_over // Once fixed, move this to Normalize.test.cpp TEST_CASE_FIXTURE(Fixture, "normalization_fails_on_certain_kinds_of_cyclic_tables") { +#if defined(_DEBUG) || defined(_NOOPT) + ScopedFastInt sfi("LuauNormalizeIterationLimit", 500); +#endif + ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, }; From f2677f697574da4f634ba9c3c322f3a4a6541262 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 21 Apr 2022 14:04:22 -0700 Subject: [PATCH 04/19] Sync to upstream/release/524 --- Analysis/include/Luau/Clone.h | 2 +- Analysis/include/Luau/Frontend.h | 21 +- Analysis/include/Luau/TypeVar.h | 6 +- Analysis/include/Luau/Unifier.h | 1 - Analysis/include/Luau/VisitTypeVar.h | 2 +- Analysis/src/Autocomplete.cpp | 4 +- Analysis/src/Clone.cpp | 42 ++- Analysis/src/Error.cpp | 23 +- Analysis/src/Frontend.cpp | 50 ++-- Analysis/src/Module.cpp | 18 +- Analysis/src/Normalize.cpp | 44 +-- Analysis/src/Substitution.cpp | 15 +- Analysis/src/ToDot.cpp | 2 +- Analysis/src/Transpiler.cpp | 53 ++-- Analysis/src/TypeAttach.cpp | 14 + Analysis/src/TypeInfer.cpp | 37 ++- Analysis/src/TypeVar.cpp | 6 + Analysis/src/Unifier.cpp | 110 ++------ Ast/include/Luau/StringUtils.h | 1 + Ast/src/Parser.cpp | 8 + Ast/src/StringUtils.cpp | 2 +- CLI/Repl.cpp | 72 ++++- CMakeLists.txt | 6 + Compiler/include/Luau/BytecodeBuilder.h | 7 + Compiler/src/BytecodeBuilder.cpp | 70 ++++- Compiler/src/Compiler.cpp | 127 ++++++++- Compiler/src/ConstantFolding.cpp | 13 +- Compiler/src/ConstantFolding.h | 3 +- VM/src/lapi.cpp | 24 +- VM/src/lgc.cpp | 26 +- VM/src/lgc.h | 2 +- VM/src/lstate.h | 7 +- bench/bench.py | 13 +- fuzz/proto.cpp | 13 +- tests/Autocomplete.test.cpp | 36 +++ tests/Compiler.test.cpp | 359 +++++++++++++++++++++++- tests/CostModel.test.cpp | 125 +++++++++ tests/JsonEncoder.test.cpp | 332 +++++++++++++++++++++- tests/Linter.test.cpp | 2 +- tests/Module.test.cpp | 24 +- tests/NonstrictMode.test.cpp | 46 ++- tests/Normalize.test.cpp | 44 ++- tests/ToDot.test.cpp | 4 +- tests/Transpiler.test.cpp | 17 ++ tests/TypeInfer.classes.test.cpp | 34 ++- tests/TypeInfer.definitions.test.cpp | 2 - tests/TypeInfer.functions.test.cpp | 30 +- tests/TypeInfer.modules.test.cpp | 4 - tests/TypeInfer.provisional.test.cpp | 18 +- tests/TypeInfer.refinements.test.cpp | 8 +- tests/TypeInfer.tables.test.cpp | 2 - tests/TypeInfer.test.cpp | 18 +- tests/TypeVar.test.cpp | 6 +- 53 files changed, 1600 insertions(+), 355 deletions(-) diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 78aa92c..9b6ffa6 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -18,7 +18,7 @@ struct CloneState SeenTypePacks seenTypePacks; int recursionCount = 0; - bool encounteredFreeType = false; + bool encounteredFreeType = false; // TODO: Remove with LuauLosslessClone. }; TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index e24e433..5912547 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -13,6 +13,7 @@ #include LUAU_FASTFLAG(LuauSeparateTypechecks) +LUAU_FASTFLAG(LuauDirtySourceModule) namespace Luau { @@ -57,19 +58,27 @@ std::optional pathExprToModuleName(const ModuleName& currentModuleNa struct SourceNode { - bool isDirty(bool forAutocomplete) const + bool hasDirtySourceModule() const + { + LUAU_ASSERT(FFlag::LuauDirtySourceModule); + + return dirtySourceModule; + } + + bool hasDirtyModule(bool forAutocomplete) const { if (FFlag::LuauSeparateTypechecks) - return forAutocomplete ? dirtyAutocomplete : dirty; + return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule; else - return dirty; + return dirtyModule; } ModuleName name; std::unordered_set requires; std::vector> requireLocations; - bool dirty = true; - bool dirtyAutocomplete = true; + bool dirtySourceModule = true; + bool dirtyModule = true; + bool dirtyModuleForAutocomplete = true; double autocompleteLimitsMult = 1.0; }; @@ -163,7 +172,7 @@ struct Frontend void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); private: - std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete); + std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete_DEPRECATED); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); bool parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index ae7d137..8457675 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -373,15 +373,17 @@ struct ClassTypeVar std::optional metatable; // metaclass? Tags tags; std::shared_ptr userData; + ModuleName definitionModuleName; - ClassTypeVar( - Name name, Props props, std::optional parent, std::optional metatable, Tags tags, std::shared_ptr userData) + ClassTypeVar(Name name, Props props, std::optional parent, std::optional metatable, Tags tags, + std::shared_ptr userData, ModuleName definitionModuleName) : name(name) , props(props) , parent(parent) , metatable(metatable) , tags(tags) , userData(userData) + , definitionModuleName(definitionModuleName) { } }; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 340feb7..418d4ca 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -92,7 +92,6 @@ private: bool canCacheResult(TypeId subTy, TypeId superTy); void cacheResult(TypeId subTy, TypeId superTy, size_t prevErrorCount); - void cacheResult_DEPRECATED(TypeId subTy, TypeId superTy); public: void tryUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index d11cbd0..045190e 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -52,7 +52,7 @@ inline void unsee(std::unordered_set& seen, const void* tv) inline void unsee(DenseHashSet& seen, const void* tv) { - // When DenseHashSet is used for 'visitOnce', where don't forget visited elements + // When DenseHashSet is used for 'visitTypeVarOnce', where don't forget visited elements } template diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index e0e79cb..dec12d0 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteSingletonTypes, false); +LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteClassSecurityLevel, false); LUAU_FASTFLAG(LuauSelfCallAutocompleteFix) static const std::unordered_set kStatementStartingKeywords = { @@ -462,7 +463,8 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId containingClass = containingClass.value_or(cls); fillProps(cls->props); if (cls->parent) - autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, cls); + autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, + FFlag::LuauFixAutocompleteClassSecurityLevel ? containingClass : cls); } else if (auto tbl = get(ty)) fillProps(tbl->props); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 8e7f7c0..d5bd9da 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -10,6 +10,7 @@ LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAG(LuauTypecheckOptPass) +LUAU_FASTFLAGVARIABLE(LuauLosslessClone, false) namespace Luau { @@ -87,11 +88,18 @@ struct TypePackCloner void operator()(const Unifiable::Free& t) { - cloneState.encounteredFreeType = true; + if (FFlag::LuauLosslessClone) + { + defaultClone(t); + } + else + { + cloneState.encounteredFreeType = true; - TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); - TypePackId cloned = dest.addTypePack(*err); - seenTypePacks[typePackId] = cloned; + TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); + TypePackId cloned = dest.addTypePack(*err); + seenTypePacks[typePackId] = cloned; + } } void operator()(const Unifiable::Generic& t) @@ -143,10 +151,18 @@ void TypeCloner::defaultClone(const T& t) void TypeCloner::operator()(const Unifiable::Free& t) { - cloneState.encounteredFreeType = true; - TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); - TypeId cloned = dest.addType(*err); - seenTypes[typeId] = cloned; + if (FFlag::LuauLosslessClone) + { + defaultClone(t); + } + else + { + cloneState.encounteredFreeType = true; + + TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); + TypeId cloned = dest.addType(*err); + seenTypes[typeId] = cloned; + } } void TypeCloner::operator()(const Unifiable::Generic& t) @@ -174,7 +190,8 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t) void TypeCloner::operator()(const ConstrainedTypeVar& t) { - cloneState.encounteredFreeType = true; + if (!FFlag::LuauLosslessClone) + cloneState.encounteredFreeType = true; TypeId res = dest.addType(ConstrainedTypeVar{t.level}); ConstrainedTypeVar* ctv = getMutable(res); @@ -252,7 +269,7 @@ void TypeCloner::operator()(const TableTypeVar& t) for (TypePackId& arg : ttv->instantiatedTypePackParams) arg = clone(arg, dest, cloneState); - if (ttv->state == TableState::Free) + if (!FFlag::LuauLosslessClone && ttv->state == TableState::Free) { cloneState.encounteredFreeType = true; @@ -276,7 +293,7 @@ void TypeCloner::operator()(const MetatableTypeVar& t) void TypeCloner::operator()(const ClassTypeVar& t) { - TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); + TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData, t.definitionModuleName}); ClassTypeVar* ctv = getMutable(result); seenTypes[typeId] = result; @@ -361,7 +378,10 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) // Persistent types are not being cloned and we get the original type back which might be read-only if (!res->persistent) + { asMutable(res)->documentationSymbol = typeId->documentationSymbol; + asMutable(res)->normal = typeId->normal; + } } return res; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index cbec0b1..24ed4ac 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -8,8 +8,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleName, false); - static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { std::string s = "expects "; @@ -59,27 +57,20 @@ struct ErrorConverter std::string result; - if (FFlag::LuauTypeMismatchModuleName) + if (givenTypeName == wantedTypeName) { - if (givenTypeName == wantedTypeName) + if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType)) { - if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType)) + if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) { - if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) - { - result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + - "' from '" + *wantedDefinitionModule + "'"; - } + result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + + "' from '" + *wantedDefinitionModule + "'"; } } + } - if (result.empty()) - result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; - } - else - { + if (result.empty()) result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; - } if (tm.error) { diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 8b0b221..34ccdac 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -23,6 +23,7 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false) LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) +LUAU_FASTFLAGVARIABLE(LuauDirtySourceModule, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 0) namespace Luau @@ -358,7 +359,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalsecond.isDirty(frontendOptions.forAutocomplete)) + if (it != sourceNodes.end() && !it->second.hasDirtyModule(frontendOptions.forAutocomplete)) { // No recheck required. if (FFlag::LuauSeparateTypechecks) @@ -402,7 +403,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalerrors.begin(), module->errors.end()); moduleResolver.modules[moduleName] = std::move(module); - sourceNode.dirty = false; + sourceNode.dirtyModule = false; } return checkResult; @@ -618,7 +619,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec // this relies on the fact that markDirty marks reverse-dependencies dirty as well // thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need // to be built, *and* can't form a cycle with any nodes we did process. - if (!it->second.isDirty(forAutocomplete)) + if (!it->second.hasDirtyModule(forAutocomplete)) continue; // note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization @@ -768,7 +769,7 @@ LintResult Frontend::lint(const SourceModule& module, std::optionalsecond.isDirty(forAutocomplete); + return it == sourceNodes.end() || it->second.hasDirtyModule(forAutocomplete); } /* @@ -810,20 +811,31 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked if (markedDirty) markedDirty->push_back(next); - if (FFlag::LuauSeparateTypechecks) + if (FFlag::LuauDirtySourceModule) { - if (sourceNode.dirty && sourceNode.dirtyAutocomplete) + LUAU_ASSERT(FFlag::LuauSeparateTypechecks); + + if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) continue; - sourceNode.dirty = true; - sourceNode.dirtyAutocomplete = true; + sourceNode.dirtySourceModule = true; + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; + } + else if (FFlag::LuauSeparateTypechecks) + { + if (sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) + continue; + + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; } else { - if (sourceNode.dirty) + if (sourceNode.dirtyModule) continue; - sourceNode.dirty = true; + sourceNode.dirtyModule = true; } if (0 == reverseDeps.count(name)) @@ -851,13 +863,14 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons } // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. -std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete) +std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete_DEPRECATED) { LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && !it->second.isDirty(forAutocomplete)) + if (it != sourceNodes.end() && + (FFlag::LuauDirtySourceModule ? !it->second.hasDirtySourceModule() : !it->second.hasDirtyModule(forAutocomplete_DEPRECATED))) { auto moduleIt = sourceModules.find(name); if (moduleIt != sourceModules.end()) @@ -901,17 +914,20 @@ std::pair Frontend::getSourceNode(CheckResult& check sourceNode.requires.clear(); sourceNode.requireLocations.clear(); + if (FFlag::LuauDirtySourceModule) + sourceNode.dirtySourceModule = false; + if (FFlag::LuauSeparateTypechecks) { if (it == sourceNodes.end()) { - sourceNode.dirty = true; - sourceNode.dirtyAutocomplete = true; + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; } } else { - sourceNode.dirty = true; + sourceNode.dirtyModule = true; } for (const auto& [moduleName, location] : requireTrace.requires) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index e2e3b43..bafd437 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,8 +14,8 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) -LUAU_FASTFLAGVARIABLE(LuauCloneDeclaredGlobals, false) LUAU_FASTFLAG(LuauLowerBoundsCalculation) +LUAU_FASTFLAG(LuauLosslessClone) namespace Luau { @@ -182,20 +182,20 @@ bool Module::clonePublicInterface(InternalErrorReporter& ice) } } - if (FFlag::LuauCloneDeclaredGlobals) + for (auto& [name, ty] : declaredGlobals) { - for (auto& [name, ty] : declaredGlobals) - { - ty = clone(ty, interfaceTypes, cloneState); - if (FFlag::LuauLowerBoundsCalculation) - normalize(ty, interfaceTypes, ice); - } + ty = clone(ty, interfaceTypes, cloneState); + if (FFlag::LuauLowerBoundsCalculation) + normalize(ty, interfaceTypes, ice); } freeze(internalTypes); freeze(interfaceTypes); - return cloneState.encounteredFreeType; + if (FFlag::LuauLosslessClone) + return false; // TODO: make function return void. + else + return cloneState.encounteredFreeType; } } // namespace Luau diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 40341ac..043526e 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -5,7 +5,6 @@ #include #include "Luau/Clone.h" -#include "Luau/DenseHash.h" #include "Luau/Substitution.h" #include "Luau/Unifier.h" #include "Luau/VisitTypeVar.h" @@ -254,7 +253,7 @@ bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice) } template -static bool areNormal_(const T& t, const DenseHashSet& seen, InternalErrorReporter& ice) +static bool areNormal_(const T& t, const std::unordered_set& seen, InternalErrorReporter& ice) { int count = 0; auto isNormal = [&](TypeId ty) { @@ -262,18 +261,19 @@ static bool areNormal_(const T& t, const DenseHashSet& seen, InternalErro if (count >= FInt::LuauNormalizeIterationLimit) ice.ice("Luau::areNormal hit iteration limit"); - return ty->normal || seen.find(asMutable(ty)); + // The follow is here because a bound type may not be normal, but the bound type is normal. + return ty->normal || follow(ty)->normal || seen.find(asMutable(ty)) != seen.end(); }; return std::all_of(begin(t), end(t), isNormal); } -static bool areNormal(const std::vector& types, const DenseHashSet& seen, InternalErrorReporter& ice) +static bool areNormal(const std::vector& types, const std::unordered_set& seen, InternalErrorReporter& ice) { return areNormal_(types, seen, ice); } -static bool areNormal(TypePackId tp, const DenseHashSet& seen, InternalErrorReporter& ice) +static bool areNormal(TypePackId tp, const std::unordered_set& seen, InternalErrorReporter& ice) { tp = follow(tp); if (get(tp)) @@ -288,7 +288,7 @@ static bool areNormal(TypePackId tp, const DenseHashSet& seen, InternalEr return true; if (auto vtp = get(*tail)) - return vtp->ty->normal || seen.find(asMutable(vtp->ty)); + return vtp->ty->normal || follow(vtp->ty)->normal || seen.find(asMutable(vtp->ty)) != seen.end(); return true; } @@ -335,9 +335,14 @@ struct Normalize return false; } - bool operator()(TypeId ty, const BoundTypeVar& btv) + bool operator()(TypeId ty, const BoundTypeVar& btv, std::unordered_set& seen) { - // It should never be the case that this TypeVar is normal, but is bound to a non-normal type. + // A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses. + // So we need to avoid eagerly saying that this bound type is normal if the thing it is bound to is in the stack. + if (seen.find(asMutable(btv.boundTo)) != seen.end()) + return false; + + // It should never be the case that this TypeVar is normal, but is bound to a non-normal type, except in nontrivial cases. LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal); asMutable(ty)->normal = btv.boundTo->normal; @@ -365,7 +370,7 @@ struct Normalize return false; } - bool operator()(TypeId ty, const ConstrainedTypeVar& ctvRef, DenseHashSet& seen) + bool operator()(TypeId ty, const ConstrainedTypeVar& ctvRef, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -391,8 +396,7 @@ struct Normalize return false; } - bool operator()(TypeId ty, const FunctionTypeVar& ftv) = delete; - bool operator()(TypeId ty, const FunctionTypeVar& ftv, DenseHashSet& seen) + bool operator()(TypeId ty, const FunctionTypeVar& ftv, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -407,7 +411,7 @@ struct Normalize return false; } - bool operator()(TypeId ty, const TableTypeVar& ttv, DenseHashSet& seen) + bool operator()(TypeId ty, const TableTypeVar& ttv, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -419,7 +423,7 @@ struct Normalize auto checkNormal = [&](TypeId t) { // if t is on the stack, it is possible that this type is normal. // If t is not normal and it is not on the stack, this type is definitely not normal. - if (!t->normal && !seen.find(asMutable(t))) + if (!t->normal && seen.find(asMutable(t)) == seen.end()) normal = false; }; @@ -449,7 +453,7 @@ struct Normalize return false; } - bool operator()(TypeId ty, const MetatableTypeVar& mtv, DenseHashSet& seen) + bool operator()(TypeId ty, const MetatableTypeVar& mtv, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -477,7 +481,7 @@ struct Normalize return false; } - bool operator()(TypeId ty, const UnionTypeVar& utvRef, DenseHashSet& seen) + bool operator()(TypeId ty, const UnionTypeVar& utvRef, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -507,7 +511,7 @@ struct Normalize return false; } - bool operator()(TypeId ty, const IntersectionTypeVar& itvRef, DenseHashSet& seen) + bool operator()(TypeId ty, const IntersectionTypeVar& itvRef, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -775,8 +779,8 @@ std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorRepo (void)clone(ty, arena, state); Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; - DenseHashSet seen{nullptr}; - visitTypeVarOnce(ty, n, seen); + std::unordered_set seen; + visitTypeVar(ty, n, seen); return {ty, !n.limitExceeded}; } @@ -800,8 +804,8 @@ std::pair normalize(TypePackId tp, TypeArena& arena, InternalE (void)clone(tp, arena, state); Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; - DenseHashSet seen{nullptr}; - visitTypeVarOnce(tp, n, seen); + std::unordered_set seen; + visitTypeVar(tp, n, seen); return {tp, !n.limitExceeded}; } diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 8648b21..1b51fa3 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -11,6 +11,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) LUAU_FASTFLAG(LuauTypecheckOptPass) LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowNewTypes, false) +LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowPossibleMutations, false) namespace Luau { @@ -106,7 +107,7 @@ void Tarjan::visitChildren(TypePackId tp, int index) std::pair Tarjan::indexify(TypeId ty) { - if (FFlag::LuauTypecheckOptPass) + if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) LUAU_ASSERT(ty == log->follow(ty)); else ty = log->follow(ty); @@ -127,7 +128,7 @@ std::pair Tarjan::indexify(TypeId ty) std::pair Tarjan::indexify(TypePackId tp) { - if (FFlag::LuauTypecheckOptPass) + if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) LUAU_ASSERT(tp == log->follow(tp)); else tp = log->follow(tp); @@ -148,7 +149,8 @@ std::pair Tarjan::indexify(TypePackId tp) void Tarjan::visitChild(TypeId ty) { - ty = log->follow(ty); + if (!FFlag::LuauSubstituteFollowPossibleMutations) + ty = log->follow(ty); edgesTy.push_back(ty); edgesTp.push_back(nullptr); @@ -156,7 +158,8 @@ void Tarjan::visitChild(TypeId ty) void Tarjan::visitChild(TypePackId tp) { - tp = log->follow(tp); + if (!FFlag::LuauSubstituteFollowPossibleMutations) + tp = log->follow(tp); edgesTy.push_back(nullptr); edgesTp.push_back(tp); @@ -471,7 +474,7 @@ TypePackId Substitution::clone(TypePackId tp) void Substitution::foundDirty(TypeId ty) { - if (FFlag::LuauTypecheckOptPass) + if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) LUAU_ASSERT(ty == log->follow(ty)); else ty = log->follow(ty); @@ -484,7 +487,7 @@ void Substitution::foundDirty(TypeId ty) void Substitution::foundDirty(TypePackId tp) { - if (FFlag::LuauTypecheckOptPass) + if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) LUAU_ASSERT(tp == log->follow(tp)); else tp = log->follow(tp); diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index cb54bfc..9b396c8 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -327,7 +327,7 @@ void StateDot::visitChildren(TypePackId tp, int index) } else if (const VariadicTypePack* vtp = get(tp)) { - formatAppend(result, "VariadicTypePack %d", index); + formatAppend(result, "VariadicTypePack %s%d", vtp->hidden ? "hidden " : "", index); finishNodeLabel(tp); finishNode(); diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 92ed241..1577bd6 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -1025,31 +1025,42 @@ struct Printer } else if (const auto& a = typeAnnotation.as()) { - CommaSeparatorInserter comma(writer); + AstTypeReference* indexType = a->indexer ? a->indexer->indexType->as() : nullptr; - writer.symbol("{"); - - for (std::size_t i = 0; i < a->props.size; ++i) + if (a->props.size == 0 && indexType && indexType->name == "number") { - comma(); - advance(a->props.data[i].location.begin); - writer.identifier(a->props.data[i].name.value); - if (a->props.data[i].type) - { - writer.symbol(":"); - visualizeTypeAnnotation(*a->props.data[i].type); - } - } - if (a->indexer) - { - comma(); - writer.symbol("["); - visualizeTypeAnnotation(*a->indexer->indexType); - writer.symbol("]"); - writer.symbol(":"); + writer.symbol("{"); visualizeTypeAnnotation(*a->indexer->resultType); + writer.symbol("}"); + } + else + { + CommaSeparatorInserter comma(writer); + + writer.symbol("{"); + + for (std::size_t i = 0; i < a->props.size; ++i) + { + comma(); + advance(a->props.data[i].location.begin); + writer.identifier(a->props.data[i].name.value); + if (a->props.data[i].type) + { + writer.symbol(":"); + visualizeTypeAnnotation(*a->props.data[i].type); + } + } + if (a->indexer) + { + comma(); + writer.symbol("["); + visualizeTypeAnnotation(*a->indexer->indexType); + writer.symbol("]"); + writer.symbol(":"); + visualizeTypeAnnotation(*a->indexer->resultType); + } + writer.symbol("}"); } - writer.symbol("}"); } else if (auto a = typeAnnotation.as()) { diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index bc8d0d4..0f4534b 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -479,6 +479,20 @@ public: { return visitLocal(al->local); } + + virtual bool visit(AstStatFor* stat) override + { + visitLocal(stat->var); + return true; + } + + virtual bool visit(AstStatForIn* stat) override + { + for (size_t i = 0; i < stat->vars.size; ++i) + visitLocal(stat->vars.data[i]); + return true; + } + virtual bool visit(AstExprFunction* fn) override { // TODO: add generics if the inferred type of the function is generic CLI-39908 diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index af42a4e..6411e2a 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" +#include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/ModuleResolver.h" #include "Luau/Normalize.h" @@ -47,7 +48,6 @@ LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) -LUAU_FASTFLAG(LuauTypeMismatchModuleName) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. @@ -61,7 +61,9 @@ LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false) LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false) LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false) +LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); +LUAU_FASTFLAG(LuauLosslessClone) namespace Luau { @@ -376,7 +378,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo prepareErrorsForDisplay(currentModule->errors); bool encounteredFreeType = currentModule->clonePublicInterface(*iceHandler); - if (encounteredFreeType) + if (!FFlag::LuauLosslessClone && encounteredFreeType) { reportError(TypeError{module.root->location, GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}}); @@ -785,7 +787,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; - if (useConstrainedIntersections()) + if (FFlag::LuauReturnTypeInferenceInNonstrict ? FFlag::LuauLowerBoundsCalculation : useConstrainedIntersections()) { unifyLowerBound(retPack, scope->returnType, return_.location); return; @@ -1241,7 +1243,12 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco // If in nonstrict mode and allowing redefinition of global function, restore the previous definition type // in case this function has a differing signature. The signature discrepancy will be caught in checkBlock. if (previouslyDefined) + { + if (FFlag::LuauReturnTypeInferenceInNonstrict && FFlag::LuauLowerBoundsCalculation) + quantify(funScope, ty, exprName->location); + globalBindings[name] = oldBinding; + } else globalBindings[name] = {quantify(funScope, ty, exprName->location), exprName->location}; @@ -1555,7 +1562,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar Name className(declaredClass.name.value); - TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {})); + TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); ClassTypeVar* ctv = getMutable(classTy); TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level}); @@ -3284,7 +3291,7 @@ std::pair TypeChecker::checkFunctionSignature( TypePackId retPack; if (expr.returnAnnotation) retPack = resolveTypePack(funScope, *expr.returnAnnotation); - else if (isNonstrictMode()) + else if (FFlag::LuauReturnTypeInferenceInNonstrict ? (!FFlag::LuauLowerBoundsCalculation && isNonstrictMode()) : isNonstrictMode()) retPack = anyTypePack; else if (expectedFunctionType) { @@ -5328,19 +5335,9 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (const auto& indexer = table->indexer) tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); - if (FFlag::LuauTypeMismatchModuleName) - { - TableTypeVar ttv{props, tableIndexer, scope->level, TableState::Sealed}; - ttv.definitionModuleName = currentModuleName; - return addType(std::move(ttv)); - } - else - { - return addType(TableTypeVar{ - props, tableIndexer, scope->level, - TableState::Sealed // FIXME: probably want a way to annotate other kinds of tables maybe - }); - } + TableTypeVar ttv{props, tableIndexer, scope->level, TableState::Sealed}; + ttv.definitionModuleName = currentModuleName; + return addType(std::move(ttv)); } else if (const auto& func = annotation.as()) { @@ -5602,9 +5599,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, { ttv->instantiatedTypeParams = typeParams; ttv->instantiatedTypePackParams = typePackParams; - - if (FFlag::LuauTypeMismatchModuleName) - ttv->definitionModuleName = currentModuleName; + ttv->definitionModuleName = currentModuleName; } return instantiated; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 0fbfdbf..4d42573 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -27,6 +27,7 @@ LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) LUAU_FASTFLAG(LuauDiscriminableUnions2) LUAU_FASTFLAGVARIABLE(LuauAnyInIsOptionalIsOptional, false) +LUAU_FASTFLAGVARIABLE(LuauClassDefinitionModuleInError, false) namespace Luau { @@ -304,6 +305,11 @@ std::optional getDefinitionModuleName(TypeId type) if (ftv->definition) return ftv->definition->definitionModuleName; } + else if (auto ctv = get(type); ctv && FFlag::LuauClassDefinitionModuleInError) + { + if (!ctv->definitionModuleName.empty()) + return ctv->definitionModuleName; + } return std::nullopt; } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index f9ea58c..9862d7b 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -17,7 +17,6 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINT(LuauTypeInferIterationLimit); LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); -LUAU_FASTFLAGVARIABLE(LuauExtendedIndexerError, false); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); @@ -28,7 +27,6 @@ LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogCheckForInvalidation, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false) -LUAU_FASTFLAGVARIABLE(LuauUnifierCacheErrors, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAG(LuauTypecheckOptPass) @@ -474,32 +472,21 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (get(subTy)) return tryUnifyWithAny(superTy, subTy); - bool cacheEnabled; auto& cache = sharedState.cachedUnify; // What if the types are immutable and we proved their relation before - if (FFlag::LuauUnifierCacheErrors) + bool cacheEnabled = !isFunctionCall && !isIntersection && variance == Invariant; + + if (cacheEnabled) { - cacheEnabled = !isFunctionCall && !isIntersection && variance == Invariant; - - if (cacheEnabled) - { - if (cache.contains({subTy, superTy})) - return; - - if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) - { - reportError(TypeError{location, *error}); - return; - } - } - } - else - { - cacheEnabled = !isFunctionCall && !isIntersection; - - if (cacheEnabled && cache.contains({superTy, subTy}) && (variance == Covariant || cache.contains({subTy, superTy}))) + if (cache.contains({subTy, superTy})) return; + + if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) + { + reportError(TypeError{location, *error}); + return; + } } // If we have seen this pair of types before, we are currently recursing into cyclic types. @@ -543,12 +530,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.getMutable(superTy) && log.getMutable(subTy)) { tryUnifyTables(subTy, superTy, isIntersection); - - if (!FFlag::LuauUnifierCacheErrors) - { - if (cacheEnabled && errors.empty()) - cacheResult_DEPRECATED(subTy, superTy); - } } // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. @@ -568,7 +549,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - if (FFlag::LuauUnifierCacheErrors && cacheEnabled) + if (cacheEnabled) cacheResult(subTy, superTy, errorCount); log.popSeen(superTy, subTy); @@ -705,21 +686,10 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { TypeId type = uv->options[i]; - if (FFlag::LuauUnifierCacheErrors) + if (cache.contains({subTy, type})) { - if (cache.contains({subTy, type})) - { - startIndex = i; - break; - } - } - else - { - if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) - { - startIndex = i; - break; - } + startIndex = i; + break; } } } @@ -807,21 +777,10 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV { TypeId type = uv->parts[i]; - if (FFlag::LuauUnifierCacheErrors) + if (cache.contains({type, superTy})) { - if (cache.contains({type, superTy})) - { - startIndex = i; - break; - } - } - else - { - if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) - { - startIndex = i; - break; - } + startIndex = i; + break; } } } @@ -896,19 +855,6 @@ void Unifier::cacheResult(TypeId subTy, TypeId superTy, size_t prevErrorCount) } } -void Unifier::cacheResult_DEPRECATED(TypeId subTy, TypeId superTy) -{ - LUAU_ASSERT(!FFlag::LuauUnifierCacheErrors); - - if (!canCacheResult(subTy, superTy)) - return; - - sharedState.cachedUnify.insert({superTy, subTy}); - - if (variance == Invariant) - sharedState.cachedUnify.insert({subTy, superTy}); -} - struct WeirdIter { TypePackId packId; @@ -1650,24 +1596,16 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) Unifier innerState = makeChildUnifier(); - if (FFlag::LuauExtendedIndexerError) - { - innerState.tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); + innerState.tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); - bool reported = !innerState.errors.empty(); + bool reported = !innerState.errors.empty(); - checkChildUnifierTypeMismatch(innerState.errors, "[indexer key]", superTy, subTy); + checkChildUnifierTypeMismatch(innerState.errors, "[indexer key]", superTy, subTy); - innerState.tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); + innerState.tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); - if (!reported) - checkChildUnifierTypeMismatch(innerState.errors, "[indexer value]", superTy, subTy); - } - else - { - innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); - } + if (!reported) + checkChildUnifierTypeMismatch(innerState.errors, "[indexer value]", superTy, subTy); if (innerState.errors.empty()) log.concat(std::move(innerState.log)); @@ -2225,7 +2163,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) void Unifier::tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer) { - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2 || !FFlag::LuauExtendedIndexerError); + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); tryUnify_(subIndexer.indexType, superIndexer.indexType); tryUnify_(subIndexer.indexResultType, superIndexer.indexResultType); diff --git a/Ast/include/Luau/StringUtils.h b/Ast/include/Luau/StringUtils.h index 6ecf060..6ae9e97 100644 --- a/Ast/include/Luau/StringUtils.h +++ b/Ast/include/Luau/StringUtils.h @@ -19,6 +19,7 @@ std::string format(const char* fmt, ...) LUAU_PRINTF_ATTR(1, 2); std::string vformat(const char* fmt, va_list args); void formatAppend(std::string& str, const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); +void vformatAppend(std::string& ret, const char* fmt, va_list args); std::string join(const std::vector& segments, std::string_view delimiter); std::string join(const std::vector& segments, std::string_view delimiter); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index badd3fd..31ff3f7 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -167,6 +167,7 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc Function top; top.vararg = true; + functionStack.reserve(8); functionStack.push_back(top); nameSelf = names.addStatic("self"); @@ -186,6 +187,13 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc // all hot comments parsed after the first non-comment lexeme are special in that they don't affect type checking / linting mode hotcommentHeader = false; + + // preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays + localStack.reserve(16); + scratchStat.reserve(16); + scratchExpr.reserve(16); + scratchLocal.reserve(16); + scratchBinding.reserve(16); } bool Parser::blockFollow(const Lexeme& l) diff --git a/Ast/src/StringUtils.cpp b/Ast/src/StringUtils.cpp index 9c7fed3..0dc3f3f 100644 --- a/Ast/src/StringUtils.cpp +++ b/Ast/src/StringUtils.cpp @@ -11,7 +11,7 @@ namespace Luau { -static void vformatAppend(std::string& ret, const char* fmt, va_list args) +void vformatAppend(std::string& ret, const char* fmt, va_list args) { va_list argscopy; va_copy(argscopy, args); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 5fd6d34..345cb7a 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -579,7 +579,8 @@ static bool compileFile(const char* name, CompileFormat format) if (format == CompileFormat::Text) { - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals); + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | + Luau::BytecodeBuilder::Dump_Remarks); bcb.setDumpSource(*source); } @@ -636,13 +637,60 @@ static int assertionHandler(const char* expr, const char* file, int line, const return 1; } +static void setLuauFlags(bool state) +{ + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + { + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = state; + } +} + +static void setFlag(std::string_view name, bool state) +{ + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + { + if (name == flag->name) + { + flag->value = state; + return; + } + } + + fprintf(stderr, "Warning: --fflag unrecognized flag '%.*s'.\n\n", int(name.length()), name.data()); +} + +static void applyFlagKeyValue(std::string_view element) +{ + if (size_t separator = element.find('='); separator != std::string_view::npos) + { + std::string_view key = element.substr(0, separator); + std::string_view value = element.substr(separator + 1); + + if (value == "true") + setFlag(key, true); + else if (value == "false") + setFlag(key, false); + else + fprintf(stderr, "Warning: --fflag unrecognized value '%.*s' for flag '%.*s'.\n\n", int(value.length()), value.data(), int(key.length()), + key.data()); + } + else + { + if (element == "true") + setLuauFlags(true); + else if (element == "false") + setLuauFlags(false); + else + setFlag(element, true); + } +} + int replMain(int argc, char** argv) { Luau::assertHandler() = assertionHandler; - for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) - if (strncmp(flag->name, "Luau", 4) == 0) - flag->value = true; + setLuauFlags(true); CliMode mode = CliMode::Unknown; CompileFormat compileFormat{}; @@ -727,6 +775,22 @@ int replMain(int argc, char** argv) return 1; #endif } + else if (strncmp(argv[i], "--fflags=", 9) == 0) + { + std::string_view list = argv[i] + 9; + + while (!list.empty()) + { + size_t ending = list.find(","); + + applyFlagKeyValue(list.substr(0, ending)); + + if (ending != std::string_view::npos) + list.remove_prefix(ending + 1); + else + break; + } + } else if (argv[i][0] == '-') { fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]); diff --git a/CMakeLists.txt b/CMakeLists.txt index c6ccebc..af03b33 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,6 +73,12 @@ else() list(APPEND LUAU_OPTIONS -Wall) # All warnings endif() +if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + # Some gcc versions treat var in `if (type var = val)` as unused + # Some gcc versions treat variables used in constexpr if blocks as unused + list(APPEND LUAU_OPTIONS -Wno-unused) +endif() + # Enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere if(LUAU_WERROR) if(MSVC) diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index 287bf4e..67b9302 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -3,6 +3,7 @@ #include "Luau/Bytecode.h" #include "Luau/DenseHash.h" +#include "Luau/StringUtils.h" #include @@ -80,6 +81,8 @@ public: void pushDebugUpval(StringRef name); uint32_t getDebugPC() const; + void addDebugRemark(const char* format, ...) LUAU_PRINTF_ATTR(2, 3); + void finalize(); enum DumpFlags @@ -88,6 +91,7 @@ public: Dump_Lines = 1 << 1, Dump_Source = 1 << 2, Dump_Locals = 1 << 3, + Dump_Remarks = 1 << 4, }; void setDumpFlags(uint32_t flags) @@ -228,6 +232,9 @@ private: DenseHashMap stringTable; + DenseHashMap debugRemarks; + std::string debugRemarkBuffer; + BytecodeEncoder* encoder = nullptr; std::string bytecode; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 6944de0..6c6f122 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -181,9 +181,17 @@ BytecodeBuilder::BytecodeBuilder(BytecodeEncoder* encoder) : constantMap({Constant::Type_Nil, ~0ull}) , tableShapeMap(TableShape()) , stringTable({nullptr, 0}) + , debugRemarks(~0u) , encoder(encoder) { LUAU_ASSERT(stringTable.find(StringRef{"", 0}) == nullptr); + + // preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays + insns.reserve(32); + lines.reserve(32); + constants.reserve(16); + protos.reserve(16); + functions.reserve(8); } uint32_t BytecodeBuilder::beginFunction(uint8_t numparams, bool isvararg) @@ -219,8 +227,8 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues) validate(); #endif - // very approximate: 4 bytes per instruction for code, 1 byte for debug line, and 1-2 bytes for aux data like constants - func.data.reserve(insns.size() * 7); + // very approximate: 4 bytes per instruction for code, 1 byte for debug line, and 1-2 bytes for aux data like constants plus overhead + func.data.reserve(32 + insns.size() * 7); writeFunction(func.data, currentFunction); @@ -242,6 +250,9 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues) constantMap.clear(); tableShapeMap.clear(); + + debugRemarks.clear(); + debugRemarkBuffer.clear(); } void BytecodeBuilder::setMainFunction(uint32_t fid) @@ -505,9 +516,40 @@ uint32_t BytecodeBuilder::getDebugPC() const return uint32_t(insns.size()); } +void BytecodeBuilder::addDebugRemark(const char* format, ...) +{ + if ((dumpFlags & Dump_Remarks) == 0) + return; + + size_t offset = debugRemarkBuffer.size(); + + va_list args; + va_start(args, format); + vformatAppend(debugRemarkBuffer, format, args); + va_end(args); + + // we null-terminate all remarks to avoid storing remark length + debugRemarkBuffer += '\0'; + + debugRemarks[uint32_t(insns.size())] = uint32_t(offset); +} + void BytecodeBuilder::finalize() { LUAU_ASSERT(bytecode.empty()); + + // preallocate space for bytecode blob + size_t capacity = 16; + + for (auto& p : stringTable) + capacity += p.first.length + 2; + + for (const Function& func : functions) + capacity += func.data.size(); + + bytecode.reserve(capacity); + + // assemble final bytecode blob bytecode = char(LBC_VERSION); writeStringTable(bytecode); @@ -663,6 +705,8 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id) const void BytecodeBuilder::writeLineInfo(std::string& ss) const { + LUAU_ASSERT(!lines.empty()); + // this function encodes lines inside each span as a 8-bit delta to span baseline // span is always a power of two; depending on the line info input, it may need to be as low as 1 int span = 1 << 24; @@ -693,7 +737,17 @@ void BytecodeBuilder::writeLineInfo(std::string& ss) const } // second pass: compute span base - std::vector baseline((lines.size() - 1) / span + 1); + int baselineOne = 0; + std::vector baselineScratch; + int* baseline = &baselineOne; + size_t baselineSize = (lines.size() - 1) / span + 1; + + if (baselineSize > 1) + { + // avoid heap allocation for single-element baseline which is most functions (<256 lines) + baselineScratch.resize(baselineSize); + baseline = baselineScratch.data(); + } for (size_t offset = 0; offset < lines.size(); offset += span) { @@ -725,7 +779,7 @@ void BytecodeBuilder::writeLineInfo(std::string& ss) const int lastLine = 0; - for (size_t i = 0; i < baseline.size(); ++i) + for (size_t i = 0; i < baselineSize; ++i) { writeInt(ss, baseline[i] - lastLine); lastLine = baseline[i]; @@ -1695,6 +1749,14 @@ std::string BytecodeBuilder::dumpCurrentFunction() const continue; } + if (dumpFlags & Dump_Remarks) + { + const uint32_t* remark = debugRemarks.find(uint32_t(code - insns.data())); + + if (remark) + formatAppend(result, "REMARK %s\n", debugRemarkBuffer.c_str() + *remark); + } + if (dumpFlags & Dump_Source) { int line = lines[code - insns.data()]; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 8ef69e7..810caae 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -8,12 +8,17 @@ #include "Builtins.h" #include "ConstantFolding.h" +#include "CostModel.h" #include "TableShape.h" #include "ValueTracking.h" #include #include #include +#include + +LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThreshold, 25) +LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThresholdMaxBoost, 300) namespace Luau { @@ -77,8 +82,12 @@ struct Compiler , globals(AstName()) , variables(nullptr) , constants(nullptr) + , locstants(nullptr) , tableShapes(nullptr) { + // preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays + localStack.reserve(16); + upvals.reserve(16); } uint8_t getLocal(AstLocal* local) @@ -209,7 +218,9 @@ struct Compiler Function& f = functions[func]; f.id = fid; - f.upvals = std::move(upvals); + f.upvals = upvals; + + upvals.clear(); // note: instead of std::move above, we copy & clear to preserve capacity for future pushes return fid; } @@ -2133,10 +2144,119 @@ struct Compiler pushLocal(stat->vars.data[i], uint8_t(vars + i)); } + int getConstantShort(AstExpr* expr) + { + const Constant* c = constants.find(expr); + + if (c && c->type == Constant::Type_Number) + { + double n = c->valueNumber; + + if (n >= -32767 && n <= 32767 && double(int(n)) == n) + return int(n); + } + + return INT_MIN; + } + + bool canUnrollForBody(AstStatFor* stat) + { + struct CanUnrollVisitor : AstVisitor + { + bool result = true; + + bool visit(AstExpr* node) override + { + // functions may capture loop variable, and our upval handling doesn't handle elided variables (constant) + result = result && !node->is(); + return result; + } + + bool visit(AstStat* node) override + { + // while we can easily unroll nested loops, our cost model doesn't take unrolling into account so this can result in code explosion + // we also avoid continue/break since they introduce control flow across iterations + result = result && !node->is() && !node->is() && !node->is(); + return result; + } + }; + + CanUnrollVisitor canUnroll; + stat->body->visit(&canUnroll); + + return canUnroll.result; + } + + bool tryCompileUnrolledFor(AstStatFor* stat, int thresholdBase, int thresholdMaxBoost) + { + int from = getConstantShort(stat->from); + int to = getConstantShort(stat->to); + int step = stat->step ? getConstantShort(stat->step) : 1; + + // check that limits are reasonably small and trip count can be computed + if (from == INT_MIN || to == INT_MIN || step == INT_MIN || step == 0 || (step < 0 && to > from) || (step > 0 && to < from)) + { + bytecode.addDebugRemark("loop unroll failed: invalid iteration count"); + return false; + } + + if (!canUnrollForBody(stat)) + { + bytecode.addDebugRemark("loop unroll failed: unsupported loop body"); + return false; + } + + int tripCount = (to - from) / step + 1; + + if (tripCount > thresholdBase * thresholdMaxBoost / 100) + { + bytecode.addDebugRemark("loop unroll failed: too many iterations (%d)", tripCount); + return false; + } + + AstLocal* var = stat->var; + uint64_t costModel = modelCost(stat->body, &var, 1); + + // we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to unrolling + bool varc = true; + int unrolledCost = computeCost(costModel, &varc, 1) * tripCount; + int baselineCost = (computeCost(costModel, nullptr, 0) + 1) * tripCount; + int unrollProfit = (unrolledCost == 0) ? thresholdMaxBoost : std::min(thresholdMaxBoost, 100 * baselineCost / unrolledCost); + + int threshold = thresholdBase * unrollProfit / 100; + + if (unrolledCost > threshold) + { + bytecode.addDebugRemark( + "loop unroll failed: too expensive (iterations %d, cost %d, profit %.2fx)", tripCount, unrolledCost, double(unrollProfit) / 100); + return false; + } + + bytecode.addDebugRemark("loop unroll succeeded (iterations %d, cost %d, profit %.2fx)", tripCount, unrolledCost, double(unrollProfit) / 100); + + for (int i = from; step > 0 ? i <= to : i >= to; i += step) + { + // we need to re-fold constants in the loop body with the new value; this reuses computed constant values elsewhere in the tree + locstants[var].type = Constant::Type_Number; + locstants[var].valueNumber = i; + + foldConstants(constants, variables, locstants, stat); + + compileStat(stat->body); + } + + return true; + } + void compileStatFor(AstStatFor* stat) { RegScope rs(this); + // Optimization: small loops can be unrolled when it is profitable + if (options.optimizationLevel >= 2 && isConstant(stat->to) && isConstant(stat->from) && (!stat->step || isConstant(stat->step))) + if (tryCompileUnrolledFor(stat, FInt::LuauCompileLoopUnrollThreshold, FInt::LuauCompileLoopUnrollThresholdMaxBoost)) + return; + size_t oldLocals = localStack.size(); size_t oldJumps = loopJumps.size(); @@ -2826,6 +2946,8 @@ struct Compiler : self(self) , functions(functions) { + // preallocate the result; this works around std::vector's inefficient growth policy for small arrays + functions.reserve(16); } bool visit(AstExprFunction* node) override @@ -2979,6 +3101,7 @@ struct Compiler DenseHashMap globals; DenseHashMap variables; DenseHashMap constants; + DenseHashMap locstants; DenseHashMap tableShapes; unsigned int regTop = 0; @@ -3008,7 +3131,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName if (options.optimizationLevel >= 1) { // this pass analyzes constantness of expressions - foldConstants(compiler.constants, compiler.variables, root); + foldConstants(compiler.constants, compiler.variables, compiler.locstants, root); // this pass analyzes table assignments to estimate table shapes for initially empty tables predictTableShapes(compiler.tableShapes, root); diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 35ea0bf..7ad91d4 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -191,13 +191,13 @@ struct ConstantVisitor : AstVisitor { DenseHashMap& constants; DenseHashMap& variables; + DenseHashMap& locals; - DenseHashMap locals; - - ConstantVisitor(DenseHashMap& constants, DenseHashMap& variables) + ConstantVisitor( + DenseHashMap& constants, DenseHashMap& variables, DenseHashMap& locals) : constants(constants) , variables(variables) - , locals(nullptr) + , locals(locals) { } @@ -385,9 +385,10 @@ struct ConstantVisitor : AstVisitor } }; -void foldConstants(DenseHashMap& constants, DenseHashMap& variables, AstNode* root) +void foldConstants(DenseHashMap& constants, DenseHashMap& variables, + DenseHashMap& locals, AstNode* root) { - ConstantVisitor visitor{constants, variables}; + ConstantVisitor visitor{constants, variables, locals}; root->visit(&visitor); } diff --git a/Compiler/src/ConstantFolding.h b/Compiler/src/ConstantFolding.h index c0e6353..0a995d7 100644 --- a/Compiler/src/ConstantFolding.h +++ b/Compiler/src/ConstantFolding.h @@ -42,7 +42,8 @@ struct Constant } }; -void foldConstants(DenseHashMap& constants, DenseHashMap& variables, AstNode* root); +void foldConstants(DenseHashMap& constants, DenseHashMap& variables, + DenseHashMap& locals, AstNode* root); } // namespace Compile } // namespace Luau diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 46b1093..431f7e5 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,6 +14,8 @@ #include +LUAU_FASTFLAG(LuauGcWorkTrackFix) + const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -1050,6 +1052,7 @@ int lua_gc(lua_State* L, int what, int data) { size_t prevthreshold = g->GCthreshold; size_t amount = (cast_to(size_t, data) << 10); + ptrdiff_t oldcredit = g->gcstate == GCSpause ? 0 : g->GCthreshold - g->totalbytes; // temporarily adjust the threshold so that we can perform GC work if (amount <= g->totalbytes) @@ -1069,9 +1072,9 @@ int lua_gc(lua_State* L, int what, int data) while (g->GCthreshold <= g->totalbytes) { - luaC_step(L, false); + size_t stepsize = luaC_step(L, false); - actualwork += g->gcstepsize; + actualwork += FFlag::LuauGcWorkTrackFix ? stepsize : g->gcstepsize; if (g->gcstate == GCSpause) { /* end of cycle? */ @@ -1107,11 +1110,20 @@ int lua_gc(lua_State* L, int what, int data) // if cycle hasn't finished, advance threshold forward for the amount of extra work performed if (g->gcstate != GCSpause) { - // if a new cycle was triggered by explicit step, we ignore old threshold as that shows an incorrect 'credit' of GC work - if (waspaused) - g->GCthreshold = g->totalbytes + actualwork; + if (FFlag::LuauGcWorkTrackFix) + { + // if a new cycle was triggered by explicit step, old 'credit' of GC work is 0 + ptrdiff_t newthreshold = g->totalbytes + actualwork + oldcredit; + g->GCthreshold = newthreshold < 0 ? 0 : newthreshold; + } else - g->GCthreshold = prevthreshold + actualwork; + { + // if a new cycle was triggered by explicit step, we ignore old threshold as that shows an incorrect 'credit' of GC work + if (waspaused) + g->GCthreshold = g->totalbytes + actualwork; + else + g->GCthreshold = prevthreshold + actualwork; + } } break; } diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 8fc930d..e7b73fe 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -13,9 +13,10 @@ #include -#define GC_SWEEPMAX 40 -#define GC_SWEEPCOST 10 -#define GC_SWEEPPAGESTEPCOST 4 +LUAU_FASTFLAGVARIABLE(LuauGcWorkTrackFix, false) +LUAU_FASTFLAGVARIABLE(LuauGcSweepCostFix, false) + +#define GC_SWEEPPAGESTEPCOST (FFlag::LuauGcSweepCostFix ? 16 : 4) #define GC_INTERRUPT(state) \ { \ @@ -64,7 +65,7 @@ static void recordGcStateStep(global_State* g, int startgcstate, double seconds, case GCSpropagate: case GCSpropagateagain: g->gcmetrics.currcycle.marktime += seconds; - g->gcmetrics.currcycle.markrequests += g->gcstepsize; + g->gcmetrics.currcycle.markwork += work; if (assist) g->gcmetrics.currcycle.markassisttime += seconds; @@ -74,7 +75,7 @@ static void recordGcStateStep(global_State* g, int startgcstate, double seconds, break; case GCSsweep: g->gcmetrics.currcycle.sweeptime += seconds; - g->gcmetrics.currcycle.sweeprequests += g->gcstepsize; + g->gcmetrics.currcycle.sweepwork += work; if (assist) g->gcmetrics.currcycle.sweepassisttime += seconds; @@ -87,13 +88,11 @@ static void recordGcStateStep(global_State* g, int startgcstate, double seconds, { g->gcmetrics.stepassisttimeacc += seconds; g->gcmetrics.currcycle.assistwork += work; - g->gcmetrics.currcycle.assistrequests += g->gcstepsize; } else { g->gcmetrics.stepexplicittimeacc += seconds; g->gcmetrics.currcycle.explicitwork += work; - g->gcmetrics.currcycle.explicitrequests += g->gcstepsize; } } @@ -878,11 +877,11 @@ static size_t getheaptrigger(global_State* g, size_t heapgoal) return heaptrigger < int64_t(g->totalbytes) ? g->totalbytes : (heaptrigger > int64_t(heapgoal) ? heapgoal : size_t(heaptrigger)); } -void luaC_step(lua_State* L, bool assist) +size_t luaC_step(lua_State* L, bool assist) { global_State* g = L->global; - int lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ + int lim = FFlag::LuauGcWorkTrackFix ? g->gcstepsize * g->gcstepmul / 100 : (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ LUAU_ASSERT(g->totalbytes >= g->GCthreshold); size_t debt = g->totalbytes - g->GCthreshold; @@ -902,12 +901,13 @@ void luaC_step(lua_State* L, bool assist) int lastgcstate = g->gcstate; size_t work = gcstep(L, lim); - (void)work; #ifdef LUAI_GCMETRICS recordGcStateStep(g, lastgcstate, lua_clock() - lasttimestamp, assist, work); #endif + size_t actualstepsize = work * 100 / g->gcstepmul; + // at the end of the last cycle if (g->gcstate == GCSpause) { @@ -927,14 +927,16 @@ void luaC_step(lua_State* L, bool assist) } else { - g->GCthreshold = g->totalbytes + g->gcstepsize; + g->GCthreshold = g->totalbytes + (FFlag::LuauGcWorkTrackFix ? actualstepsize : g->gcstepsize); // compensate if GC is "behind schedule" (has some debt to pay) - if (g->GCthreshold > debt) + if (FFlag::LuauGcWorkTrackFix ? g->GCthreshold >= debt : g->GCthreshold > debt) g->GCthreshold -= debt; } GC_INTERRUPT(lastgcstate); + + return actualstepsize; } void luaC_fullgc(lua_State* L) diff --git a/VM/src/lgc.h b/VM/src/lgc.h index dcd070b..08d1ff5 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -133,7 +133,7 @@ #define luaC_init(L, o, tt) luaC_initobj(L, cast_to(GCObject*, (o)), tt) LUAI_FUNC void luaC_freeall(lua_State* L); -LUAI_FUNC void luaC_step(lua_State* L, bool assist); +LUAI_FUNC size_t luaC_step(lua_State* L, bool assist); LUAI_FUNC void luaC_fullgc(lua_State* L); LUAI_FUNC void luaC_initobj(lua_State* L, GCObject* o, uint8_t tt); LUAI_FUNC void luaC_initupval(lua_State* L, UpVal* uv); diff --git a/VM/src/lstate.h b/VM/src/lstate.h index e7c3737..45d9ba2 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -106,7 +106,7 @@ struct GCCycleMetrics double markassisttime = 0.0; double markmaxexplicittime = 0.0; size_t markexplicitsteps = 0; - size_t markrequests = 0; + size_t markwork = 0; double atomicstarttimestamp = 0.0; size_t atomicstarttotalsizebytes = 0; @@ -122,10 +122,7 @@ struct GCCycleMetrics double sweepassisttime = 0.0; double sweepmaxexplicittime = 0.0; size_t sweepexplicitsteps = 0; - size_t sweeprequests = 0; - - size_t assistrequests = 0; - size_t explicitrequests = 0; + size_t sweepwork = 0; size_t assistwork = 0; size_t explicitwork = 0; diff --git a/bench/bench.py b/bench/bench.py index 39f219f..67fc8cf 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -814,13 +814,12 @@ def run(args, argsubcb): analyzeResult('', mainResult, compareResults) else: - for subdir, dirs, files in os.walk(arguments.folder): - for filename in files: - filepath = subdir + os.sep + filename - - if filename.endswith(".lua"): - if arguments.run_test == None or re.match(arguments.run_test, filename[:-4]): - runTest(subdir, filename, filepath) + all_files = [subdir + os.sep + filename for subdir, dirs, files in os.walk(arguments.folder) for filename in files] + for filepath in sorted(all_files): + subdir, filename = os.path.split(filepath) + if filename.endswith(".lua"): + if arguments.run_test == None or re.match(arguments.run_test, filename[:-4]): + runTest(subdir, filename, filepath) if arguments.sort and len(plotValueLists) > 1: rearrange(rearrangeSortKeyForComparison) diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 1022831..a48f068 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -103,7 +103,7 @@ int registerTypes(Luau::TypeChecker& env) // Vector3 stub TypeId vector3MetaType = arena.addType(TableTypeVar{}); - TypeId vector3InstanceType = arena.addType(ClassTypeVar{"Vector3", {}, nullopt, vector3MetaType, {}, {}}); + TypeId vector3InstanceType = arena.addType(ClassTypeVar{"Vector3", {}, nullopt, vector3MetaType, {}, {}, "Test"}); getMutable(vector3InstanceType)->props = { {"X", {env.numberType}}, {"Y", {env.numberType}}, @@ -117,7 +117,7 @@ int registerTypes(Luau::TypeChecker& env) env.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vector3InstanceType}; // Instance stub - TypeId instanceType = arena.addType(ClassTypeVar{"Instance", {}, nullopt, nullopt, {}, {}}); + TypeId instanceType = arena.addType(ClassTypeVar{"Instance", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(instanceType)->props = { {"Name", {env.stringType}}, }; @@ -125,7 +125,7 @@ int registerTypes(Luau::TypeChecker& env) env.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; // Part stub - TypeId partType = arena.addType(ClassTypeVar{"Part", {}, instanceType, nullopt, {}, {}}); + TypeId partType = arena.addType(ClassTypeVar{"Part", {}, instanceType, nullopt, {}, {}, "Test"}); getMutable(partType)->props = { {"Position", {vector3InstanceType}}, }; @@ -173,7 +173,7 @@ struct FuzzConfigResolver : Luau::ConfigResolver { FuzzConfigResolver() { - defaultConfig.mode = Luau::Mode::Nonstrict; // typecheckTwice option will cover Strict mode + defaultConfig.mode = Luau::Mode::Nonstrict; defaultConfig.enabledLint.warningMask = ~0ull; defaultConfig.parseOptions.captureComments = true; } @@ -275,6 +275,11 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) // lint (note that we need access to types so we need to do this with typeck in scope) if (kFuzzLinter && result.errors.empty()) frontend.lint(name, std::nullopt); + + // Second pass in strict mode (forced by auto-complete) + Luau::FrontendOptions opts; + opts.forAutocomplete = true; + frontend.check(name, opts); } catch (std::exception&) { diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 2e7902f..f66e23e 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -3034,4 +3034,40 @@ string:@1 CHECK(ac.entryMap["sub"].wrongIndexType == true); } +TEST_CASE_FIXTURE(ACFixture, "source_module_preservation_and_invalidation") +{ + check(R"( +local a = { x = 2, y = 4 } +a.@1 + )"); + + frontend.clear(); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); + + frontend.check("MainModule", {}); + + ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); + + frontend.markDirty("MainModule", nullptr); + + ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); + + frontend.check("MainModule", {}); + + ac = autocomplete('1'); + + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 83dad72..f3e6069 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -17,11 +17,13 @@ std::string rep(const std::string& s, size_t n); using namespace Luau; -static std::string compileFunction(const char* source, uint32_t id) +static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1) { Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); - Luau::compileOrThrow(bcb, source); + Luau::CompileOptions options; + options.optimizationLevel = optimizationLevel; + Luau::compileOrThrow(bcb, source, options); return bcb.dumpFunction(id); } @@ -2689,6 +2691,27 @@ local 8: reg 3, start pc 34 line 21, end pc 34 line 21 )"); } +TEST_CASE("DebugRemarks") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Remarks); + + uint32_t fid = bcb.beginFunction(0); + + bcb.addDebugRemark("test remark #%d", 42); + bcb.emitABC(LOP_RETURN, 0, 1, 0); + + bcb.endFunction(0, 0); + + bcb.setMainFunction(fid); + bcb.finalize(); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +REMARK test remark #42 +RETURN R0 0 +)"); +} + TEST_CASE("AssignmentConflict") { // assignments are left to right @@ -4076,4 +4099,336 @@ RETURN R1 6 )"); } +TEST_CASE("LoopUnrollBasic") +{ + // forward loops + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,2 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 2 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 2 +SETTABLEN R1 R0 2 +RETURN R0 1 +)"); + + // backward loops + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=2,1,-1 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 2 +SETTABLEN R1 R0 2 +LOADN R1 1 +SETTABLEN R1 R0 1 +RETURN R0 1 +)"); + + // loops with step that doesn't divide to-from + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,4,2 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 3 +SETTABLEN R1 R0 3 +RETURN R0 1 +)"); +} + +TEST_CASE("LoopUnrollUnsupported") +{ + // can't unroll loops with non-constant bounds + CHECK_EQ("\n" + compileFunction(R"( +for i=x,y,z do +end +)", + 0, 2), + R"( +GETIMPORT R2 1 +GETIMPORT R0 3 +GETIMPORT R1 5 +FORNPREP R0 +1 +FORNLOOP R0 -1 +RETURN R0 0 +)"); + + // can't unroll loops with bounds where we can't compute trip count + CHECK_EQ("\n" + compileFunction(R"( +for i=2,1 do +end +)", + 0, 2), + R"( +LOADN R2 2 +LOADN R0 1 +LOADN R1 1 +FORNPREP R0 +1 +FORNLOOP R0 -1 +RETURN R0 0 +)"); + + // can't unroll loops with bounds that might be imprecise (non-integer) + CHECK_EQ("\n" + compileFunction(R"( +for i=1,2,0.1 do +end +)", + 0, 2), + R"( +LOADN R2 1 +LOADN R0 2 +LOADK R1 K0 +FORNPREP R0 +1 +FORNLOOP R0 -1 +RETURN R0 0 +)"); + + // can't unroll loops if the bounds are too large, as it might overflow trip count math + CHECK_EQ("\n" + compileFunction(R"( +for i=4294967295,4294967296 do +end +)", + 0, 2), + R"( +LOADK R2 K0 +LOADK R0 K1 +LOADN R1 1 +FORNPREP R0 +1 +FORNLOOP R0 -1 +RETURN R0 0 +)"); + + // can't unroll loops if the body has loop control flow or nested loops + CHECK_EQ("\n" + compileFunction(R"( +for i=1,1 do + for j=1,1 do + if i == 1 then + continue + else + break + end + end +end +)", + 0, 2), + R"( +LOADN R2 1 +LOADN R0 1 +LOADN R1 1 +FORNPREP R0 +11 +LOADN R5 1 +LOADN R3 1 +LOADN R4 1 +FORNPREP R3 +6 +JUMPIFNOTEQK R2 K0 +5 +JUMP +2 +JUMP +1 +JUMP +1 +FORNLOOP R3 -6 +FORNLOOP R0 -11 +RETURN R0 0 +)"); + + // can't unroll loops if the body has functions that refer to loop variables + CHECK_EQ("\n" + compileFunction(R"( +for i=1,1 do + local x = function() return i end +end +)", + 1, 2), + R"( +LOADN R2 1 +LOADN R0 1 +LOADN R1 1 +FORNPREP R0 +3 +NEWCLOSURE R3 P0 +CAPTURE VAL R2 +FORNLOOP R0 -3 +RETURN R0 0 +)"); +} + +TEST_CASE("LoopUnrollCost") +{ + ScopedFastInt sfis[] = { + {"LuauCompileLoopUnrollThreshold", 25}, + {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, + }; + + // loops with short body + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,10 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 10 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 2 +SETTABLEN R1 R0 2 +LOADN R1 3 +SETTABLEN R1 R0 3 +LOADN R1 4 +SETTABLEN R1 R0 4 +LOADN R1 5 +SETTABLEN R1 R0 5 +LOADN R1 6 +SETTABLEN R1 R0 6 +LOADN R1 7 +SETTABLEN R1 R0 7 +LOADN R1 8 +SETTABLEN R1 R0 8 +LOADN R1 9 +SETTABLEN R1 R0 9 +LOADN R1 10 +SETTABLEN R1 R0 10 +RETURN R0 1 +)"); + + // loops with body that's too long + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,100 do + t[i] = i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R3 1 +LOADN R1 100 +LOADN R2 1 +FORNPREP R1 +2 +SETTABLE R3 R0 R3 +FORNLOOP R1 -2 +RETURN R0 1 +)"); + + // loops with body that's long but has a high boost factor due to constant folding + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,30 do + t[i] = i * i * i +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 0 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 8 +SETTABLEN R1 R0 2 +LOADN R1 27 +SETTABLEN R1 R0 3 +LOADN R1 64 +SETTABLEN R1 R0 4 +LOADN R1 125 +SETTABLEN R1 R0 5 +LOADN R1 216 +SETTABLEN R1 R0 6 +LOADN R1 343 +SETTABLEN R1 R0 7 +LOADN R1 512 +SETTABLEN R1 R0 8 +LOADN R1 729 +SETTABLEN R1 R0 9 +LOADN R1 1000 +SETTABLEN R1 R0 10 +LOADN R1 1331 +SETTABLEN R1 R0 11 +LOADN R1 1728 +SETTABLEN R1 R0 12 +LOADN R1 2197 +SETTABLEN R1 R0 13 +LOADN R1 2744 +SETTABLEN R1 R0 14 +LOADN R1 3375 +SETTABLEN R1 R0 15 +LOADN R1 4096 +SETTABLEN R1 R0 16 +LOADN R1 4913 +SETTABLEN R1 R0 17 +LOADN R1 5832 +SETTABLEN R1 R0 18 +LOADN R1 6859 +SETTABLEN R1 R0 19 +LOADN R1 8000 +SETTABLEN R1 R0 20 +LOADN R1 9261 +SETTABLEN R1 R0 21 +LOADN R1 10648 +SETTABLEN R1 R0 22 +LOADN R1 12167 +SETTABLEN R1 R0 23 +LOADN R1 13824 +SETTABLEN R1 R0 24 +LOADN R1 15625 +SETTABLEN R1 R0 25 +LOADN R1 17576 +SETTABLEN R1 R0 26 +LOADN R1 19683 +SETTABLEN R1 R0 27 +LOADN R1 21952 +SETTABLEN R1 R0 28 +LOADN R1 24389 +SETTABLEN R1 R0 29 +LOADN R1 27000 +SETTABLEN R1 R0 30 +RETURN R0 1 +)"); + + // loops with body that's long and doesn't have a high boost factor + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +for i=1,10 do + t[i] = math.abs(math.sin(i)) +end +return t +)", + 0, 2), + R"( +NEWTABLE R0 0 10 +LOADN R3 1 +LOADN R1 10 +LOADN R2 1 +FORNPREP R1 +11 +FASTCALL1 24 R3 +3 +MOVE R6 R3 +GETIMPORT R5 2 +CALL R5 1 -1 +FASTCALL 2 +2 +GETIMPORT R4 4 +CALL R4 -1 1 +SETTABLE R4 R0 R3 +FORNLOOP R1 -11 +RETURN R0 1 +)"); +} + TEST_SUITE_END(); diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp index ec04932..aa5b728 100644 --- a/tests/CostModel.test.cpp +++ b/tests/CostModel.test.cpp @@ -98,4 +98,129 @@ end CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 1)); } +TEST_CASE("ImportCall") +{ + uint64_t model = modelFunction(R"( +function test(a) + return Instance.new(a) +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("FastCall") +{ + uint64_t model = modelFunction(R"( +function test(a) + return math.abs(a + 1) +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + // note: we currently don't treat fast calls differently from cost model perspective + CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(5, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("ControlFlow") +{ + uint64_t model = modelFunction(R"( +function test(a) + while a < 0 do + a += 1 + end + for i=1,2 do + a += 1 + end + for i in pairs({}) do + a += 1 + if a % 2 == 0 then continue end + end + repeat + a += 1 + if a % 2 == 0 then break end + until a > 10 + return a +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(38, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(37, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("Conditional") +{ + uint64_t model = modelFunction(R"( +function test(a) + return if a < 0 then -a else a +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(4, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 1)); +} + +TEST_CASE("VarArgs") +{ + uint64_t model = modelFunction(R"( +function test(...) + return select('#', ...) :: number +end +)"); + + CHECK_EQ(8, Luau::Compile::computeCost(model, nullptr, 0)); +} + +TEST_CASE("TablesFunctions") +{ + uint64_t model = modelFunction(R"( +function test() + return { 42, op = function() end } +end +)"); + + CHECK_EQ(22, Luau::Compile::computeCost(model, nullptr, 0)); +} + +TEST_CASE("CostOverflow") +{ + uint64_t model = modelFunction(R"( +function test() + return {{{{{{{{{{{{{{{}}}}}}}}}}}}}}} +end +)"); + + CHECK_EQ(127, Luau::Compile::computeCost(model, nullptr, 0)); +} + +TEST_CASE("TableAssign") +{ + uint64_t model = modelFunction(R"( +function test(a) + for i=1,#a do + a[i] = i + end +end +)"); + + const bool args1[] = {false}; + const bool args2[] = {true}; + + CHECK_EQ(4, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1)); +} + TEST_SUITE_END(); diff --git a/tests/JsonEncoder.test.cpp b/tests/JsonEncoder.test.cpp index 1d2ad64..6f1cebc 100644 --- a/tests/JsonEncoder.test.cpp +++ b/tests/JsonEncoder.test.cpp @@ -9,6 +9,46 @@ using namespace Luau; +struct JsonEncoderFixture +{ + Allocator allocator; + AstNameTable names{allocator}; + + ParseResult parse(std::string_view src) + { + ParseOptions opts; + opts.allowDeclarationSyntax = true; + return Parser::parse(src.data(), src.size(), names, allocator, opts); + } + + AstStatBlock* expectParse(std::string_view src) + { + ParseResult res = parse(src); + REQUIRE(res.errors.size() == 0); + return res.root; + } + + AstStat* expectParseStatement(std::string_view src) + { + AstStatBlock* root = expectParse(src); + REQUIRE(1 == root->body.size); + return root->body.data[0]; + } + + AstExpr* expectParseExpr(std::string_view src) + { + std::string s = "a = "; + s.append(src); + AstStatBlock* root = expectParse(s); + + AstStatAssign* statAssign = root->body.data[0]->as(); + REQUIRE(statAssign != nullptr); + REQUIRE(statAssign->values.size == 1); + + return statAssign->values.data[0]; + } +}; + TEST_SUITE_BEGIN("JsonEncoderTests"); TEST_CASE("encode_constants") @@ -51,7 +91,7 @@ TEST_CASE("encode_AstStatBlock") toJson(&block)); } -TEST_CASE("encode_tables") +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_tables") { std::string src = R"( local x: { @@ -61,16 +101,294 @@ TEST_CASE("encode_tables") } )"; - Allocator allocator; - AstNameTable names(allocator); - ParseResult parseResult = Parser::parse(src.c_str(), src.length(), names, allocator); - - REQUIRE(parseResult.errors.size() == 0); - std::string json = toJson(parseResult.root); + AstStatBlock* root = expectParse(src); + std::string json = toJson(root); CHECK( json == R"({"type":"AstStatBlock","location":"0,0 - 6,4","body":[{"type":"AstStatLocal","location":"1,8 - 5,9","vars":[{"type":{"type":"AstTypeTable","location":"1,17 - 3,9","props":[{"name":"foo","location":"2,12 - 2,15","type":{"type":"AstTypeReference","location":"2,17 - 2,23","name":"number","parameters":[]}}],"indexer":false},"name":"x","location":"1,14 - 1,15"}],"values":[{"type":"AstExprTable","location":"3,12 - 5,9","items":[{"kind":"record","key":{"type":"AstExprConstantString","location":"4,12 - 4,15","value":"foo"},"value":{"type":"AstExprConstantNumber","location":"4,18 - 4,21","value":123}}]}]}]})"); } +TEST_CASE("encode_AstExprGroup") +{ + AstExprConstantNumber number{Location{}, 5.0}; + AstExprGroup group{Location{}, &number}; + + std::string json = toJson(&group); + + const std::string expected = R"({"type":"AstExprGroup","location":"0,0 - 0,0","expr":{"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":5}})"; + + CHECK(json == expected); +} + +TEST_CASE("encode_AstExprGlobal") +{ + AstExprGlobal global{Location{}, AstName{"print"}}; + + std::string json = toJson(&global); + std::string expected = R"({"type":"AstExprGlobal","location":"0,0 - 0,0","global":"print"})"; + + CHECK(json == expected); +} + +TEST_CASE("encode_AstExprLocal") +{ + AstLocal local{AstName{"foo"}, Location{}, nullptr, 0, 0, nullptr}; + AstExprLocal exprLocal{Location{}, &local, false}; + + CHECK(toJson(&exprLocal) == R"({"type":"AstExprLocal","location":"0,0 - 0,0","local":{"type":null,"name":"foo","location":"0,0 - 0,0"}})"); +} + +TEST_CASE("encode_AstExprVarargs") +{ + AstExprVarargs varargs{Location{}}; + + CHECK(toJson(&varargs) == R"({"type":"AstExprVarargs","location":"0,0 - 0,0"})"); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprCall") +{ + AstExpr* expr = expectParseExpr("foo(1, 2, 3)"); + std::string_view expected = R"({"type":"AstExprCall","location":"0,4 - 0,16","func":{"type":"AstExprGlobal","location":"0,4 - 0,7","global":"foo"},"args":[{"type":"AstExprConstantNumber","location":"0,8 - 0,9","value":1},{"type":"AstExprConstantNumber","location":"0,11 - 0,12","value":2},{"type":"AstExprConstantNumber","location":"0,14 - 0,15","value":3}],"self":false,"argLocation":"0,8 - 0,16"})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprIndexName") +{ + AstExpr* expr = expectParseExpr("foo.bar"); + + std::string_view expected = R"({"type":"AstExprIndexName","location":"0,4 - 0,11","expr":{"type":"AstExprGlobal","location":"0,4 - 0,7","global":"foo"},"index":"bar","indexLocation":"0,8 - 0,11","op":"."})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprIndexExpr") +{ + AstExpr* expr = expectParseExpr("foo['bar']"); + + std::string_view expected = R"({"type":"AstExprIndexExpr","location":"0,4 - 0,14","expr":{"type":"AstExprGlobal","location":"0,4 - 0,7","global":"foo"},"index":{"type":"AstExprConstantString","location":"0,8 - 0,13","value":"bar"}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprFunction") +{ + AstExpr* expr = expectParseExpr("function (a) return a end"); + + std::string_view expected = R"({"type":"AstExprFunction","location":"0,4 - 0,29","generics":[],"genericPacks":[],"args":[{"type":null,"name":"a","location":"0,14 - 0,15"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,16 - 0,26","body":[{"type":"AstStatReturn","location":"0,17 - 0,25","list":[{"type":"AstExprLocal","location":"0,24 - 0,25","local":{"type":null,"name":"a","location":"0,14 - 0,15"}}]}]},"functionDepth":1,"debugname":"","hasEnd":true})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprTable") +{ + AstExpr* expr = expectParseExpr("{true, key=true, [key2]=true}"); + + std::string_view expected = R"({"type":"AstExprTable","location":"0,4 - 0,33","items":[{"kind":"item","value":{"type":"AstExprConstantBool","location":"0,5 - 0,9","value":true}},{"kind":"record","key":{"type":"AstExprConstantString","location":"0,11 - 0,14","value":"key"},"value":{"type":"AstExprConstantBool","location":"0,15 - 0,19","value":true}},{"kind":"general","key":{"type":"AstExprGlobal","location":"0,22 - 0,26","global":"key2"},"value":{"type":"AstExprConstantBool","location":"0,28 - 0,32","value":true}}]})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprUnary") +{ + AstExpr* expr = expectParseExpr("-b"); + + std::string_view expected = R"({"type":"AstExprUnary","location":"0,4 - 0,6","op":"minus","expr":{"type":"AstExprGlobal","location":"0,5 - 0,6","global":"b"}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprBinary") +{ + AstExpr* expr = expectParseExpr("b + c"); + + std::string_view expected = R"({"type":"AstExprBinary","location":"0,4 - 0,9","op":"Add","left":{"type":"AstExprGlobal","location":"0,4 - 0,5","global":"b"},"right":{"type":"AstExprGlobal","location":"0,8 - 0,9","global":"c"}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprTypeAssertion") +{ + AstExpr* expr = expectParseExpr("b :: any"); + + std::string_view expected = R"({"type":"AstExprTypeAssertion","location":"0,4 - 0,12","expr":{"type":"AstExprGlobal","location":"0,4 - 0,5","global":"b"},"annotation":{"type":"AstTypeReference","location":"0,9 - 0,12","name":"any","parameters":[]}})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprError") +{ + std::string_view src = "a = "; + ParseResult parseResult = Parser::parse(src.data(), src.size(), names, allocator); + + REQUIRE(1 == parseResult.root->body.size); + + AstStatAssign* statAssign = parseResult.root->body.data[0]->as(); + REQUIRE(statAssign != nullptr); + REQUIRE(1 == statAssign->values.size); + + AstExpr* expr = statAssign->values.data[0]; + + std::string_view expected = R"({"type":"AstExprError","location":"0,4 - 0,4","expressions":[],"messageIndex":0})"; + + CHECK(toJson(expr) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatIf") +{ + AstStat* statement = expectParseStatement("if true then else end"); + + std::string_view expected = R"({"type":"AstStatIf","location":"0,0 - 0,21","condition":{"type":"AstExprConstantBool","location":"0,3 - 0,7","value":true},"thenbody":{"type":"AstStatBlock","location":"0,12 - 0,13","body":[]},"elsebody":{"type":"AstStatBlock","location":"0,17 - 0,18","body":[]},"hasThen":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatWhile") +{ + AstStat* statement = expectParseStatement("while true do end"); + + std::string_view expected = R"({"type":"AtStatWhile","location":"0,0 - 0,17","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,14","body":[]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatRepeat") +{ + AstStat* statement = expectParseStatement("repeat until true"); + + std::string_view expected = R"({"type":"AstStatRepeat","location":"0,0 - 0,17","condition":{"type":"AstExprConstantBool","location":"0,13 - 0,17","value":true},"body":{"type":"AstStatBlock","location":"0,6 - 0,7","body":[]},"hasUntil":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatBreak") +{ + AstStat* statement = expectParseStatement("while true do break end"); + + std::string_view expected = R"({"type":"AtStatWhile","location":"0,0 - 0,23","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,20","body":[{"type":"AstStatBreak","location":"0,14 - 0,19"}]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatContinue") +{ + AstStat* statement = expectParseStatement("while true do continue end"); + + std::string_view expected = R"({"type":"AtStatWhile","location":"0,0 - 0,26","condition":{"type":"AstExprConstantBool","location":"0,6 - 0,10","value":true},"body":{"type":"AstStatBlock","location":"0,13 - 0,23","body":[{"type":"AstStatContinue","location":"0,14 - 0,22"}]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatFor") +{ + AstStat* statement = expectParseStatement("for a=0,1 do end"); + + std::string_view expected = R"({"type":"AstStatFor","location":"0,0 - 0,16","var":{"type":null,"name":"a","location":"0,4 - 0,5"},"from":{"type":"AstExprConstantNumber","location":"0,6 - 0,7","value":0},"to":{"type":"AstExprConstantNumber","location":"0,8 - 0,9","value":1},"body":{"type":"AstStatBlock","location":"0,12 - 0,13","body":[]},"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatForIn") +{ + AstStat* statement = expectParseStatement("for a in b do end"); + + std::string_view expected = R"({"type":"AstStatForIn","location":"0,0 - 0,17","vars":[{"type":null,"name":"a","location":"0,4 - 0,5"}],"values":[{"type":"AstExprGlobal","location":"0,9 - 0,10","global":"b"}],"body":{"type":"AstStatBlock","location":"0,13 - 0,14","body":[]},"hasIn":true,"hasDo":true,"hasEnd":true})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatCompoundAssign") +{ + AstStat* statement = expectParseStatement("a += b"); + + std::string_view expected = R"({"type":"AstStatCompoundAssign","location":"0,0 - 0,6","op":"Add","var":{"type":"AstExprGlobal","location":"0,0 - 0,1","global":"a"},"value":{"type":"AstExprGlobal","location":"0,5 - 0,6","global":"b"}})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatLocalFunction") +{ + AstStat* statement = expectParseStatement("local function a(b) return end"); + + std::string_view expected = R"({"type":"AstStatLocalFunction","location":"0,0 - 0,30","name":{"type":null,"name":"a","location":"0,15 - 0,16"},"func":{"type":"AstExprFunction","location":"0,0 - 0,30","generics":[],"genericPacks":[],"args":[{"type":null,"name":"b","location":"0,17 - 0,18"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,19 - 0,27","body":[{"type":"AstStatReturn","location":"0,20 - 0,26","list":[]}]},"functionDepth":1,"debugname":"a","hasEnd":true}})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatTypeAlias") +{ + AstStat* statement = expectParseStatement("type A = B"); + + std::string_view expected = R"({"type":"AstStatTypeAlias","location":"0,0 - 0,10","name":"A","generics":[],"genericPacks":[],"type":{"type":"AstTypeReference","location":"0,9 - 0,10","name":"B","parameters":[]},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareFunction") +{ + AstStat* statement = expectParseStatement("declare function foo(x: number): string"); + + std::string_view expected = R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,39","name":"foo","params":{"types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","parameters":[]}]},"retTypes":{"types":[{"type":"AstTypeReference","location":"0,33 - 0,39","name":"string","parameters":[]}]},"generics":[],"genericPacks":[]})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") +{ + AstStatBlock* root = expectParse(R"( + declare class Foo + prop: number + function method(self, foo: number): string + end + + declare class Bar extends Foo + prop2: string + end + )"); + + REQUIRE(2 == root->body.size); + + std::string_view expected1 = R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","type":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","parameters":[]}},{"name":"method","type":{"type":"AstTypeFunction","location":"3,21 - 4,11","generics":[],"genericPacks":[],"argTypes":{"types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","parameters":[]}]},"returnTypes":{"types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","parameters":[]}]}}}]})"; + CHECK(toJson(root->body.data[0]) == expected1); + + std::string_view expected2 = R"({"type":"AstStatDeclareClass","location":"6,22 - 8,11","name":"Bar","superName":"Foo","props":[{"name":"prop2","type":{"type":"AstTypeReference","location":"7,19 - 7,25","name":"string","parameters":[]}}]})"; + CHECK(toJson(root->body.data[1]) == expected2); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_annotation") +{ + AstStat* statement = expectParseStatement("type T = ((number) -> (string | nil)) & ((string) -> ())"); + + std::string_view expected = R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,35","generics":[],"genericPacks":[],"argTypes":{"types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","parameters":[]}]},"returnTypes":{"types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","generics":[],"genericPacks":[],"argTypes":{"types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","parameters":[]}]},"returnTypes":{"types":[]}}]},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstTypeError") +{ + ParseResult parseResult = parse("type T = "); + REQUIRE(1 == parseResult.root->body.size); + + AstStat* statement = parseResult.root->body.data[0]; + + std::string_view expected = R"({"type":"AstStatTypeAlias","location":"0,0 - 0,9","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeError","location":"0,8 - 0,9","types":[],"messageIndex":0},"exported":false})"; + + CHECK(toJson(statement) == expected); +} + +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstTypePackExplicit") +{ + AstStatBlock* root = expectParse(R"( + type A = () -> T... + local a: A<(number, string)> + )"); + + CHECK(2 == root->body.size); + + std::string_view expected = R"({"type":"AstStatLocal","location":"2,8 - 2,36","vars":[{"type":{"type":"AstTypeReference","location":"2,17 - 2,36","name":"A","parameters":[{"type":"AstTypePackExplicit","location":"2,19 - 2,20","typeList":{"types":[{"type":"AstTypeReference","location":"2,20 - 2,26","name":"number","parameters":[]},{"type":"AstTypeReference","location":"2,28 - 2,34","name":"string","parameters":[]}]}}]},"name":"a","location":"2,14 - 2,15"}],"values":[]})"; + + CHECK(toJson(root->body.data[1]) == expected); +} + TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 05ee9a7..6649cb7 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1436,7 +1436,7 @@ TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") { unfreeze(typeChecker.globalTypes); - TypeId instanceType = typeChecker.globalTypes.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, {}}); + TypeId instanceType = typeChecker.globalTypes.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); persist(instanceType); typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 738893d..af7d76d 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -173,13 +173,13 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") { {"__add", {typeChecker.anyType}}, }, - std::nullopt, std::nullopt, {}, {}}}; + std::nullopt, std::nullopt, {}, {}, "Test"}}; TypeVar exampleClass{ClassTypeVar{"ExampleClass", { {"PropOne", {typeChecker.numberType}}, {"PropTwo", {typeChecker.stringType}}, }, - std::nullopt, &exampleMetaClass, {}, {}}}; + std::nullopt, &exampleMetaClass, {}, {}, "Test"}}; TypeArena dest; CloneState cloneState; @@ -196,9 +196,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") CHECK_EQ("ExampleClassMeta", metatable->name); } -TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") +TEST_CASE_FIXTURE(Fixture, "clone_free_types") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + ScopedFastFlag sff[]{ + {"LuauErrorRecoveryType", true}, + {"LuauLosslessClone", true}, + }; TypeVar freeTy(FreeTypeVar{TypeLevel{}}); TypePackVar freeTp(FreeTypePack{TypeLevel{}}); @@ -207,17 +210,17 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") CloneState cloneState; TypeId clonedTy = clone(&freeTy, dest, cloneState); - CHECK_EQ("any", toString(clonedTy)); - CHECK(cloneState.encounteredFreeType); + CHECK(get(clonedTy)); cloneState = {}; TypePackId clonedTp = clone(&freeTp, dest, cloneState); - CHECK_EQ("...any", toString(clonedTp)); - CHECK(cloneState.encounteredFreeType); + CHECK(get(clonedTp)); } -TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") +TEST_CASE_FIXTURE(Fixture, "clone_free_tables") { + ScopedFastFlag sff{"LuauLosslessClone", true}; + TypeVar tableTy{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tableTy); ttv->state = TableState::Free; @@ -227,8 +230,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") TypeId cloned = clone(&tableTy, dest, cloneState); const TableTypeVar* clonedTtv = get(cloned); - CHECK_EQ(clonedTtv->state, TableState::Sealed); - CHECK(cloneState.encounteredFreeType); + CHECK_EQ(clonedTtv->state, TableState::Free); } TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection") diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index a8a12b6..9748eb2 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -13,6 +13,29 @@ using namespace Luau; TEST_SUITE_BEGIN("NonstrictModeTests"); +TEST_CASE_FIXTURE(Fixture, "function_returns_number_or_string") +{ + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true} + }; + + CheckResult result = check(R"( + --!nonstrict + local function f() + if math.random() > 0.5 then + return 5 + else + return "hi" + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("() -> number | string" == toString(requireType("f"))); +} + TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") { CheckResult result = check(R"( @@ -35,8 +58,13 @@ TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") REQUIRE_EQ(0, rets.size()); } -TEST_CASE_FIXTURE(Fixture, "infer_the_maximum_number_of_values_the_function_could_return") +TEST_CASE_FIXTURE(Fixture, "first_return_type_dictates_number_of_return_types") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nonstrict function getMinCardCountForWidth(width) @@ -51,22 +79,18 @@ TEST_CASE_FIXTURE(Fixture, "infer_the_maximum_number_of_values_the_function_coul TypeId t = requireType("getMinCardCountForWidth"); REQUIRE(t); - REQUIRE_EQ("(any) -> (...any)", toString(t)); + REQUIRE_EQ("(any) -> number", toString(t)); } -#if 0 -// Maybe we want this? TEST_CASE_FIXTURE(Fixture, "return_annotation_is_still_checked") { CheckResult result = check(R"( + --!nonstrict function foo(x): number return 'hello' end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE_NE(*typeChecker.anyType, *requireType("foo")); } -#endif TEST_CASE_FIXTURE(Fixture, "function_parameters_are_any") { @@ -256,6 +280,12 @@ TEST_CASE_FIXTURE(Fixture, "delay_function_does_not_require_its_argument_to_retu TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + {"LuauSealExports", true}, + }; + CheckResult result = check(R"( --!nonstrict @@ -272,7 +302,7 @@ TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("any", toString(getMainModule()->getModuleScope()->returnType)); + REQUIRE_EQ("((any) -> string) | {| foo: any |}", toString(getMainModule()->getModuleScope()->returnType)); } TEST_CASE_FIXTURE(Fixture, "returning_insufficient_return_values") diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 5a84201..d3778f6 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -21,7 +21,7 @@ void createSomeClasses(TypeChecker& typeChecker) unfreeze(arena); - TypeId parentType = arena.addType(ClassTypeVar{"Parent", {}, std::nullopt, std::nullopt, {}, nullptr}); + TypeId parentType = arena.addType(ClassTypeVar{"Parent", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); ClassTypeVar* parentClass = getMutable(parentType); parentClass->props["method"] = {makeFunction(arena, parentType, {}, {})}; @@ -31,7 +31,7 @@ void createSomeClasses(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "Parent", {parentType}); typeChecker.globalScope->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; - TypeId childType = arena.addType(ClassTypeVar{"Child", {}, parentType, std::nullopt, {}, nullptr}); + TypeId childType = arena.addType(ClassTypeVar{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test"}); ClassTypeVar* childClass = getMutable(childType); childClass->props["virtual_method"] = {makeFunction(arena, childType, {}, {})}; @@ -39,7 +39,7 @@ void createSomeClasses(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "Child", {childType}); typeChecker.globalScope->exportedTypeBindings["Child"] = TypeFun{{}, childType}; - TypeId unrelatedType = arena.addType(ClassTypeVar{"Unrelated", {}, std::nullopt, std::nullopt, {}, nullptr}); + TypeId unrelatedType = arena.addType(ClassTypeVar{"Unrelated", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); addGlobalBinding(typeChecker, "Unrelated", {unrelatedType}); typeChecker.globalScope->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; @@ -400,7 +400,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "table_with_table_prop") CHECK_EQ("{| x: {| y: number & string |} |}", toString(requireType("a"))); } -#if 0 TEST_CASE_FIXTURE(NormalizeFixture, "tables") { check(R"( @@ -428,6 +427,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "tables") CHECK(!isSubtype(b, d)); } +#if 0 TEST_CASE_FIXTURE(NormalizeFixture, "table_indexers_are_invariant") { check(R"( @@ -619,6 +619,7 @@ TEST_CASE_FIXTURE(Fixture, "normalize_module_return_type") { ScopedFastFlag sff[] = { {"LuauLowerBoundsCalculation", true}, + {"LuauReturnTypeInferenceInNonstrict", true}, }; check(R"( @@ -639,7 +640,7 @@ TEST_CASE_FIXTURE(Fixture, "normalize_module_return_type") end )"); - CHECK_EQ("(any, any) -> (...any)", toString(getMainModule()->getModuleScope()->returnType)); + CHECK_EQ("(any, any) -> (any, any) -> any", toString(getMainModule()->getModuleScope()->returnType)); } TEST_CASE_FIXTURE(Fixture, "return_type_is_not_a_constrained_intersection") @@ -950,6 +951,27 @@ TEST_CASE_FIXTURE(Fixture, "nested_table_normalization_with_non_table__no_ice") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "visiting_a_type_twice_is_not_considered_normal") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + --!strict + function f(a, b) + local function g() + if math.random() > 0.5 then + return a() + else + return b + end + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(() -> a, a) -> ()", toString(requireType("f"))); +} + TEST_CASE_FIXTURE(Fixture, "fuzz_failure_instersection_combine_must_follow") { ScopedFastFlag flags[] = { @@ -964,4 +986,16 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_failure_instersection_combine_must_follow") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "fuzz_failure_bound_type_is_normal_but_not_its_bounded_to") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + type t252 = ((t0)|(any))|(any) + type t0 = t252,t24...> + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index f3fda54..332a4b2 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -21,13 +21,13 @@ struct ToDotClassFixture : Fixture TypeId baseClassMetaType = arena.addType(TableTypeVar{}); - TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}}); + TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}, "Test"}); getMutable(baseClassInstanceType)->props = { {"BaseField", {typeChecker.numberType}}, }; typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; - TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}}); + TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}, "Test"}); getMutable(childClassInstanceType)->props = { {"ChildField", {typeChecker.stringType}}, }; diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 0c324cd..b02a52b 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -661,4 +661,21 @@ type t4 = false CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_array_types") +{ + std::string code = R"( +type t1 = {number} +type t2 = {[string]: number} + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_for_in_multiple_types") +{ + std::string code = "for k:string,v:boolean in next,{}do end"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 8e3629e..5a6e403 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -19,13 +19,13 @@ struct ClassFixture : Fixture unfreeze(arena); - TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}}); + TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(baseClassInstanceType)->props = { {"BaseMethod", {makeFunction(arena, baseClassInstanceType, {numberType}, {})}}, {"BaseField", {numberType}}, }; - TypeId baseClassType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}}); + TypeId baseClassType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(baseClassType)->props = { {"StaticMethod", {makeFunction(arena, nullopt, {}, {numberType})}}, {"Clone", {makeFunction(arena, nullopt, {baseClassInstanceType}, {baseClassInstanceType})}}, @@ -34,39 +34,39 @@ struct ClassFixture : Fixture typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; addGlobalBinding(typeChecker, "BaseClass", baseClassType, "@test"); - TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}}); + TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(childClassInstanceType)->props = { {"Method", {makeFunction(arena, childClassInstanceType, {}, {typeChecker.stringType})}}, }; - TypeId childClassType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassType, nullopt, {}, {}}); + TypeId childClassType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(childClassType)->props = { {"New", {makeFunction(arena, nullopt, {}, {childClassInstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; addGlobalBinding(typeChecker, "ChildClass", childClassType, "@test"); - TypeId grandChildInstanceType = arena.addType(ClassTypeVar{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}}); + TypeId grandChildInstanceType = arena.addType(ClassTypeVar{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(grandChildInstanceType)->props = { {"Method", {makeFunction(arena, grandChildInstanceType, {}, {typeChecker.stringType})}}, }; - TypeId grandChildType = arena.addType(ClassTypeVar{"GrandChild", {}, baseClassType, nullopt, {}, {}}); + TypeId grandChildType = arena.addType(ClassTypeVar{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(grandChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {grandChildInstanceType})}}, }; typeChecker.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; addGlobalBinding(typeChecker, "GrandChild", childClassType, "@test"); - TypeId anotherChildInstanceType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}}); + TypeId anotherChildInstanceType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(anotherChildInstanceType)->props = { {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {typeChecker.stringType})}}, }; - TypeId anotherChildType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassType, nullopt, {}, {}}); + TypeId anotherChildType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(anotherChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {anotherChildInstanceType})}}, }; @@ -75,13 +75,13 @@ struct ClassFixture : Fixture TypeId vector2MetaType = arena.addType(TableTypeVar{}); - TypeId vector2InstanceType = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, vector2MetaType, {}, {}}); + TypeId vector2InstanceType = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, vector2MetaType, {}, {}, "Test"}); getMutable(vector2InstanceType)->props = { {"X", {numberType}}, {"Y", {numberType}}, }; - TypeId vector2Type = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, nullopt, {}, {}}); + TypeId vector2Type = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(vector2Type)->props = { {"New", {makeFunction(arena, nullopt, {numberType, numberType}, {vector2InstanceType})}}, }; @@ -468,4 +468,18 @@ caused by: toString(result.errors[0])); } +TEST_CASE_FIXTURE(ClassFixture, "class_type_mismatch_with_name_conflict") +{ + ScopedFastFlag luauClassDefinitionModuleInError{"LuauClassDefinitionModuleInError", true}; + + CheckResult result = check(R"( +local i = ChildClass.New() +type ChildClass = { x: number } +local a: ChildClass = i + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'ChildClass' from 'Test' could not be converted into 'ChildClass' from 'MainModule'", toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 898d890..4545b8d 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -295,8 +295,6 @@ TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_type TEST_CASE_FIXTURE(Fixture, "single_class_type_identity_in_global_types") { - ScopedFastFlag luauCloneDeclaredGlobals{"LuauCloneDeclaredGlobals", true}; - loadDefinition(R"( declare class Cls end diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 6599368..7cd7bec 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -656,6 +656,11 @@ TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nonstrict @@ -664,7 +669,7 @@ TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") end return function() - return f():andThen() + return f() end )"); @@ -791,14 +796,18 @@ TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields TEST_CASE_FIXTURE(Fixture, "calling_function_with_anytypepack_doesnt_leak_free_types") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nonstrict - function Test(a) + function Test(a): ...any return 1, "" end - local tab = {} table.insert(tab, Test(1)); )"); @@ -1616,4 +1625,19 @@ TEST_CASE_FIXTURE(Fixture, "occurs_check_failure_in_function_return_type") CHECK(nullptr != get(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") +{ + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + local function f() return end + local g = function() return f() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index e5eeae3..fa1f519 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -307,8 +307,6 @@ type Rename = typeof(x.x) TEST_CASE_FIXTURE(Fixture, "module_type_conflict") { - ScopedFastFlag luauTypeMismatchModuleName{"LuauTypeMismatchModuleName", true}; - fileResolver.source["game/A"] = R"( export type T = { x: number } return {} @@ -343,8 +341,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "module_type_conflict_instantiated") { - ScopedFastFlag luauTypeMismatchModuleName{"LuauTypeMismatchModuleName", true}; - fileResolver.source["game/A"] = R"( export type Wrap = { x: T } return {} diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 8c8059d..4b5075d 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -584,20 +584,6 @@ TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") LUAU_REQUIRE_ERRORS(result); // Should not have any errors. } -TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") -{ - ScopedFastFlag sff[] = { - {"LuauLowerBoundsCalculation", false}, - }; - - CheckResult result = check(R"( - local function f() return end - local g = function() return f() end - )"); - - LUAU_REQUIRE_ERRORS(result); // Should not have any errors. -} - TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") { ScopedFastFlag sff[] = { @@ -636,6 +622,10 @@ TEST_CASE_FIXTURE(Fixture, "lower_bounds_calculation_is_too_permissive_with_over // Once fixed, move this to Normalize.test.cpp TEST_CASE_FIXTURE(Fixture, "normalization_fails_on_certain_kinds_of_cyclic_tables") { +#if defined(_DEBUG) || defined(_NOOPT) + ScopedFastInt sfi("LuauNormalizeIterationLimit", 500); +#endif + ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, }; diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index ce22bcb..136ca00 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -44,7 +44,7 @@ struct RefinementClassFixture : Fixture TypeArena& arena = typeChecker.globalTypes; unfreeze(arena); - TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr}); + TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); getMutable(vec3)->props = { {"X", Property{typeChecker.numberType}}, {"Y", Property{typeChecker.numberType}}, @@ -52,7 +52,7 @@ struct RefinementClassFixture : Fixture }; normalize(vec3, arena, *typeChecker.iceHandler); - TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr}); + TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); @@ -66,9 +66,9 @@ struct RefinementClassFixture : Fixture }; normalize(inst, arena, *typeChecker.iceHandler); - TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr}); + TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); normalize(folder, arena, *typeChecker.iceHandler); - TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr}); + TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); getMutable(part)->props = { {"Position", Property{vec3}}, }; diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index ca1b8de..2a727bb 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2086,7 +2086,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_key") { ScopedFastFlag luauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedIndexerError{"LuauExtendedIndexerError", true}; CheckResult result = check(R"( type A = { [number]: string } @@ -2105,7 +2104,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") { ScopedFastFlag luauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedIndexerError{"LuauExtendedIndexerError", true}; CheckResult result = check(R"( type A = { [number]: number } diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 6abd96b..a578b1c 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -86,16 +86,21 @@ TEST_CASE_FIXTURE(Fixture, "infer_locals_via_assignment_from_its_call_site") TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nocheck function f(x) - return x + return 5 end -- we get type information even if there's type errors f(1, 2) )"); - CHECK_EQ("(any) -> (...any)", toString(requireType("f"))); + CHECK_EQ("(any) -> number", toString(requireType("f"))); LUAU_REQUIRE_NO_ERRORS(result); } @@ -363,6 +368,11 @@ TEST_CASE_FIXTURE(Fixture, "globals") TEST_CASE_FIXTURE(Fixture, "globals2") { + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + CheckResult result = check(R"( --!nonstrict foo = function() return 1 end @@ -373,9 +383,9 @@ TEST_CASE_FIXTURE(Fixture, "globals2") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ("() -> (...any)", toString(tm->wantedType)); + CHECK_EQ("() -> number", toString(tm->wantedType)); CHECK_EQ("string", toString(tm->givenType)); - CHECK_EQ("() -> (...any)", toString(requireType("foo"))); + CHECK_EQ("() -> number", toString(requireType("foo"))); } TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index fd5f4db..d03bb03 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -275,7 +275,7 @@ TEST_CASE("tagging_tables") TEST_CASE("tagging_classes") { - TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; CHECK(!Luau::hasTag(&base, "foo")); Luau::attachTag(&base, "foo"); CHECK(Luau::hasTag(&base, "foo")); @@ -283,8 +283,8 @@ TEST_CASE("tagging_classes") TEST_CASE("tagging_subclasses") { - TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; - TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}}; + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; + TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr, "Test"}}; CHECK(!Luau::hasTag(&base, "foo")); CHECK(!Luau::hasTag(&derived, "foo")); From 4d9ac7db1e49a3bfdc96ed623d19fa966aa1920b Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 28 Apr 2022 18:04:52 -0700 Subject: [PATCH 05/19] Sync to upstream/release/525 --- Analysis/src/Frontend.cpp | 2 +- Analysis/src/Linter.cpp | 6 +- Analysis/src/Substitution.cpp | 2 +- Analysis/src/TypeInfer.cpp | 211 ++++++------------ Analysis/src/TypeVar.cpp | 11 +- Analysis/src/Unifier.cpp | 89 +++----- Ast/src/Lexer.cpp | 4 +- CLI/FileUtils.cpp | 4 +- CLI/Repl.cpp | 9 +- Compiler/include/Luau/BytecodeBuilder.h | 2 +- Compiler/src/BytecodeBuilder.cpp | 24 ++- Compiler/src/Compiler.cpp | 8 +- Compiler/src/CostModel.cpp | 2 +- Sources.cmake | 3 +- VM/include/lua.h | 2 +- VM/src/lapi.cpp | 2 +- VM/src/lstate.h | 2 +- VM/src/ltable.cpp | 55 ++--- VM/src/ludata.cpp | 13 +- fuzz/proto.cpp | 27 ++- tests/Autocomplete.test.cpp | 1 - tests/Compiler.test.cpp | 50 +++-- tests/Conformance.test.cpp | 2 +- tests/Frontend.test.cpp | 18 +- tests/Module.test.cpp | 1 - tests/NonstrictMode.test.cpp | 1 - tests/Parser.test.cpp | 3 - tests/RuntimeLimits.test.cpp | 270 ++++++++++++++++++++++++ tests/TypeInfer.aliases.test.cpp | 19 +- tests/TypeInfer.annotations.test.cpp | 4 - tests/TypeInfer.functions.test.cpp | 7 - tests/TypeInfer.generics.test.cpp | 24 +-- tests/TypeInfer.loops.test.cpp | 18 ++ tests/TypeInfer.operators.test.cpp | 4 - tests/TypeInfer.primitives.test.cpp | 2 - tests/TypeInfer.provisional.test.cpp | 236 --------------------- tests/TypeInfer.singletons.test.cpp | 16 -- tests/TypeInfer.tables.test.cpp | 10 - tests/TypeInfer.tryUnify.test.cpp | 2 - tests/TypeVar.test.cpp | 2 - 40 files changed, 527 insertions(+), 641 deletions(-) create mode 100644 tests/RuntimeLimits.test.cpp diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 34ccdac..b8f7836 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -24,7 +24,7 @@ LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false) LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) LUAU_FASTFLAGVARIABLE(LuauDirtySourceModule, false) -LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 0) +LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) namespace Luau { diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 5608e4b..200b7d1 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -2653,12 +2653,12 @@ static void lintComments(LintContext& context, const std::vector& ho } else { - std::string::size_type space = hc.content.find_first_of(" \t"); + size_t space = hc.content.find_first_of(" \t"); std::string_view first = std::string_view(hc.content).substr(0, space); if (first == "nolint") { - std::string::size_type notspace = hc.content.find_first_not_of(" \t", space); + size_t notspace = hc.content.find_first_not_of(" \t", space); if (space == std::string::npos || notspace == std::string::npos) { @@ -2827,7 +2827,7 @@ uint64_t LintWarning::parseMask(const std::vector& hotcomments) if (hc.content.compare(0, 6, "nolint") != 0) continue; - std::string::size_type name = hc.content.find_first_not_of(" \t", 6); + size_t name = hc.content.find_first_not_of(" \t", 6); // --!nolint disables everything if (name == std::string::npos) diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 1b51fa3..30d8574 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -8,7 +8,7 @@ #include LUAU_FASTFLAG(LuauLowerBoundsCalculation) -LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) +LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTFLAG(LuauTypecheckOptPass) LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowNewTypes, false) LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowPossibleMutations, false) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 6411e2a..ba91ae1 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -22,29 +22,25 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) -LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) -LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000) +LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) +LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) -LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) +LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauSeparateTypechecks) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAG(LuauAutocompleteSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauCyclicModuleTypeSurface, false) +LUAU_FASTFLAGVARIABLE(LuauDoNotRelyOnNextBinding, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) -LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) LUAU_FASTFLAGVARIABLE(LuauInferStatFunction, false) -LUAU_FASTFLAGVARIABLE(LuauSealExports, false) +LUAU_FASTFLAGVARIABLE(LuauInstantiateFollows, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) -LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) -LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) -LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) @@ -54,12 +50,9 @@ LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree2) LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) -LUAU_FASTFLAGVARIABLE(LuauFixArgumentCountMismatchAmountWithGenericTypes, false) -LUAU_FASTFLAGVARIABLE(LuauFixIncorrectLineNumberDuplicateType, false) LUAU_FASTFLAGVARIABLE(LuauCheckImplicitNumbericKeys, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false) -LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false) LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); @@ -1160,7 +1153,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) } else { - iterTy = follow(instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location)); + if (FFlag::LuauInstantiateFollows) + iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); + else + iterTy = follow(instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location)); } const FunctionTypeVar* iterFunc = get(iterTy); @@ -1172,7 +1168,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) unify(varTy, var, forin.location); if (!get(iterTy) && !get(iterTy) && !get(iterTy)) - reportError(TypeError{firstValue->location, TypeMismatch{globalScope->bindings[AstName{"next"}].typeId, iterTy}}); + { + if (FFlag::LuauDoNotRelyOnNextBinding) + reportError(firstValue->location, CannotCallNonFunction{iterTy}); + else + reportError(TypeError{firstValue->location, TypeMismatch{globalScope->bindings[AstName{"next"}].typeId, iterTy}}); + } return check(loopScope, *forin.body); } @@ -1427,8 +1428,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias ftv->forwardedTypeAlias = true; bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; - if (FFlag::LuauFixIncorrectLineNumberDuplicateType) - scope->typeAliasLocations[name] = typealias.location; + scope->typeAliasLocations[name] = typealias.location; } } else @@ -2217,7 +2217,7 @@ TypeId TypeChecker::checkExprTable( if (isNonstrictMode() && !getTableType(exprType) && !get(exprType)) exprType = anyType; - if (FFlag::LuauPropertiesGetExpectedType && expectedTable) + if (expectedTable) { auto it = expectedTable->props.find(key->value.data); if (it != expectedTable->props.end()) @@ -2309,9 +2309,8 @@ ExprResult TypeChecker::checkExpr_(const ScopePtr& scope, const AstExprT } } } - else if (FFlag::LuauExpectedTypesOfProperties) - if (const UnionTypeVar* utv = get(follow(*expectedType))) - expectedUnion = utv; + else if (const UnionTypeVar* utv = get(follow(*expectedType))) + expectedUnion = utv; } for (size_t i = 0; i < expr.items.size; ++i) @@ -2334,7 +2333,7 @@ ExprResult TypeChecker::checkExpr_(const ScopePtr& scope, const AstExprT if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) expectedResultType = prop->second.type; } - else if (FFlag::LuauExpectedTypesOfProperties && expectedUnion) + else if (expectedUnion) { std::vector expectedResultTypes; for (TypeId expectedOption : expectedUnion) @@ -2713,8 +2712,6 @@ TypeId TypeChecker::checkBinaryOperation( { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); } } @@ -2754,7 +2751,7 @@ TypeId TypeChecker::checkBinaryOperation( reportErrors(state.errors); bool hasErrors = !state.errors.empty(); - if (FFlag::LuauErrorRecoveryType && hasErrors) + if (hasErrors) { // If there are unification errors, the return type may still be unknown // so we loosen the argument types to see if that helps. @@ -2768,8 +2765,7 @@ TypeId TypeChecker::checkBinaryOperation( if (state.errors.empty()) state.log.commit(); } - - if (!hasErrors) + else { state.log.commit(); } @@ -3196,16 +3192,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T } else { - if (!ttv) - { - if (!FFlag::LuauErrorRecoveryType && !isTableIntersection(lhsType)) - // This error now gets reported when we check the function body. - reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); - - return errorRecoveryType(scope); - } - - if (lhsType->persistent || ttv->state == TableState::Sealed) + if (!ttv || lhsType->persistent || ttv->state == TableState::Sealed) return errorRecoveryType(scope); } @@ -3532,32 +3519,6 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } // Returns the minimum number of arguments the argument list can accept. -static size_t getMinParameterCount_DEPRECATED(TypePackId tp) -{ - size_t minCount = 0; - size_t optionalCount = 0; - - auto it = begin(tp); - auto endIter = end(tp); - - while (it != endIter) - { - TypeId ty = *it; - if (isOptional(ty)) - ++optionalCount; - else - { - minCount += optionalCount; - optionalCount = 0; - minCount++; - } - - ++it; - } - - return minCount; -} - static size_t getMinParameterCount(TxnLog* log, TypePackId tp) { size_t minCount = 0; @@ -3597,19 +3558,14 @@ void TypeChecker::checkArgumentList( size_t paramIndex = 0; - size_t minParams = FFlag::LuauFixIncorrectLineNumberDuplicateType ? 0 : getMinParameterCount_DEPRECATED(paramPack); - - auto reportCountMismatchError = [&state, &argLocations, minParams, paramPack, argPack]() { + auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack]() { // For this case, we want the error span to cover every errant extra parameter Location location = state.location; if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; - size_t mp = minParams; - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - mp = getMinParameterCount(&state.log, paramPack); - - state.reportError(TypeError{location, CountMismatch{mp, std::distance(begin(argPack), end(argPack))}}); + size_t minParams = getMinParameterCount(&state.log, paramPack); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); }; while (true) @@ -3707,16 +3663,10 @@ void TypeChecker::checkArgumentList( } // ok else { - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - minParams = getMinParameterCount(&state.log, paramPack); + size_t minParams = getMinParameterCount(&state.log, paramPack); - bool isVariadic = false; - if (FFlag::LuauArgCountMismatchSaysAtLeastWhenVariadic) - { - std::optional tail = flatten(paramPack, state.log).second; - if (tail) - isVariadic = Luau::isVariadic(*tail); - } + std::optional tail = flatten(paramPack, state.log).second; + bool isVariadic = tail && Luau::isVariadic(*tail); state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex, CountMismatch::Context::Arg, isVariadic}}); return; @@ -3863,7 +3813,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A actualFunctionType = instantiate(scope, functionType, expr.func->location); } - actualFunctionType = follow(actualFunctionType); + if (!FFlag::LuauInstantiateFollows) + actualFunctionType = follow(actualFunctionType); TypePackId retPack; if (FFlag::LuauLowerBoundsCalculation || !FFlag::LuauWidenIfSupertypeIsFree2) @@ -3930,16 +3881,13 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); - if (FFlag::LuauErrorRecoveryType) - { - const FunctionTypeVar* overload = nullptr; - if (!overloadsThatMatchArgCount.empty()) - overload = get(overloadsThatMatchArgCount[0]); - if (!overload && !overloadsThatDont.empty()) - overload = get(overloadsThatDont[0]); - if (overload) - return {errorRecoveryTypePack(overload->retType)}; - } + const FunctionTypeVar* overload = nullptr; + if (!overloadsThatMatchArgCount.empty()) + overload = get(overloadsThatMatchArgCount[0]); + if (!overload && !overloadsThatDont.empty()) + overload = get(overloadsThatDont[0]); + if (overload) + return {errorRecoveryTypePack(overload->retType)}; return {errorRecoveryTypePack(retPack)}; } @@ -4129,7 +4077,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (!argMismatch) overloadsThatMatchArgCount.push_back(fn); - else if (FFlag::LuauErrorRecoveryType) + else overloadsThatDont.push_back(fn); errors.emplace_back(std::move(state.errors), args->head, ftv); @@ -4715,7 +4663,7 @@ bool Anyification::isDirty(TypeId ty) return false; if (const TableTypeVar* ttv = log->getMutable(ty)) - return (ttv->state == TableState::Free || (FFlag::LuauSealExports && ttv->state == TableState::Unsealed)); + return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed); else if (log->getMutable(ty)) return true; else if (get(ty)) @@ -4743,12 +4691,9 @@ TypeId Anyification::clean(TypeId ty) TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; - if (FFlag::LuauSealExports) - { - clone.name = ttv->name; - clone.syntheticName = ttv->syntheticName; - clone.tags = ttv->tags; - } + clone.name = ttv->name; + clone.syntheticName = ttv->syntheticName; + clone.tags = ttv->tags; TypeId res = addType(std::move(clone)); asMutable(res)->normal = ty->normal; return res; @@ -4791,9 +4736,12 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { + if (FFlag::LuauInstantiateFollows) + ty = follow(ty); + if (FFlag::LuauTypecheckOptPass) { - const FunctionTypeVar* ftv = get(follow(ty)); + const FunctionTypeVar* ftv = get(FFlag::LuauInstantiateFollows ? ty : follow(ty)); if (ftv && ftv->hasNoGenerics) return ty; } @@ -5175,8 +5123,6 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); parameterCountErrorReported = true; - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); } } @@ -5294,33 +5240,25 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation reportError( TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); - if (FFlag::LuauErrorRecoveryType) - { - // Pad the types out with error recovery types - while (typeParams.size() < tf->typeParams.size()) - typeParams.push_back(errorRecoveryType(scope)); - while (typePackParams.size() < tf->typePackParams.size()) - typePackParams.push_back(errorRecoveryTypePack(scope)); - } - else - return errorRecoveryType(scope); + // Pad the types out with error recovery types + while (typeParams.size() < tf->typeParams.size()) + typeParams.push_back(errorRecoveryType(scope)); + while (typePackParams.size() < tf->typePackParams.size()) + typePackParams.push_back(errorRecoveryTypePack(scope)); } - if (FFlag::LuauRecursiveTypeParameterRestriction) - { - bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) { - return itp == tp.ty; + bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) { + return itp == tp.ty; + }); + bool sameTps = std::equal( + typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + return itpp == tpp.tp; }); - bool sameTps = std::equal( - typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) { - return itpp == tpp.tp; - }); - // If the generic parameters and the type arguments are the same, we are about to - // perform an identity substitution, which we can just short-circuit. - if (sameTys && sameTps) - return tf->type; - } + // If the generic parameters and the type arguments are the same, we are about to + // perform an identity substitution, which we can just short-circuit. + if (sameTys && sameTps) + return tf->type; return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); } @@ -5483,7 +5421,7 @@ bool ApplyTypeFunction::isDirty(TypeId ty) return true; else if (const FreeTypeVar* ftv = get(ty)) { - if (FFlag::LuauRecursiveTypeParameterRestriction && ftv->forwardedTypeAlias) + if (ftv->forwardedTypeAlias) encounteredForwardedType = true; return false; } @@ -5562,7 +5500,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, reportError(location, UnificationTooComplex{}); return errorRecoveryType(scope); } - if (FFlag::LuauRecursiveTypeParameterRestriction && applyTypeFunction.encounteredForwardedType) + if (applyTypeFunction.encounteredForwardedType) { reportError(TypeError{location, GenericError{"Recursive type being used with different parameters"}}); return errorRecoveryType(scope); @@ -5632,7 +5570,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st } TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction && (!FFlag::LuauGenericFunctionsDontCacheTypeParams || useCache)) + if (useCache) { TypeId& cached = scope->parent->typeAliasTypeParameters[n]; if (!cached) @@ -5667,21 +5605,12 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st reportError(TypeError{node.location, DuplicateGenericParameter{n}}); } - TypePackId g; - if (FFlag::LuauRecursiveTypeParameterRestriction) - { - TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; - if (!cached) - cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); - g = cached; - } - else - { - g = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); - } + TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; + if (!cached) + cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); - genericPacks.push_back({g, defaultValue}); - scope->privateTypePackBindings[n] = g; + genericPacks.push_back({cached, defaultValue}); + scope->privateTypePackBindings[n] = cached; } return {generics, genericPacks}; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 4d42573..463b465 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -23,7 +23,6 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) LUAU_FASTFLAG(LuauDiscriminableUnions2) LUAU_FASTFLAGVARIABLE(LuauAnyInIsOptionalIsOptional, false) @@ -775,18 +774,12 @@ TypePackId SingletonTypes::errorRecoveryTypePack() TypeId SingletonTypes::errorRecoveryType(TypeId guess) { - if (FFlag::LuauErrorRecoveryType) - return guess; - else - return &errorType_; + return guess; } TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) { - if (FFlag::LuauErrorRecoveryType) - return guess; - else - return &errorTypePack_; + return guess; } SingletonTypes& getSingletonTypes() diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 9862d7b..334806c 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -23,10 +23,7 @@ LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter, false) -LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, false) -LUAU_FASTFLAGVARIABLE(LuauTxnLogCheckForInvalidation, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) -LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAG(LuauTypecheckOptPass) @@ -1021,7 +1018,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (superTp == subTp) return; - if (FFlag::LuauTxnLogSeesTypePacks2 && log.haveSeen(superTp, subTp)) + if (log.haveSeen(superTp, subTp)) return; if (log.getMutable(superTp)) @@ -1265,12 +1262,9 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal log.pushSeen(superFunction->generics[i], subFunction->generics[i]); } - if (FFlag::LuauTxnLogSeesTypePacks2) + for (size_t i = 0; i < numGenericPacks; i++) { - for (size_t i = 0; i < numGenericPacks; i++) - { - log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); - } + log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); } CountMismatch::Context context = ctx; @@ -1330,12 +1324,9 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal ctx = context; - if (FFlag::LuauTxnLogSeesTypePacks2) + for (int i = int(numGenericPacks) - 1; 0 <= i; i--) { - for (int i = int(numGenericPacks) - 1; 0 <= i; i--) - { - log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); - } + log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); } for (int i = int(numGenerics) - 1; 0 <= i; i--) @@ -1499,20 +1490,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else missingProperties.push_back(name); - if (FFlag::LuauTxnLogCheckForInvalidation) + // Recursive unification can change the txn log, and invalidate the old + // table. If we detect that this has happened, we start over, with the updated + // txn log. + TableTypeVar* newSuperTable = log.getMutable(superTy); + TableTypeVar* newSubTable = log.getMutable(subTy); + if (superTable != newSuperTable || subTable != newSubTable) { - // Recursive unification can change the txn log, and invalidate the old - // table. If we detect that this has happened, we start over, with the updated - // txn log. - TableTypeVar* newSuperTable = log.getMutable(superTy); - TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) - { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; - } + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); + else + return; } } @@ -1570,20 +1558,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else extraProperties.push_back(name); - if (FFlag::LuauTxnLogCheckForInvalidation) + // Recursive unification can change the txn log, and invalidate the old + // table. If we detect that this has happened, we start over, with the updated + // txn log. + TableTypeVar* newSuperTable = log.getMutable(superTy); + TableTypeVar* newSubTable = log.getMutable(subTy); + if (superTable != newSuperTable || subTable != newSubTable) { - // Recursive unification can change the txn log, and invalidate the old - // table. If we detect that this has happened, we start over, with the updated - // txn log. - TableTypeVar* newSuperTable = log.getMutable(superTy); - TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) - { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; - } + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); + else + return; } } @@ -1630,27 +1615,9 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } } - if (FFlag::LuauTxnLogDontRetryForIndexers) - { - // Changing the indexer can invalidate the table pointers. - superTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); - } - else if (FFlag::LuauTxnLogCheckForInvalidation) - { - // Recursive unification can change the txn log, and invalidate the old - // table. If we detect that this has happened, we start over, with the updated - // txn log. - TableTypeVar* newSuperTable = log.getMutable(superTy); - TableTypeVar* newSubTable = log.getMutable(subTy); - if (superTable != newSuperTable || subTable != newSubTable) - { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; - } - } + // Changing the indexer can invalidate the table pointers. + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); if (!missingProperties.empty()) { diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 5dd4f04..a1f1d46 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -6,8 +6,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkip, false) - namespace Luau { @@ -361,7 +359,7 @@ const Lexeme& Lexer::next(bool skipComments, bool updatePrevLocation) while (isSpace(peekch())) consume(); - if (!FFlag::LuauParseLocationIgnoreCommentSkip || updatePrevLocation) + if (updatePrevLocation) prevLocation = lexeme.location; lexeme = readNext(); diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index fb6ac37..39a14ec 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -240,7 +240,7 @@ std::optional getParentPath(const std::string& path) return std::nullopt; #endif - std::string::size_type slash = path.find_last_of("\\/", path.size() - 1); + size_t slash = path.find_last_of("\\/", path.size() - 1); if (slash == 0) return "/"; @@ -253,7 +253,7 @@ std::optional getParentPath(const std::string& path) static std::string getExtension(const std::string& path) { - std::string::size_type dot = path.find_last_of(".\\/"); + size_t dot = path.find_last_of(".\\/"); if (dot == std::string::npos || path[dot] != '.') return ""; diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 345cb7a..4cb2234 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -34,7 +34,8 @@ enum class CliMode enum class CompileFormat { Text, - Binary + Binary, + Null }; constexpr int MaxTraversalLimit = 50; @@ -594,6 +595,8 @@ static bool compileFile(const char* name, CompileFormat format) case CompileFormat::Binary: fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); break; + case CompileFormat::Null: + break; } return true; @@ -716,6 +719,10 @@ int replMain(int argc, char** argv) { compileFormat = CompileFormat::Text; } + else if (strcmp(argv[1], "--compile=null") == 0) + { + compileFormat = CompileFormat::Null; + } else { fprintf(stderr, "Error: Unrecognized value for '--compile' specified.\n"); diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index 67b9302..b00440a 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -232,7 +232,7 @@ private: DenseHashMap stringTable; - DenseHashMap debugRemarks; + std::vector> debugRemarks; std::string debugRemarkBuffer; BytecodeEncoder* encoder = nullptr; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 6c6f122..871a148 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -181,7 +181,6 @@ BytecodeBuilder::BytecodeBuilder(BytecodeEncoder* encoder) : constantMap({Constant::Type_Nil, ~0ull}) , tableShapeMap(TableShape()) , stringTable({nullptr, 0}) - , debugRemarks(~0u) , encoder(encoder) { LUAU_ASSERT(stringTable.find(StringRef{"", 0}) == nullptr); @@ -257,6 +256,8 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues) void BytecodeBuilder::setMainFunction(uint32_t fid) { + LUAU_ASSERT(fid < functions.size()); + mainFunction = fid; } @@ -531,7 +532,7 @@ void BytecodeBuilder::addDebugRemark(const char* format, ...) // we null-terminate all remarks to avoid storing remark length debugRemarkBuffer += '\0'; - debugRemarks[uint32_t(insns.size())] = uint32_t(offset); + debugRemarks.emplace_back(uint32_t(insns.size()), uint32_t(offset)); } void BytecodeBuilder::finalize() @@ -1719,6 +1720,7 @@ std::string BytecodeBuilder::dumpCurrentFunction() const const uint32_t* codeEnd = insns.data() + insns.size(); int lastLine = -1; + size_t nextRemark = 0; std::string result; @@ -1741,6 +1743,7 @@ std::string BytecodeBuilder::dumpCurrentFunction() const while (code != codeEnd) { uint8_t op = LUAU_INSN_OP(*code); + uint32_t pc = uint32_t(code - insns.data()); if (op == LOP_PREPVARARGS) { @@ -1751,15 +1754,16 @@ std::string BytecodeBuilder::dumpCurrentFunction() const if (dumpFlags & Dump_Remarks) { - const uint32_t* remark = debugRemarks.find(uint32_t(code - insns.data())); - - if (remark) - formatAppend(result, "REMARK %s\n", debugRemarkBuffer.c_str() + *remark); + while (nextRemark < debugRemarks.size() && debugRemarks[nextRemark].first == pc) + { + formatAppend(result, "REMARK %s\n", debugRemarkBuffer.c_str() + debugRemarks[nextRemark].second); + nextRemark++; + } } if (dumpFlags & Dump_Source) { - int line = lines[code - insns.data()]; + int line = lines[pc]; if (line > 0 && line != lastLine) { @@ -1771,7 +1775,7 @@ std::string BytecodeBuilder::dumpCurrentFunction() const if (dumpFlags & Dump_Lines) { - formatAppend(result, "%d: ", lines[code - insns.data()]); + formatAppend(result, "%d: ", lines[pc]); } code = dumpInstruction(code, result); @@ -1784,11 +1788,11 @@ void BytecodeBuilder::setDumpSource(const std::string& source) { dumpSource.clear(); - std::string::size_type pos = 0; + size_t pos = 0; while (pos != std::string::npos) { - std::string::size_type next = source.find('\n', pos); + size_t next = source.find('\n', pos); if (next == std::string::npos) { diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 810caae..0f17ee0 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -2206,9 +2206,15 @@ struct Compiler return false; } + if (Variable* lv = variables.find(stat->var); lv && lv->written) + { + bytecode.addDebugRemark("loop unroll failed: mutable loop variable"); + return false; + } + int tripCount = (to - from) / step + 1; - if (tripCount > thresholdBase * thresholdMaxBoost / 100) + if (tripCount > thresholdBase) { bytecode.addDebugRemark("loop unroll failed: too many iterations (%d)", tripCount); return false; diff --git a/Compiler/src/CostModel.cpp b/Compiler/src/CostModel.cpp index d8511bd..9afd09f 100644 --- a/Compiler/src/CostModel.cpp +++ b/Compiler/src/CostModel.cpp @@ -249,7 +249,7 @@ int computeCost(uint64_t model, const bool* varsConst, size_t varCount) return cost; for (size_t i = 0; i < varCount && i < 7; ++i) - cost -= int((model >> (8 * i + 8)) & 0x7f) * varsConst[i]; + cost -= int((model >> (i * 8 + 8)) & 0x7f) * varsConst[i]; return cost; } diff --git a/Sources.cmake b/Sources.cmake index 60e5dfd..f9263b2 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -220,8 +220,8 @@ if(TARGET Luau.UnitTest) tests/Autocomplete.test.cpp tests/BuiltinDefinitions.test.cpp tests/Compiler.test.cpp - tests/CostModel.test.cpp tests/Config.test.cpp + tests/CostModel.test.cpp tests/Error.test.cpp tests/Frontend.test.cpp tests/JsonEncoder.test.cpp @@ -232,6 +232,7 @@ if(TARGET Luau.UnitTest) tests/Normalize.test.cpp tests/Parser.test.cpp tests/RequireTracer.test.cpp + tests/RuntimeLimits.test.cpp tests/StringUtils.test.cpp tests/Symbol.test.cpp tests/ToDot.test.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index d08b73e..c3ebadb 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -299,7 +299,7 @@ LUA_API uintptr_t lua_encodepointer(lua_State* L, uintptr_t p); LUA_API double lua_clock(); -LUA_API void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(void*)); +LUA_API void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(lua_State*, void*)); LUA_API void lua_clonefunction(lua_State* L, int idx); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 431f7e5..1f3b094 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -1323,7 +1323,7 @@ void lua_unref(lua_State* L, int ref) return; } -void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(void*)) +void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(lua_State*, void*)) { api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); L->global->udatagc[tag] = dtor; diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 45d9ba2..423514a 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -200,7 +200,7 @@ typedef struct global_State uint64_t rngstate; /* PCG random number generator state */ uint64_t ptrenckey[4]; /* pointer encoding key for display */ - void (*udatagc[LUA_UTAG_LIMIT])(void*); /* for each userdata tag, a gc callback to be called immediately before freeing memory */ + void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); /* for each userdata tag, a gc callback to be called immediately before freeing memory */ lua_Callbacks cb; diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index dc40b6e..3dc3bd1 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -33,7 +33,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauTableRehashRework, false) LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary2, false) // max size of both array and hash part is 2^MAXBITS @@ -400,16 +399,9 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) { if (!ttisnil(&t->array[i])) { - if (FFlag::LuauTableRehashRework) - { - TValue ok; - setnvalue(&ok, cast_num(i + 1)); - setobjt2t(L, newkey(L, t, &ok), &t->array[i]); - } - else - { - setobjt2t(L, luaH_setnum(L, t, i + 1), &t->array[i]); - } + TValue ok; + setnvalue(&ok, cast_num(i + 1)); + setobjt2t(L, newkey(L, t, &ok), &t->array[i]); } } /* shrink array */ @@ -418,30 +410,14 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) /* used for the migration check at the end */ TValue* anew = t->array; /* re-insert elements from hash part */ - if (FFlag::LuauTableRehashRework) + for (int i = twoto(oldhsize) - 1; i >= 0; i--) { - for (int i = twoto(oldhsize) - 1; i >= 0; i--) + LuaNode* old = nold + i; + if (!ttisnil(gval(old))) { - LuaNode* old = nold + i; - if (!ttisnil(gval(old))) - { - TValue ok; - getnodekey(L, &ok, old); - setobjt2t(L, arrayornewkey(L, t, &ok), gval(old)); - } - } - } - else - { - for (int i = twoto(oldhsize) - 1; i >= 0; i--) - { - LuaNode* old = nold + i; - if (!ttisnil(gval(old))) - { - TValue ok; - getnodekey(L, &ok, old); - setobjt2t(L, luaH_set(L, t, &ok), gval(old)); - } + TValue ok; + getnodekey(L, &ok, old); + setobjt2t(L, arrayornewkey(L, t, &ok), gval(old)); } } @@ -559,7 +535,7 @@ static TValue* newkey(lua_State* L, Table* t, const TValue* key) { rehash(L, t, key); /* grow table */ - // after rehash, numeric keys might be located in the new array part, but won't be found in the node part + /* after rehash, numeric keys might be located in the new array part, but won't be found in the node part */ return arrayornewkey(L, t, key); } @@ -571,15 +547,8 @@ static TValue* newkey(lua_State* L, Table* t, const TValue* key) { /* cannot find a free place? */ rehash(L, t, key); /* grow table */ - if (!FFlag::LuauTableRehashRework) - { - return luaH_set(L, t, key); /* re-insert key into grown table */ - } - else - { - // after rehash, numeric keys might be located in the new array part, but won't be found in the node part - return arrayornewkey(L, t, key); - } + /* after rehash, numeric keys might be located in the new array part, but won't be found in the node part */ + return arrayornewkey(L, t, key); } LUAU_ASSERT(n != dummynode); TValue mk; diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp index 819d186..2815268 100644 --- a/VM/src/ludata.cpp +++ b/VM/src/ludata.cpp @@ -22,14 +22,21 @@ Udata* luaU_newudata(lua_State* L, size_t s, int tag) void luaU_freeudata(lua_State* L, Udata* u, lua_Page* page) { - void (*dtor)(void*) = nullptr; if (u->tag < LUA_UTAG_LIMIT) + { + void (*dtor)(lua_State*, void*) = nullptr; dtor = L->global->udatagc[u->tag]; + if (dtor) + dtor(L, u->data); + } else if (u->tag == UTAG_IDTOR) + { + void (*dtor)(void*) = nullptr; memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); + if (dtor) + dtor(u->data); + } - if (dtor) - dtor(u->data); luaM_freegco(L, u, sizeudata(u->len), u->memcat, page); } diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index a48f068..22483f9 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -137,6 +137,21 @@ int registerTypes(Luau::TypeChecker& env) return 0; } + +static void setupFrontend(Luau::Frontend& frontend) +{ + registerTypes(frontend.typeChecker); + Luau::freeze(frontend.typeChecker.globalTypes); + + registerTypes(frontend.typeCheckerForAutocomplete); + Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + + frontend.iceHandler.onInternalError = [](const char* error) { + printf("ICE: %s\n", error); + LUAU_ASSERT(!"ICE"); + }; +} + struct FuzzFileResolver : Luau::FileResolver { std::optional readSource(const Luau::ModuleName& name) override @@ -238,19 +253,11 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) if (kFuzzTypeck) { static FuzzFileResolver fileResolver; - static Luau::NullConfigResolver configResolver; + static FuzzConfigResolver configResolver; static Luau::FrontendOptions options{true, true}; static Luau::Frontend frontend(&fileResolver, &configResolver, options); - static int once = registerTypes(frontend.typeChecker); - (void)once; - static int once2 = (Luau::freeze(frontend.typeChecker.globalTypes), 0); - (void)once2; - - frontend.iceHandler.onInternalError = [](const char* error) { - printf("ICE: %s\n", error); - LUAU_ASSERT(!"ICE"); - }; + static int once = (setupFrontend(frontend), 0); // restart frontend.clear(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index f66e23e..5b70481 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2761,7 +2761,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") { ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; - ScopedFastFlag luauExpectedTypesOfProperties{"LuauExpectedTypesOfProperties", true}; check(R"( type tag = "cat" | "dog" diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index f3e6069..7b4bfc7 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -2698,16 +2698,22 @@ TEST_CASE("DebugRemarks") uint32_t fid = bcb.beginFunction(0); - bcb.addDebugRemark("test remark #%d", 42); + bcb.addDebugRemark("test remark #%d", 1); + bcb.emitABC(LOP_LOADNIL, 0, 0, 0); + bcb.addDebugRemark("test remark #%d", 2); + bcb.addDebugRemark("test remark #%d", 3); bcb.emitABC(LOP_RETURN, 0, 1, 0); - bcb.endFunction(0, 0); + bcb.endFunction(1, 0); bcb.setMainFunction(fid); bcb.finalize(); CHECK_EQ("\n" + bcb.dumpFunction(0), R"( -REMARK test remark #42 +REMARK test remark #1 +LOADNIL R0 +REMARK test remark #2 +REMARK test remark #3 RETURN R0 0 )"); } @@ -4332,7 +4338,7 @@ RETURN R0 1 // loops with body that's long but has a high boost factor due to constant folding CHECK_EQ("\n" + compileFunction(R"( local t = {} -for i=1,30 do +for i=1,25 do t[i] = i * i * i end return t @@ -4390,16 +4396,6 @@ LOADN R1 13824 SETTABLEN R1 R0 24 LOADN R1 15625 SETTABLEN R1 R0 25 -LOADN R1 17576 -SETTABLEN R1 R0 26 -LOADN R1 19683 -SETTABLEN R1 R0 27 -LOADN R1 21952 -SETTABLEN R1 R0 28 -LOADN R1 24389 -SETTABLEN R1 R0 29 -LOADN R1 27000 -SETTABLEN R1 R0 30 RETURN R0 1 )"); @@ -4431,4 +4427,30 @@ RETURN R0 1 )"); } +TEST_CASE("LoopUnrollMutable") +{ + // can't unroll loops that mutate iteration variable + CHECK_EQ("\n" + compileFunction(R"( +for i=1,3 do + i = 3 + print(i) -- should print 3 three times in a row +end +)", + 0, 2), + R"( +LOADN R2 1 +LOADN R0 3 +LOADN R1 1 +FORNPREP R0 +7 +MOVE R3 R2 +LOADN R3 3 +GETIMPORT R4 1 +MOVE R5 R3 +CALL R4 1 0 +FORNLOOP R0 -7 +RETURN R0 0 +)"); +} + + TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 0ed7dc4..6f136d3 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -1056,7 +1056,7 @@ TEST_CASE("UserdataApi") lua_State* L = globalState.get(); // setup dtor for tag 42 (created later) - lua_setuserdatadtor(L, 42, [](void* data) { + lua_setuserdatadtor(L, 42, [](lua_State* l, void* data) { dtorhits += *(int*)data; }); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 9fc0a00..e771b6b 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -975,8 +975,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "typecheck_twice_for_ast_types") TEST_CASE_FIXTURE(FrontendFixture, "imported_table_modification_2") { - ScopedFastFlag sffs("LuauSealExports", true); - frontend.options.retainFullTypeGraphs = false; fileResolver.source["Module/A"] = R"( @@ -1035,4 +1033,20 @@ return false; fix.frontend.check("Module/B"); } +TEST_CASE("check_without_builtin_next") +{ + ScopedFastFlag luauDoNotRelyOnNextBinding{"LuauDoNotRelyOnNextBinding", true}; + + TestFileResolver fileResolver; + TestConfigResolver configResolver; + Frontend frontend(&fileResolver, &configResolver); + + fileResolver.source["Module/A"] = "for k,v in 2 do end"; + fileResolver.source["Module/B"] = "return next"; + + // We don't care about the result. That we haven't crashed is enough. + frontend.check("Module/A"); + frontend.check("Module/B"); +} + TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index af7d76d..44cc20a 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -199,7 +199,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") TEST_CASE_FIXTURE(Fixture, "clone_free_types") { ScopedFastFlag sff[]{ - {"LuauErrorRecoveryType", true}, {"LuauLosslessClone", true}, }; diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index 9748eb2..feeaf2c 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -283,7 +283,6 @@ TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") ScopedFastFlag sff[]{ {"LuauReturnTypeInferenceInNonstrict", true}, {"LuauLowerBoundsCalculation", true}, - {"LuauSealExports", true}, }; CheckResult result = check(R"( diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index b941103..55eafe3 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1606,8 +1606,6 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_of_functions_unions_and_intersections") TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments") { - ScopedFastFlag luauParseLocationIgnoreCommentSkip{"LuauParseLocationIgnoreCommentSkip", true}; - AstStatBlock* block = parse(R"( type F = number --comment @@ -1620,7 +1618,6 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments") TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments_even_with_capture") { - ScopedFastFlag luauParseLocationIgnoreCommentSkip{"LuauParseLocationIgnoreCommentSkip", true}; ScopedFastFlag luauParseLocationIgnoreCommentSkipInCapture{"LuauParseLocationIgnoreCommentSkipInCapture", true}; // Same should hold when comments are captured diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp new file mode 100644 index 0000000..0b615e1 --- /dev/null +++ b/tests/RuntimeLimits.test.cpp @@ -0,0 +1,270 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +/* Tests in this source file are meant to be a bellwether to verify that the numeric limits we've set are sufficient for + * most real-world scripts. + * + * If a change breaks a test in this source file, please don't adjust the flag values set in the fixture. Instead, + * consider it a latent performance problem by default. + * + * We should periodically revisit this to retest the limits. + */ + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauLowerBoundsCalculation); + +struct LimitFixture : Fixture +{ +#if defined(_NOOPT) + ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 150}; +#endif + + ScopedFastFlag LuauJustOneCallFrameForHaveSeen{"LuauJustOneCallFrameForHaveSeen", true}; +}; + +template +bool hasError(const CheckResult& result, T* = nullptr) +{ + auto it = std::find_if(result.errors.begin(), result.errors.end(), [](const TypeError& a) { + return nullptr != get(a); + }); + return it != result.errors.end(); +} + +TEST_SUITE_BEGIN("RuntimeLimitTests"); + +TEST_CASE_FIXTURE(LimitFixture, "bail_early_on_typescript_port_of_Result_type" * doctest::timeout(1.0)) +{ + constexpr const char* src = R"LUA( + --!strict + local TS = _G[script] + local lazyGet = TS.import(script, script.Parent.Parent, "util", "lazyLoad").lazyGet + local unit = TS.import(script, script.Parent.Parent, "util", "Unit").unit + local Iterator + lazyGet("Iterator", function(c) + Iterator = c + end) + local Option + lazyGet("Option", function(c) + Option = c + end) + local Vec + lazyGet("Vec", function(c) + Vec = c + end) + local Result + do + Result = setmetatable({}, { + __tostring = function() + return "Result" + end, + }) + Result.__index = Result + function Result.new(...) + local self = setmetatable({}, Result) + self:constructor(...) + return self + end + function Result:constructor(okValue, errValue) + self.okValue = okValue + self.errValue = errValue + end + function Result:ok(val) + return Result.new(val, nil) + end + function Result:err(val) + return Result.new(nil, val) + end + function Result:fromCallback(c) + local _0 = c + local _1, _2 = pcall(_0) + local result = _1 and { + success = true, + value = _2, + } or { + success = false, + error = _2, + } + return result.success and Result:ok(result.value) or Result:err(Option:wrap(result.error)) + end + function Result:fromVoidCallback(c) + local _0 = c + local _1, _2 = pcall(_0) + local result = _1 and { + success = true, + value = _2, + } or { + success = false, + error = _2, + } + return result.success and Result:ok(unit()) or Result:err(Option:wrap(result.error)) + end + Result.fromPromise = TS.async(function(self, p) + local _0, _1 = TS.try(function() + return TS.TRY_RETURN, { Result:ok(TS.await(p)) } + end, function(e) + return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } + end) + if _0 then + return unpack(_1) + end + end) + Result.fromVoidPromise = TS.async(function(self, p) + local _0, _1 = TS.try(function() + TS.await(p) + return TS.TRY_RETURN, { Result:ok(unit()) } + end, function(e) + return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } + end) + if _0 then + return unpack(_1) + end + end) + function Result:isOk() + return self.okValue ~= nil + end + function Result:isErr() + return self.errValue ~= nil + end + function Result:contains(x) + return self.okValue == x + end + function Result:containsErr(x) + return self.errValue == x + end + function Result:okOption() + return Option:wrap(self.okValue) + end + function Result:errOption() + return Option:wrap(self.errValue) + end + function Result:map(func) + return self:isOk() and Result:ok(func(self.okValue)) or Result:err(self.errValue) + end + function Result:mapOr(def, func) + local _0 + if self:isOk() then + _0 = func(self.okValue) + else + _0 = def + end + return _0 + end + function Result:mapOrElse(def, func) + local _0 + if self:isOk() then + _0 = func(self.okValue) + else + _0 = def(self.errValue) + end + return _0 + end + function Result:mapErr(func) + return self:isErr() and Result:err(func(self.errValue)) or Result:ok(self.okValue) + end + Result["and"] = function(self, other) + return self:isErr() and Result:err(self.errValue) or other + end + function Result:andThen(func) + return self:isErr() and Result:err(self.errValue) or func(self.okValue) + end + Result["or"] = function(self, other) + return self:isOk() and Result:ok(self.okValue) or other + end + function Result:orElse(other) + return self:isOk() and Result:ok(self.okValue) or other(self.errValue) + end + function Result:expect(msg) + if self:isOk() then + return self.okValue + else + error(msg) + end + end + function Result:unwrap() + return self:expect("called `Result.unwrap()` on an `Err` value: " .. tostring(self.errValue)) + end + function Result:unwrapOr(def) + local _0 + if self:isOk() then + _0 = self.okValue + else + _0 = def + end + return _0 + end + function Result:unwrapOrElse(gen) + local _0 + if self:isOk() then + _0 = self.okValue + else + _0 = gen(self.errValue) + end + return _0 + end + function Result:expectErr(msg) + if self:isErr() then + return self.errValue + else + error(msg) + end + end + function Result:unwrapErr() + return self:expectErr("called `Result.unwrapErr()` on an `Ok` value: " .. tostring(self.okValue)) + end + function Result:transpose() + return self:isOk() and self.okValue:map(function(some) + return Result:ok(some) + end) or Option:some(Result:err(self.errValue)) + end + function Result:flatten() + return self:isOk() and Result.new(self.okValue.okValue, self.okValue.errValue) or Result:err(self.errValue) + end + function Result:match(ifOk, ifErr) + local _0 + if self:isOk() then + _0 = ifOk(self.okValue) + else + _0 = ifErr(self.errValue) + end + return _0 + end + function Result:asPtr() + local _0 = (self.okValue) + if _0 == nil then + _0 = (self.errValue) + end + return _0 + end + end + local resultMeta = Result + resultMeta.__eq = function(a, b) + return b:match(function(ok) + return a:contains(ok) + end, function(err) + return a:containsErr(err) + end) + end + resultMeta.__tostring = function(result) + return result:match(function(ok) + return "Result.ok(" .. tostring(ok) .. ")" + end, function(err) + return "Result.err(" .. tostring(err) .. ")" + end) + end + return { + Result = Result, + } + )LUA"; + + if (FFlag::LuauLowerBoundsCalculation) + (void)check(src); + else + CHECK_THROWS_AS(check(src), std::exception); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index b2e7605..b0eb31c 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -7,8 +7,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauFixIncorrectLineNumberDuplicateType) - TEST_SUITE_BEGIN("TypeAliases"); TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") @@ -257,11 +255,7 @@ TEST_CASE_FIXTURE(Fixture, "reported_location_is_correct_when_type_alias_are_dup auto dtd = get(result.errors[0]); REQUIRE(dtd); CHECK_EQ(dtd->name, "B"); - - if (FFlag::LuauFixIncorrectLineNumberDuplicateType) - CHECK_EQ(dtd->previousLocation.begin.line + 1, 3); - else - CHECK_EQ(dtd->previousLocation.begin.line + 1, 1); + CHECK_EQ(dtd->previousLocation.begin.line + 1, 3); } TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") @@ -495,8 +489,6 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") { - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - CheckResult result = check(R"( -- OK because forwarded types are used with their parameters. type Tree = { data: T, children: Forest } @@ -508,8 +500,6 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") { - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - CheckResult result = check(R"( -- Not OK because forwarded types are used with different types than their parameters. type Forest = {Tree<{T}>} @@ -531,8 +521,6 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") { - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - CheckResult result = check(R"( type Tree1 = { data: T, children: {Tree2} } type Tree2 = { data: U, children: {Tree1} } @@ -647,9 +635,6 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni { ScopedFastFlag sff[] = { {"LuauTwoPassAliasDefinitionFix", true}, - - // We also force this flag because it surfaced an unfortunate interaction. - {"LuauErrorRecoveryType", true}, }; CheckResult result = check(R"( @@ -687,8 +672,6 @@ TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") { - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - CheckResult result = check(R"( -- this would be an infinite type if we allowed it type Tree = { data: T, children: {Tree<{T}>} } diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index e2971ad..7f1c757 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -221,8 +221,6 @@ TEST_CASE_FIXTURE(Fixture, "as_expr_is_bidirectional") TEST_CASE_FIXTURE(Fixture, "as_expr_warns_on_unrelated_cast") { - ScopedFastFlag sff2{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( local a = 55 :: string )"); @@ -407,8 +405,6 @@ TEST_CASE_FIXTURE(Fixture, "typeof_expr") TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( type A = B type B = A diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 7cd7bec..0e07121 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -951,8 +951,6 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( type Overload = ((string) -> string) & ((number, number) -> number) local abc: Overload @@ -1538,7 +1536,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") { - ScopedFastFlag sff{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; CheckResult result = check(R"( function test(a: number, b: string, ...) end @@ -1560,8 +1557,6 @@ TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic") { - ScopedFastFlag sff1{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; - ScopedFastFlag sff2{"LuauFixArgumentCountMismatchAmountWithGenericTypes", true}; CheckResult result = check(R"( function test(a: number, b: string, ...) return 1 @@ -1587,8 +1582,6 @@ wrapper(test) TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic2") { - ScopedFastFlag sff1{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; - ScopedFastFlag sff2{"LuauFixArgumentCountMismatchAmountWithGenericTypes", true}; CheckResult result = check(R"( function test(a: number, b: string, ...) return 1 diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 49d31fc..91be2c1 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -11,8 +11,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauFixArgumentCountMismatchAmountWithGenericTypes) - TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") @@ -679,8 +677,6 @@ local d: D = c TEST_CASE_FIXTURE(Fixture, "generic_functions_dont_cache_type_parameters") { - ScopedFastFlag sff{"LuauGenericFunctionsDontCacheTypeParams", true}; - CheckResult result = check(R"( -- See https://github.com/Roblox/luau/issues/332 -- This function has a type parameter with the same name as clones, @@ -707,8 +703,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_functions_should_be_memory_safe") ScopedFastFlag sffs[] = { {"LuauTableSubtypingVariance2", true}, {"LuauUnsealedTableLiteral", true}, - {"LuauPropertiesGetExpectedType", true}, - {"LuauRecursiveTypeParameterRestriction", true}, }; CheckResult result = check(R"( @@ -733,8 +727,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification1") { - ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; - CheckResult result = check(R"( --!strict type Dispatcher = { @@ -753,8 +745,6 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification2") { - ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; - CheckResult result = check(R"( --!strict type Dispatcher = { @@ -773,8 +763,6 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification3") { - ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; - CheckResult result = check(R"( --!strict type Dispatcher = { @@ -805,11 +793,7 @@ wrapper(test) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 1 argument, but 1 is specified)"); + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); } TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_many") @@ -826,11 +810,7 @@ wrapper(test2, 1, "", 3) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 3 arguments, but 4 are specified)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 1 argument, but 4 are specified)"); + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 3 arguments, but 4 are specified)"); } TEST_CASE_FIXTURE(Fixture, "generic_function") diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 30df717..960c6ed 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -78,6 +78,8 @@ TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") { + ScopedFastFlag luauDoNotRelyOnNextBinding{"LuauDoNotRelyOnNextBinding", true}; + CheckResult result = check(R"( local foo = "bar" for i, v in foo do @@ -85,6 +87,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Cannot call non-function string", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "for_in_with_just_one_iterator_is_ok") @@ -470,4 +473,19 @@ TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") LUAU_REQUIRE_ERROR_COUNT(2, result); } +TEST_CASE_FIXTURE(Fixture, "fuzz_fail_missing_instantitation_follow") +{ + ScopedFastFlag luauInstantiateFollows{"LuauInstantiateFollows", true}; + + // Just check that this doesn't assert + check(R"( + --!nonstrict + function _(l0:number) + return _ + end + for _ in _(8) do + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 5f2e240..a2787ca 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -142,8 +142,6 @@ TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( --!strict local Vec3 = {} @@ -178,8 +176,6 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( --!strict local Vec3 = {} diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index 3ddf981..e1684df 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -85,8 +85,6 @@ TEST_CASE_FIXTURE(Fixture, "string_function_other") TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( local x: number = 9999 function x:y(z: number) diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 4b5075d..2ef7741 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -268,242 +268,6 @@ TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doct } } -TEST_CASE_FIXTURE(Fixture, "bail_early_on_typescript_port_of_Result_type" * doctest::timeout(1.0)) -{ - ScopedFastInt sffi{"LuauTarjanChildLimit", 400}; - - CheckResult result = check(R"LUA( - --!strict - local TS = _G[script] - local lazyGet = TS.import(script, script.Parent.Parent, "util", "lazyLoad").lazyGet - local unit = TS.import(script, script.Parent.Parent, "util", "Unit").unit - local Iterator - lazyGet("Iterator", function(c) - Iterator = c - end) - local Option - lazyGet("Option", function(c) - Option = c - end) - local Vec - lazyGet("Vec", function(c) - Vec = c - end) - local Result - do - Result = setmetatable({}, { - __tostring = function() - return "Result" - end, - }) - Result.__index = Result - function Result.new(...) - local self = setmetatable({}, Result) - self:constructor(...) - return self - end - function Result:constructor(okValue, errValue) - self.okValue = okValue - self.errValue = errValue - end - function Result:ok(val) - return Result.new(val, nil) - end - function Result:err(val) - return Result.new(nil, val) - end - function Result:fromCallback(c) - local _0 = c - local _1, _2 = pcall(_0) - local result = _1 and { - success = true, - value = _2, - } or { - success = false, - error = _2, - } - return result.success and Result:ok(result.value) or Result:err(Option:wrap(result.error)) - end - function Result:fromVoidCallback(c) - local _0 = c - local _1, _2 = pcall(_0) - local result = _1 and { - success = true, - value = _2, - } or { - success = false, - error = _2, - } - return result.success and Result:ok(unit()) or Result:err(Option:wrap(result.error)) - end - Result.fromPromise = TS.async(function(self, p) - local _0, _1 = TS.try(function() - return TS.TRY_RETURN, { Result:ok(TS.await(p)) } - end, function(e) - return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } - end) - if _0 then - return unpack(_1) - end - end) - Result.fromVoidPromise = TS.async(function(self, p) - local _0, _1 = TS.try(function() - TS.await(p) - return TS.TRY_RETURN, { Result:ok(unit()) } - end, function(e) - return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } - end) - if _0 then - return unpack(_1) - end - end) - function Result:isOk() - return self.okValue ~= nil - end - function Result:isErr() - return self.errValue ~= nil - end - function Result:contains(x) - return self.okValue == x - end - function Result:containsErr(x) - return self.errValue == x - end - function Result:okOption() - return Option:wrap(self.okValue) - end - function Result:errOption() - return Option:wrap(self.errValue) - end - function Result:map(func) - return self:isOk() and Result:ok(func(self.okValue)) or Result:err(self.errValue) - end - function Result:mapOr(def, func) - local _0 - if self:isOk() then - _0 = func(self.okValue) - else - _0 = def - end - return _0 - end - function Result:mapOrElse(def, func) - local _0 - if self:isOk() then - _0 = func(self.okValue) - else - _0 = def(self.errValue) - end - return _0 - end - function Result:mapErr(func) - return self:isErr() and Result:err(func(self.errValue)) or Result:ok(self.okValue) - end - Result["and"] = function(self, other) - return self:isErr() and Result:err(self.errValue) or other - end - function Result:andThen(func) - return self:isErr() and Result:err(self.errValue) or func(self.okValue) - end - Result["or"] = function(self, other) - return self:isOk() and Result:ok(self.okValue) or other - end - function Result:orElse(other) - return self:isOk() and Result:ok(self.okValue) or other(self.errValue) - end - function Result:expect(msg) - if self:isOk() then - return self.okValue - else - error(msg) - end - end - function Result:unwrap() - return self:expect("called `Result.unwrap()` on an `Err` value: " .. tostring(self.errValue)) - end - function Result:unwrapOr(def) - local _0 - if self:isOk() then - _0 = self.okValue - else - _0 = def - end - return _0 - end - function Result:unwrapOrElse(gen) - local _0 - if self:isOk() then - _0 = self.okValue - else - _0 = gen(self.errValue) - end - return _0 - end - function Result:expectErr(msg) - if self:isErr() then - return self.errValue - else - error(msg) - end - end - function Result:unwrapErr() - return self:expectErr("called `Result.unwrapErr()` on an `Ok` value: " .. tostring(self.okValue)) - end - function Result:transpose() - return self:isOk() and self.okValue:map(function(some) - return Result:ok(some) - end) or Option:some(Result:err(self.errValue)) - end - function Result:flatten() - return self:isOk() and Result.new(self.okValue.okValue, self.okValue.errValue) or Result:err(self.errValue) - end - function Result:match(ifOk, ifErr) - local _0 - if self:isOk() then - _0 = ifOk(self.okValue) - else - _0 = ifErr(self.errValue) - end - return _0 - end - function Result:asPtr() - local _0 = (self.okValue) - if _0 == nil then - _0 = (self.errValue) - end - return _0 - end - end - local resultMeta = Result - resultMeta.__eq = function(a, b) - return b:match(function(ok) - return a:contains(ok) - end, function(err) - return a:containsErr(err) - end) - end - resultMeta.__tostring = function(result) - return result:match(function(ok) - return "Result.ok(" .. tostring(ok) .. ")" - end, function(err) - return "Result.err(" .. tostring(err) .. ")" - end) - end - return { - Result = Result, - } - )LUA"); - - auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& a) { - return nullptr != get(a); - }); - if (it == result.errors.end()) - { - dumpErrors(result); - FAIL("Expected a UnificationTooComplex error"); - } -} - // Should be in TypeInfer.tables.test.cpp // It's unsound to instantiate tables containing generic methods, // since mutating properties means table properties should be invariant. diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 2b01c29..8d6682b 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -164,10 +164,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping") TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons") { - ScopedFastFlag sffs[] = { - {"LuauExpectedTypesOfProperties", true}, - }; - CheckResult result = check(R"( type Dog = { tag: "Dog", howls: boolean } type Cat = { tag: "Cat", meows: boolean } @@ -281,10 +277,6 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string") { - ScopedFastFlag sffs[] = { - {"LuauExpectedTypesOfProperties", true}, - }; - CheckResult result = check(R"( type Cat = { tag: 'cat', catfood: string } type Dog = { tag: 'dog', dogfood: string } @@ -302,10 +294,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") { - ScopedFastFlag sffs[] = { - {"LuauExpectedTypesOfProperties", true}, - }; - CheckResult result = check(R"( type Good = { success: true, result: string } type Bad = { success: false, error: string } @@ -323,10 +311,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") { - ScopedFastFlag sffs[] = { - {"LuauExpectedTypesOfProperties", true}, - }; - CheckResult result = check(R"( type Cat = { tag: 'cat', catfood: string } type Dog = { tag: 'dog', dogfood: string } diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 2a727bb..5bd522a 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2122,8 +2122,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") { ScopedFastFlag sffs[]{ - {"LuauPropertiesGetExpectedType", true}, - {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, }; @@ -2143,8 +2141,6 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") { ScopedFastFlag sffs[]{ - {"LuauPropertiesGetExpectedType", true}, - {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, {"LuauUnsealedTableLiteral", true}, }; @@ -2171,8 +2167,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") { ScopedFastFlag sffs[]{ - {"LuauPropertiesGetExpectedType", true}, - {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, }; @@ -2377,8 +2371,6 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a TEST_CASE_FIXTURE(Fixture, "unifying_tables_shouldnt_uaf1") { - ScopedFastFlag sff{"LuauTxnLogCheckForInvalidation", true}; - CheckResult result = check(R"( -- This example produced a UAF at one point, caused by pointers to table types becoming -- invalidated by child unifiers. (Calling log.concat can cause pointers to become invalid.) @@ -2409,8 +2401,6 @@ end TEST_CASE_FIXTURE(Fixture, "unifying_tables_shouldnt_uaf2") { - ScopedFastFlag sff{"LuauTxnLogCheckForInvalidation", true}; - CheckResult result = check(R"( -- Another example that UAFd, this time found by fuzzing. local _ diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index c21e162..b6e9326 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -126,8 +126,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") { - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - CheckResult result = check(R"( function f(arg: number) return arg end local a diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index d03bb03..e033fe2 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -184,8 +184,6 @@ TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_empty_union") TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") { - ScopedFastFlag sff{"LuauSealExports", true}; - TypeVar ftv11{FreeTypeVar{TypeLevel{}}}; TypePackVar tp24{TypePack{{&ftv11}}}; From 0d6481b9df2472c1bd6267d3f6ea7cc2a0a38cab Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 28 Apr 2022 18:10:31 -0700 Subject: [PATCH 06/19] Fix tests in debug --- tests/RuntimeLimits.test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index 0b615e1..d7e2246 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -19,7 +19,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); struct LimitFixture : Fixture { -#if defined(_NOOPT) +#if defined(_NOOPT) || defined(_DEBUG) ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 150}; #endif From 51ae97c211f6818aa59c1b9ab06e97b22817c388 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 28 Apr 2022 18:15:04 -0700 Subject: [PATCH 07/19] We also need to lower the limit --- tests/RuntimeLimits.test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index d7e2246..dcbf0b6 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -20,7 +20,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); struct LimitFixture : Fixture { #if defined(_NOOPT) || defined(_DEBUG) - ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 150}; + ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100}; #endif ScopedFastFlag LuauJustOneCallFrameForHaveSeen{"LuauJustOneCallFrameForHaveSeen", true}; From bb57bf96035b11a6afdf5bfb5055616aece55203 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 5 May 2022 16:52:48 -0700 Subject: [PATCH 08/19] Sync to upstream/release/526 --- Analysis/include/Luau/Frontend.h | 1 - Analysis/include/Luau/VisitTypeVar.h | 319 +++++++- Analysis/src/Autocomplete.cpp | 54 +- Analysis/src/Frontend.cpp | 34 +- Analysis/src/Normalize.cpp | 329 ++++++-- Analysis/src/Quantify.cpp | 50 +- Analysis/src/ToString.cpp | 50 +- Analysis/src/TxnLog.cpp | 34 +- Analysis/src/TypeInfer.cpp | 125 ++-- Analysis/src/Unifier.cpp | 134 +++- Ast/include/Luau/Ast.h | 2 +- Ast/src/Parser.cpp | 5 +- CLI/Repl.cpp | 5 + Compiler/include/Luau/Bytecode.h | 5 + Compiler/src/BytecodeBuilder.cpp | 10 + Compiler/src/Compiler.cpp | 338 ++++++++- Compiler/src/ConstantFolding.cpp | 53 +- Sources.cmake | 1 + VM/src/lapi.cpp | 2 +- VM/src/lbuiltins.cpp | 4 +- VM/src/lgc.h | 2 +- VM/src/ltable.cpp | 48 +- VM/src/ltm.cpp | 4 +- VM/src/ltm.h | 3 +- VM/src/lvmexecute.cpp | 196 ++++- .../test_LargeTableSum_loop_iter.lua | 17 + bench/tests/sunspider/3d-cube.lua | 30 +- bench/tests/sunspider/3d-morph.lua | 2 +- bench/tests/sunspider/3d-raytrace.lua | 44 +- bench/tests/sunspider/access-binary-trees.lua | 69 -- .../tests/sunspider/controlflow-recursive.lua | 8 +- bench/tests/sunspider/crypto-aes.lua | 148 ++-- bench/tests/sunspider/math-cordic.lua | 10 +- bench/tests/sunspider/math-partial-sums.lua | 2 +- bench/tests/sunspider/math-spectral-norm.lua | 72 -- tests/Autocomplete.test.cpp | 8 - tests/Compiler.test.cpp | 708 +++++++++++++++++- tests/Conformance.test.cpp | 12 + tests/Frontend.test.cpp | 4 - tests/Parser.test.cpp | 4 - tests/RuntimeLimits.test.cpp | 13 +- tests/TypeInfer.loops.test.cpp | 67 ++ tests/TypeInfer.modules.test.cpp | 1 - tests/TypeInfer.tables.test.cpp | 4 +- tests/TypeInfer.test.cpp | 41 + tests/TypeInfer.tryUnify.test.cpp | 4 - tests/TypeVar.test.cpp | 22 +- tests/VisitTypeVar.test.cpp | 48 ++ tests/conformance/iter.lua | 196 +++++ tests/conformance/nextvar.lua | 55 +- tools/lldb_formatters.py | 2 +- 51 files changed, 2670 insertions(+), 729 deletions(-) create mode 100644 bench/micro_tests/test_LargeTableSum_loop_iter.lua delete mode 100644 bench/tests/sunspider/access-binary-trees.lua delete mode 100644 bench/tests/sunspider/math-spectral-norm.lua create mode 100644 tests/VisitTypeVar.test.cpp create mode 100644 tests/conformance/iter.lua diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 5912547..37e3cfd 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -145,7 +145,6 @@ struct Frontend */ std::pair lintFragment(std::string_view source, std::optional enabledLintWarnings = {}); - CheckResult check(const SourceModule& module); // OLD. TODO KILL LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 045190e..67fce5e 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -1,9 +1,15 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include + #include "Luau/DenseHash.h" -#include "Luau/TypeVar.h" +#include "Luau/RecursionCounter.h" #include "Luau/TypePack.h" +#include "Luau/TypeVar.h" + +LUAU_FASTFLAG(LuauUseVisitRecursionLimit) +LUAU_FASTINT(LuauVisitRecursionLimit) namespace Luau { @@ -219,24 +225,321 @@ void visit(TypePackId tp, F& f, Set& seen) } // namespace visit_detail +template +struct GenericTypeVarVisitor +{ + using Set = S; + + Set seen; + int recursionCounter = 0; + + GenericTypeVarVisitor() = default; + + explicit GenericTypeVarVisitor(Set seen) + : seen(std::move(seen)) + { + } + + virtual void cycle(TypeId) {} + virtual void cycle(TypePackId) {} + + virtual bool visit(TypeId ty) + { + return true; + } + virtual bool visit(TypeId ty, const BoundTypeVar& btv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const FreeTypeVar& ftv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const GenericTypeVar& gtv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const ErrorTypeVar& etv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const ConstrainedTypeVar& ctv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const FunctionTypeVar& ftv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const TableTypeVar& ttv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const MetatableTypeVar& mtv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const ClassTypeVar& ctv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const AnyTypeVar& atv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const UnionTypeVar& utv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const IntersectionTypeVar& itv) + { + return visit(ty); + } + + virtual bool visit(TypePackId tp) + { + return true; + } + virtual bool visit(TypePackId tp, const BoundTypePack& btp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const FreeTypePack& ftp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const GenericTypePack& gtp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const Unifiable::Error& etp) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const TypePack& pack) + { + return visit(tp); + } + virtual bool visit(TypePackId tp, const VariadicTypePack& vtp) + { + return visit(tp); + } + + void traverse(TypeId ty) + { + RecursionLimiter limiter{&recursionCounter, FInt::LuauVisitRecursionLimit, "TypeVarVisitor"}; + + if (visit_detail::hasSeen(seen, ty)) + { + cycle(ty); + return; + } + + if (auto btv = get(ty)) + { + if (visit(ty, *btv)) + traverse(btv->boundTo); + } + + else if (auto ftv = get(ty)) + visit(ty, *ftv); + + else if (auto gtv = get(ty)) + visit(ty, *gtv); + + else if (auto etv = get(ty)) + visit(ty, *etv); + + else if (auto ctv = get(ty)) + { + if (visit(ty, *ctv)) + { + for (TypeId part : ctv->parts) + traverse(part); + } + } + + else if (auto ptv = get(ty)) + visit(ty, *ptv); + + else if (auto ftv = get(ty)) + { + if (visit(ty, *ftv)) + { + traverse(ftv->argTypes); + traverse(ftv->retType); + } + } + + else if (auto ttv = get(ty)) + { + // Some visitors want to see bound tables, that's why we traverse the original type + if (visit(ty, *ttv)) + { + if (ttv->boundTo) + { + traverse(*ttv->boundTo); + } + else + { + for (auto& [_name, prop] : ttv->props) + traverse(prop.type); + + if (ttv->indexer) + { + traverse(ttv->indexer->indexType); + traverse(ttv->indexer->indexResultType); + } + } + } + } + + else if (auto mtv = get(ty)) + { + if (visit(ty, *mtv)) + { + traverse(mtv->table); + traverse(mtv->metatable); + } + } + + else if (auto ctv = get(ty)) + { + if (visit(ty, *ctv)) + { + for (const auto& [name, prop] : ctv->props) + traverse(prop.type); + + if (ctv->parent) + traverse(*ctv->parent); + + if (ctv->metatable) + traverse(*ctv->metatable); + } + } + + else if (auto atv = get(ty)) + visit(ty, *atv); + + else if (auto utv = get(ty)) + { + if (visit(ty, *utv)) + { + for (TypeId optTy : utv->options) + traverse(optTy); + } + } + + else if (auto itv = get(ty)) + { + if (visit(ty, *itv)) + { + for (TypeId partTy : itv->parts) + traverse(partTy); + } + } + + visit_detail::unsee(seen, ty); + } + + void traverse(TypePackId tp) + { + if (visit_detail::hasSeen(seen, tp)) + { + cycle(tp); + return; + } + + if (auto btv = get(tp)) + { + if (visit(tp, *btv)) + traverse(btv->boundTo); + } + + else if (auto ftv = get(tp)) + visit(tp, *ftv); + + else if (auto gtv = get(tp)) + visit(tp, *gtv); + + else if (auto etv = get(tp)) + visit(tp, *etv); + + else if (auto pack = get(tp)) + { + visit(tp, *pack); + + for (TypeId ty : pack->head) + traverse(ty); + + if (pack->tail) + traverse(*pack->tail); + } + else if (auto pack = get(tp)) + { + visit(tp, *pack); + traverse(pack->ty); + } + else + LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypePackId) is not exhaustive!"); + + visit_detail::unsee(seen, tp); + } +}; + +/** Visit each type under a given type. Skips over cycles and keeps recursion depth under control. + * + * The same type may be visited multiple times if there are multiple distinct paths to it. If this is undesirable, use + * TypeVarOnceVisitor. + */ +struct TypeVarVisitor : GenericTypeVarVisitor> +{ +}; + +/// Visit each type under a given type. Each type will only be checked once even if there are multiple paths to it. +struct TypeVarOnceVisitor : GenericTypeVarVisitor> +{ + TypeVarOnceVisitor() + : GenericTypeVarVisitor{DenseHashSet{nullptr}} + { + } +}; + +// Clip with FFlagLuauUseVisitRecursionLimit template -void visitTypeVar(TID ty, F& f, std::unordered_set& seen) +void DEPRECATED_visitTypeVar(TID ty, F& f, std::unordered_set& seen) { visit_detail::visit(ty, f, seen); } +// Delete and inline when clipping FFlagLuauUseVisitRecursionLimit template -void visitTypeVar(TID ty, F& f) +void DEPRECATED_visitTypeVar(TID ty, F& f) { - std::unordered_set seen; - visit_detail::visit(ty, f, seen); + if (FFlag::LuauUseVisitRecursionLimit) + f.traverse(ty); + else + { + std::unordered_set seen; + visit_detail::visit(ty, f, seen); + } } +// Delete and inline when clipping FFlagLuauUseVisitRecursionLimit template -void visitTypeVarOnce(TID ty, F& f, DenseHashSet& seen) +void DEPRECATED_visitTypeVarOnce(TID ty, F& f, DenseHashSet& seen) { - seen.clear(); - visit_detail::visit(ty, f, seen); + if (FFlag::LuauUseVisitRecursionLimit) + f.traverse(ty); + else + { + seen.clear(); + visit_detail::visit(ty, f, seen); + } } } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index dec12d0..19d06cf 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,7 +14,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompleteSingletonTypes, false); LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteClassSecurityLevel, false); LUAU_FASTFLAG(LuauSelfCallAutocompleteFix) @@ -1341,38 +1340,21 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul scope = scope->parent; } - if (FFlag::LuauAutocompleteSingletonTypes) - { - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); - TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().trueType); - TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().falseType); - TypeCorrectKind correctForFunction = - functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); + TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().trueType); + TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, node, position, getSingletonTypes().falseType); + TypeCorrectKind correctForFunction = + functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; - result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForTrue}; - result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForFalse}; - result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; - result["not"] = {AutocompleteEntryKind::Keyword}; - result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForTrue}; + result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForFalse}; + result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; + result["not"] = {AutocompleteEntryKind::Keyword}; + result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; - if (auto ty = findExpectedTypeAt(module, node, position)) - autocompleteStringSingleton(*ty, true, result); - } - else - { - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); - TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.booleanType); - TypeCorrectKind correctForFunction = - functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - - result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; - result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; - result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; - result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; - result["not"] = {AutocompleteEntryKind::Keyword}; - result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; - } + if (auto ty = findExpectedTypeAt(module, node, position)) + autocompleteStringSingleton(*ty, true, result); } } @@ -1680,11 +1662,8 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { AutocompleteEntryMap result; - if (FFlag::LuauAutocompleteSingletonTypes) - { - if (auto it = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*it, false, result); - } + if (auto it = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*it, false, result); if (finder.ancestry.size() >= 2) { @@ -1693,8 +1672,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (auto it = module->astTypes.find(idxExpr->expr)) autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry, result); } - else if (auto binExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as(); - binExpr && FFlag::LuauAutocompleteSingletonTypes) + else if (auto binExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) { if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) { diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index b8f7836..56c0ac2 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -18,7 +18,6 @@ LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTarjanChildLimit) -LUAU_FASTFLAG(LuauCyclicModuleTypeSurface) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false) @@ -433,8 +432,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional Frontend::lintFragment(std::string_view sour return {std::move(sourceModule), classifyLints(warnings, config)}; } -CheckResult Frontend::check(const SourceModule& module) -{ - LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); - LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); - - const Config& config = configResolver->getConfig(module.name); - - Mode mode = module.mode.value_or(config.mode); - - double timestamp = getTimestamp(); - - ModulePtr checkedModule = typeChecker.check(module, mode); - - stats.timeCheck += getTimestamp() - timestamp; - stats.filesStrict += mode == Mode::Strict; - stats.filesNonstrict += mode == Mode::Nonstrict; - - if (checkedModule == nullptr) - throw std::runtime_error("Frontend::check produced a nullptr module for module " + module.name); - moduleResolver.modules[module.name] = checkedModule; - - return CheckResult{checkedModule->errors}; -} - LintResult Frontend::lint(const SourceModule& module, std::optional enabledLintWarnings) { LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 043526e..d8c1138 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -304,37 +304,23 @@ static bool areNormal(TypePackId tp, const std::unordered_set& seen, Inte ++iterationLimit; \ } while (false) -struct Normalize +struct Normalize final : TypeVarVisitor { + using TypeVarVisitor::Set; + + Normalize(TypeArena& arena, InternalErrorReporter& ice) + : arena(arena) + , ice(ice) + { + } + TypeArena& arena; InternalErrorReporter& ice; - // Debug data. Types being normalized are invalidated but trying to see what's going on is painful. - // To actually see the original type, read it by using the pointer of the type being normalized. - // e.g. in lldb, `e dump(originalTys[ty])`. - SeenTypes originalTys; - SeenTypePacks originalTps; - int iterationLimit = 0; bool limitExceeded = false; - template - bool operator()(TypePackId, const T&) - { - return true; - } - - template - void cycle(TID) - { - } - - bool operator()(TypeId ty, const FreeTypeVar&) - { - LUAU_ASSERT(!ty->normal); - return false; - } - + // TODO: Clip with FFlag::LuauUseVisitRecursionLimit bool operator()(TypeId ty, const BoundTypeVar& btv, std::unordered_set& seen) { // A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses. @@ -349,27 +335,22 @@ struct Normalize return !ty->normal; } - bool operator()(TypeId ty, const PrimitiveTypeVar&) + bool operator()(TypeId ty, const FreeTypeVar& ftv) { - LUAU_ASSERT(ty->normal); - return false; + return visit(ty, ftv); } - - bool operator()(TypeId ty, const GenericTypeVar&) + bool operator()(TypeId ty, const PrimitiveTypeVar& ptv) { - if (!ty->normal) - asMutable(ty)->normal = true; - - return false; + return visit(ty, ptv); } - - bool operator()(TypeId ty, const ErrorTypeVar&) + bool operator()(TypeId ty, const GenericTypeVar& gtv) { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; + return visit(ty, gtv); + } + bool operator()(TypeId ty, const ErrorTypeVar& etv) + { + return visit(ty, etv); } - bool operator()(TypeId ty, const ConstrainedTypeVar& ctvRef, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -470,17 +451,12 @@ struct Normalize bool operator()(TypeId ty, const ClassTypeVar& ctv) { - if (!ty->normal) - asMutable(ty)->normal = true; - return false; + return visit(ty, ctv); } - - bool operator()(TypeId ty, const AnyTypeVar&) + bool operator()(TypeId ty, const AnyTypeVar& atv) { - LUAU_ASSERT(ty->normal); - return false; + return visit(ty, atv); } - bool operator()(TypeId ty, const UnionTypeVar& utvRef, std::unordered_set& seen) { CHECK_ITERATION_LIMIT(false); @@ -570,8 +546,257 @@ struct Normalize return false; } - bool operator()(TypeId ty, const LazyTypeVar&) + // TODO: Clip with FFlag::LuauUseVisitRecursionLimit + template + bool operator()(TypePackId, const T&) { + return true; + } + + // TODO: Clip with FFlag::LuauUseVisitRecursionLimit + template + void cycle(TID) + { + } + + bool visit(TypeId ty, const FreeTypeVar&) override + { + LUAU_ASSERT(!ty->normal); + return false; + } + + bool visit(TypeId ty, const BoundTypeVar& btv) override + { + // A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses. + // So we need to avoid eagerly saying that this bound type is normal if the thing it is bound to is in the stack. + if (seen.find(asMutable(btv.boundTo)) != seen.end()) + return false; + + // It should never be the case that this TypeVar is normal, but is bound to a non-normal type, except in nontrivial cases. + LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal); + + asMutable(ty)->normal = btv.boundTo->normal; + return !ty->normal; + } + + bool visit(TypeId ty, const PrimitiveTypeVar&) override + { + LUAU_ASSERT(ty->normal); + return false; + } + + bool visit(TypeId ty, const GenericTypeVar&) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + + return false; + } + + bool visit(TypeId ty, const ErrorTypeVar&) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override + { + CHECK_ITERATION_LIMIT(false); + + ConstrainedTypeVar* ctv = const_cast(&ctvRef); + + std::vector parts = std::move(ctv->parts); + + // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar + for (TypeId part : parts) + traverse(part); + + std::vector newParts = normalizeUnion(parts); + + const bool normal = areNormal(newParts, seen, ice); + + if (newParts.size() == 1) + *asMutable(ty) = BoundTypeVar{newParts[0]}; + else + *asMutable(ty) = UnionTypeVar{std::move(newParts)}; + + asMutable(ty)->normal = normal; + + return false; + } + + bool visit(TypeId ty, const FunctionTypeVar& ftv) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + traverse(ftv.argTypes); + traverse(ftv.retType); + + asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retType, seen, ice); + + return false; + } + + bool visit(TypeId ty, const TableTypeVar& ttv) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + bool normal = true; + + auto checkNormal = [&](TypeId t) { + // if t is on the stack, it is possible that this type is normal. + // If t is not normal and it is not on the stack, this type is definitely not normal. + if (!t->normal && seen.find(asMutable(t)) == seen.end()) + normal = false; + }; + + if (ttv.boundTo) + { + traverse(*ttv.boundTo); + asMutable(ty)->normal = (*ttv.boundTo)->normal; + return false; + } + + for (const auto& [_name, prop] : ttv.props) + { + traverse(prop.type); + checkNormal(prop.type); + } + + if (ttv.indexer) + { + traverse(ttv.indexer->indexType); + checkNormal(ttv.indexer->indexType); + traverse(ttv.indexer->indexResultType); + checkNormal(ttv.indexer->indexResultType); + } + + asMutable(ty)->normal = normal; + + return false; + } + + bool visit(TypeId ty, const MetatableTypeVar& mtv) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + traverse(mtv.table); + traverse(mtv.metatable); + + asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal; + + return false; + } + + bool visit(TypeId ty, const ClassTypeVar& ctv) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool visit(TypeId ty, const AnyTypeVar&) override + { + LUAU_ASSERT(ty->normal); + return false; + } + + bool visit(TypeId ty, const UnionTypeVar& utvRef) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + UnionTypeVar* utv = &const_cast(utvRef); + std::vector options = std::move(utv->options); + + // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar + for (TypeId option : options) + traverse(option); + + std::vector newOptions = normalizeUnion(options); + + const bool normal = areNormal(newOptions, seen, ice); + + LUAU_ASSERT(!newOptions.empty()); + + if (newOptions.size() == 1) + *asMutable(ty) = BoundTypeVar{newOptions[0]}; + else + utv->options = std::move(newOptions); + + asMutable(ty)->normal = normal; + + return false; + } + + bool visit(TypeId ty, const IntersectionTypeVar& itvRef) override + { + CHECK_ITERATION_LIMIT(false); + + if (ty->normal) + return false; + + IntersectionTypeVar* itv = &const_cast(itvRef); + + std::vector oldParts = std::move(itv->parts); + + for (TypeId part : oldParts) + traverse(part); + + std::vector tables; + for (TypeId part : oldParts) + { + part = follow(part); + if (get(part)) + tables.push_back(part); + else + { + Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD + combineIntoIntersection(replacer, itv, part); + } + } + + // Don't allocate a new table if there's just one in the intersection. + if (tables.size() == 1) + itv->parts.push_back(tables[0]); + else if (!tables.empty()) + { + const TableTypeVar* first = get(tables[0]); + LUAU_ASSERT(first); + + TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); + TableTypeVar* ttv = getMutable(newTable); + for (TypeId part : tables) + { + // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need + // to be rewritten to point at 'newTable' in the clone. + Replacer replacer{&arena, part, newTable}; + combineIntoTable(replacer, ttv, part); + } + + itv->parts.push_back(newTable); + } + + asMutable(ty)->normal = areNormal(itv->parts, seen, ice); + + if (itv->parts.size() == 1) + { + TypeId part = itv->parts[0]; + *asMutable(ty) = BoundTypeVar{part}; + } + return false; } @@ -778,9 +1003,9 @@ std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorRepo if (FFlag::DebugLuauCopyBeforeNormalizing) (void)clone(ty, arena, state); - Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; + Normalize n{arena, ice}; std::unordered_set seen; - visitTypeVar(ty, n, seen); + DEPRECATED_visitTypeVar(ty, n, seen); return {ty, !n.limitExceeded}; } @@ -803,9 +1028,9 @@ std::pair normalize(TypePackId tp, TypeArena& arena, InternalE if (FFlag::DebugLuauCopyBeforeNormalizing) (void)clone(tp, arena, state); - Normalize n{arena, ice, std::move(state.seenTypes), std::move(state.seenTypePacks)}; + Normalize n{arena, ice}; std::unordered_set seen; - visitTypeVar(tp, n, seen); + DEPRECATED_visitTypeVar(tp, n, seen); return {tp, !n.limitExceeded}; } diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 305f83c..4f3e446 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -9,7 +9,7 @@ LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { -struct Quantifier +struct Quantifier final : TypeVarOnceVisitor { TypeLevel level; std::vector generics; @@ -17,26 +17,17 @@ struct Quantifier bool seenGenericType = false; bool seenMutableType = false; - Quantifier(TypeLevel level) + explicit Quantifier(TypeLevel level) : level(level) { } - void cycle(TypeId) {} - void cycle(TypePackId) {} + void cycle(TypeId) override {} + void cycle(TypePackId) override {} bool operator()(TypeId ty, const FreeTypeVar& ftv) { - if (FFlag::LuauTypecheckOptPass) - seenMutableType = true; - - if (!level.subsumes(ftv.level)) - return false; - - *asMutable(ty) = GenericTypeVar{level}; - generics.push_back(ty); - - return false; + return visit(ty, ftv); } template @@ -56,8 +47,33 @@ struct Quantifier return true; } - bool operator()(TypeId ty, const TableTypeVar&) + bool operator()(TypeId ty, const TableTypeVar& ttv) { + return visit(ty, ttv); + } + + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + return visit(tp, ftp); + } + + bool visit(TypeId ty, const FreeTypeVar& ftv) override + { + if (FFlag::LuauTypecheckOptPass) + seenMutableType = true; + + if (!level.subsumes(ftv.level)) + return false; + + *asMutable(ty) = GenericTypeVar{level}; + generics.push_back(ty); + + return false; + } + + bool visit(TypeId ty, const TableTypeVar&) override + { + LUAU_ASSERT(getMutable(ty)); TableTypeVar& ttv = *getMutable(ty); if (FFlag::LuauTypecheckOptPass) @@ -93,7 +109,7 @@ struct Quantifier return true; } - bool operator()(TypePackId tp, const FreeTypePack& ftp) + bool visit(TypePackId tp, const FreeTypePack& ftp) override { if (FFlag::LuauTypecheckOptPass) seenMutableType = true; @@ -111,7 +127,7 @@ void quantify(TypeId ty, TypeLevel level) { Quantifier q{level}; DenseHashSet seen{nullptr}; - visitTypeVarOnce(ty, q, seen); + DEPRECATED_visitTypeVarOnce(ty, q, seen); FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 610842d..b5d6a55 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -26,7 +26,7 @@ namespace Luau namespace { -struct FindCyclicTypes +struct FindCyclicTypes final : TypeVarVisitor { FindCyclicTypes() = default; FindCyclicTypes(const FindCyclicTypes&) = delete; @@ -38,20 +38,22 @@ struct FindCyclicTypes std::set cycles; std::set cycleTPs; - void cycle(TypeId ty) + void cycle(TypeId ty) override { cycles.insert(ty); } - void cycle(TypePackId tp) + void cycle(TypePackId tp) override { cycleTPs.insert(tp); } + // TODO: Clip all the operator()s when we clip FFlagLuauUseVisitRecursionLimit + template bool operator()(TypeId ty, const T&) { - return visited.insert(ty).second; + return visit(ty); } bool operator()(TypeId ty, const TableTypeVar& ttv) = delete; @@ -64,10 +66,10 @@ struct FindCyclicTypes if (ttv.name || ttv.syntheticName) { for (TypeId itp : ttv.instantiatedTypeParams) - visitTypeVar(itp, *this, seen); + DEPRECATED_visitTypeVar(itp, *this, seen); for (TypePackId itp : ttv.instantiatedTypePackParams) - visitTypeVar(itp, *this, seen); + DEPRECATED_visitTypeVar(itp, *this, seen); return exhaustive; } @@ -82,9 +84,43 @@ struct FindCyclicTypes template bool operator()(TypePackId tp, const T&) + { + return visit(tp); + } + + bool visit(TypeId ty) override + { + return visited.insert(ty).second; + } + + bool visit(TypePackId tp) override { return visitedPacks.insert(tp).second; } + + bool visit(TypeId ty, const TableTypeVar& ttv) override + { + if (!visited.insert(ty).second) + return false; + + if (ttv.name || ttv.syntheticName) + { + for (TypeId itp : ttv.instantiatedTypeParams) + traverse(itp); + + for (TypePackId itp : ttv.instantiatedTypePackParams) + traverse(itp); + + return exhaustive; + } + + return true; + } + + bool visit(TypeId ty, const ClassTypeVar&) override + { + return false; + } }; template @@ -92,7 +128,7 @@ void findCyclicTypes(std::set& cycles, std::set& cycleTPs, T { FindCyclicTypes fct; fct.exhaustive = exhaustive; - visitTypeVar(ty, fct); + DEPRECATED_visitTypeVar(ty, fct); cycles = std::move(fct.cycles); cycleTPs = std::move(fct.cycleTPs); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index a5f9d26..1fb5a61 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,7 +7,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauTxnLogPreserveOwner, false) LUAU_FASTFLAGVARIABLE(LuauJustOneCallFrameForHaveSeen, false) namespace Luau @@ -81,31 +80,20 @@ void TxnLog::concat(TxnLog rhs) void TxnLog::commit() { - if (FFlag::LuauTxnLogPreserveOwner) + for (auto& [ty, rep] : typeVarChanges) { - for (auto& [ty, rep] : typeVarChanges) - { - TypeArena* owningArena = ty->owningArena; - TypeVar* mtv = asMutable(ty); - *mtv = rep.get()->pending; - mtv->owningArena = owningArena; - } - - for (auto& [tp, rep] : typePackChanges) - { - TypeArena* owningArena = tp->owningArena; - TypePackVar* mpv = asMutable(tp); - *mpv = rep.get()->pending; - mpv->owningArena = owningArena; - } + TypeArena* owningArena = ty->owningArena; + TypeVar* mtv = asMutable(ty); + *mtv = rep.get()->pending; + mtv->owningArena = owningArena; } - else - { - for (auto& [ty, rep] : typeVarChanges) - *asMutable(ty) = rep.get()->pending; - for (auto& [tp, rep] : typePackChanges) - *asMutable(tp) = rep.get()->pending; + for (auto& [tp, rep] : typePackChanges) + { + TypeArena* owningArena = tp->owningArena; + TypePackVar* mpv = asMutable(tp); + *mpv = rep.get()->pending; + mpv->owningArena = owningArena; } clear(); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index ba91ae1..4466ede 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -26,11 +26,11 @@ LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) +LUAU_FASTFLAGVARIABLE(LuauUseVisitRecursionLimit, false) +LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauSeparateTypechecks) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) -LUAU_FASTFLAG(LuauAutocompleteSingletonTypes) -LUAU_FASTFLAGVARIABLE(LuauCyclicModuleTypeSurface, false) LUAU_FASTFLAGVARIABLE(LuauDoNotRelyOnNextBinding, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. @@ -40,6 +40,7 @@ LUAU_FASTFLAGVARIABLE(LuauInferStatFunction, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateFollows, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) +LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false) @@ -57,6 +58,7 @@ LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); LUAU_FASTFLAG(LuauLosslessClone) +LUAU_FASTFLAGVARIABLE(LuauTypecheckIter, false); namespace Luau { @@ -1159,6 +1161,47 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) iterTy = follow(instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location)); } + if (FFlag::LuauTypecheckIter) + { + if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location)) + { + // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions + // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments + // the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types + for (TypeId var : varTypes) + unify(anyType, var, forin.location); + + return check(loopScope, *forin.body); + } + else if (const TableTypeVar* iterTable = get(iterTy)) + { + // TODO: note that this doesn't cleanly handle iteration over mixed tables and tables without an indexer + // this behavior is more or less consistent with what we do for pairs(), but really both are pretty wrong and need revisiting + if (iterTable->indexer) + { + if (varTypes.size() > 0) + unify(iterTable->indexer->indexType, varTypes[0], forin.location); + + if (varTypes.size() > 1) + unify(iterTable->indexer->indexResultType, varTypes[1], forin.location); + + for (size_t i = 2; i < varTypes.size(); ++i) + unify(nilType, varTypes[i], forin.location); + } + else + { + TypeId varTy = errorRecoveryType(loopScope); + + for (TypeId var : varTypes) + unify(varTy, var, forin.location); + + reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); + } + + return check(loopScope, *forin.body); + } + } + const FunctionTypeVar* iterFunc = get(iterTy); if (!iterFunc) { @@ -2026,15 +2069,29 @@ std::vector TypeChecker::reduceUnion(const std::vector& types) if (const UnionTypeVar* utv = get(t)) { - std::vector r = reduceUnion(utv->options); - for (TypeId ty : r) + if (FFlag::LuauReduceUnionRecursion) { - ty = follow(ty); - if (get(ty) || get(ty)) - return {ty}; + for (TypeId ty : utv) + { + if (get(ty) || get(ty)) + return {ty}; - if (std::find(result.begin(), result.end(), ty) == result.end()) - result.push_back(ty); + if (result.end() == std::find(result.begin(), result.end(), ty)) + result.push_back(ty); + } + } + else + { + std::vector r = reduceUnion(utv->options); + for (TypeId ty : r) + { + ty = follow(ty); + if (get(ty) || get(ty)) + return {ty}; + + if (std::find(result.begin(), result.end(), ty) == result.end()) + result.push_back(ty); + } } } else if (std::find(result.begin(), result.end(), t) == result.end()) @@ -4372,17 +4429,12 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module } // Types of requires that transitively refer to current module have to be replaced with 'any' - std::string humanReadableName; + std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); - if (FFlag::LuauCyclicModuleTypeSurface) + for (const auto& [location, path] : requireCycles) { - humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); - - for (const auto& [location, path] : requireCycles) - { - if (!path.empty() && path.front() == humanReadableName) - return anyType; - } + if (!path.empty() && path.front() == humanReadableName) + return anyType; } ModulePtr module = resolver->getModule(moduleInfo.name); @@ -4392,32 +4444,14 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module // either the file does not exist or there's a cycle. If there's a cycle // we will already have reported the error. if (!resolver->moduleExists(moduleInfo.name) && !moduleInfo.optional) - { - if (FFlag::LuauCyclicModuleTypeSurface) - { - reportError(TypeError{location, UnknownRequire{humanReadableName}}); - } - else - { - std::string reportedModulePath = resolver->getHumanReadableModuleName(moduleInfo.name); - reportError(TypeError{location, UnknownRequire{reportedModulePath}}); - } - } + reportError(TypeError{location, UnknownRequire{humanReadableName}}); return errorRecoveryType(scope); } if (module->type != SourceCode::Module) { - if (FFlag::LuauCyclicModuleTypeSurface) - { - reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); - } - else - { - std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); - reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); - } + reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); return errorRecoveryType(scope); } @@ -4429,15 +4463,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module std::optional moduleType = first(modulePack); if (!moduleType) { - if (FFlag::LuauCyclicModuleTypeSurface) - { - reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); - } - else - { - std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); - reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); - } + reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); return errorRecoveryType(scope); } @@ -4947,10 +4973,7 @@ TypeId TypeChecker::freshType(TypeLevel level) TypeId TypeChecker::singletonType(bool value) { - if (FFlag::LuauAutocompleteSingletonTypes) - return value ? getSingletonTypes().trueType : getSingletonTypes().falseType; - - return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BooleanSingleton{value}))); + return value ? getSingletonTypes().trueType : getSingletonTypes().falseType; } TypeId TypeChecker::singletonType(std::string value) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 334806c..f5c1dde 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,7 +22,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) -LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter, false) +LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter2, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) LUAU_FASTFLAG(LuauTypecheckOptPass) @@ -30,7 +30,7 @@ LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { -struct PromoteTypeLevels +struct PromoteTypeLevels final : TypeVarOnceVisitor { TxnLog& log; const TypeArena* typeArena = nullptr; @@ -53,13 +53,34 @@ struct PromoteTypeLevels } } + // TODO cycle and operator() need to be clipped when FFlagLuauUseVisitRecursionLimit is clipped template void cycle(TID) { } - template bool operator()(TID ty, const T&) + { + return visit(ty); + } + bool operator()(TypeId ty, const FreeTypeVar& ftv) + { + return visit(ty, ftv); + } + bool operator()(TypeId ty, const FunctionTypeVar& ftv) + { + return visit(ty, ftv); + } + bool operator()(TypeId ty, const TableTypeVar& ttv) + { + return visit(ty, ttv); + } + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + return visit(tp, ftp); + } + + bool visit(TypeId ty) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (ty->owningArena != typeArena) @@ -68,7 +89,16 @@ struct PromoteTypeLevels return true; } - bool operator()(TypeId ty, const FreeTypeVar&) + bool visit(TypePackId tp) override + { + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (tp->owningArena != typeArena) + return false; + + return true; + } + + bool visit(TypeId ty, const FreeTypeVar&) override { // Surprise, it's actually a BoundTypeVar that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. @@ -79,7 +109,7 @@ struct PromoteTypeLevels return true; } - bool operator()(TypeId ty, const FunctionTypeVar&) + bool visit(TypeId ty, const FunctionTypeVar&) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (ty->owningArena != typeArena) @@ -89,7 +119,7 @@ struct PromoteTypeLevels return true; } - bool operator()(TypeId ty, const TableTypeVar& ttv) + bool visit(TypeId ty, const TableTypeVar& ttv) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (ty->owningArena != typeArena) @@ -102,7 +132,7 @@ struct PromoteTypeLevels return true; } - bool operator()(TypePackId tp, const FreeTypePack&) + bool visit(TypePackId tp, const FreeTypePack&) override { // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. @@ -122,7 +152,7 @@ static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel PromoteTypeLevels ptl{log, typeArena, minLevel}; DenseHashSet seen{nullptr}; - visitTypeVarOnce(ty, ptl, seen); + DEPRECATED_visitTypeVarOnce(ty, ptl, seen); } void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) @@ -133,10 +163,10 @@ void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLev PromoteTypeLevels ptl{log, typeArena, minLevel}; DenseHashSet seen{nullptr}; - visitTypeVarOnce(tp, ptl, seen); + DEPRECATED_visitTypeVarOnce(tp, ptl, seen); } -struct SkipCacheForType +struct SkipCacheForType final : TypeVarOnceVisitor { SkipCacheForType(const DenseHashMap& skipCacheForType, const TypeArena* typeArena) : skipCacheForType(skipCacheForType) @@ -144,28 +174,68 @@ struct SkipCacheForType { } - void cycle(TypeId) {} - void cycle(TypePackId) {} + // TODO cycle() and operator() can be clipped with FFlagLuauUseVisitRecursionLimit + void cycle(TypeId) override {} + void cycle(TypePackId) override {} bool operator()(TypeId ty, const FreeTypeVar& ftv) { - result = true; - return false; + return visit(ty, ftv); } - bool operator()(TypeId ty, const BoundTypeVar& btv) { - result = true; - return false; + return visit(ty, btv); + } + bool operator()(TypeId ty, const GenericTypeVar& gtv) + { + return visit(ty, gtv); + } + bool operator()(TypeId ty, const TableTypeVar& ttv) + { + return visit(ty, ttv); + } + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + return visit(tp, ftp); + } + bool operator()(TypePackId tp, const BoundTypePack& ftp) + { + return visit(tp, ftp); + } + bool operator()(TypePackId tp, const GenericTypePack& ftp) + { + return visit(tp, ftp); + } + template + bool operator()(TypeId ty, const T& t) + { + return visit(ty); + } + template + bool operator()(TypePackId tp, const T&) + { + return visit(tp); } - bool operator()(TypeId ty, const GenericTypeVar& btv) + bool visit(TypeId, const FreeTypeVar&) override { result = true; return false; } - bool operator()(TypeId ty, const TableTypeVar&) + bool visit(TypeId, const BoundTypeVar&) override + { + result = true; + return false; + } + + bool visit(TypeId, const GenericTypeVar&) override + { + result = true; + return false; + } + + bool visit(TypeId ty, const TableTypeVar&) override { // Types from other modules don't contain mutable elements and are ok to cache if (ty->owningArena != typeArena) @@ -188,8 +258,7 @@ struct SkipCacheForType return true; } - template - bool operator()(TypeId ty, const T& t) + bool visit(TypeId ty) override { // Types from other modules don't contain mutable elements and are ok to cache if (ty->owningArena != typeArena) @@ -206,8 +275,7 @@ struct SkipCacheForType return true; } - template - bool operator()(TypePackId tp, const T&) + bool visit(TypePackId tp) override { // Types from other modules don't contain mutable elements and are ok to cache if (tp->owningArena != typeArena) @@ -216,19 +284,19 @@ struct SkipCacheForType return true; } - bool operator()(TypePackId tp, const FreeTypePack& ftp) + bool visit(TypePackId tp, const FreeTypePack&) override { result = true; return false; } - bool operator()(TypePackId tp, const BoundTypePack& ftp) + bool visit(TypePackId tp, const BoundTypePack&) override { result = true; return false; } - bool operator()(TypePackId tp, const GenericTypePack& ftp) + bool visit(TypePackId tp, const GenericTypePack&) override { result = true; return false; @@ -578,7 +646,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId failed = true; } - if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) + if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter2) { } else @@ -593,7 +661,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId } // even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option. - if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) + if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter2) { auto tryBind = [this, subTy](TypeId superOption) { superOption = log.follow(superOption); @@ -603,6 +671,14 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) return; + // If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype + // test is successful. + if (auto subUnion = get(subTy)) + { + if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption)) + return; + } + // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. if (log.haveSeen(subTy, superOption)) @@ -822,7 +898,7 @@ bool Unifier::canCacheResult(TypeId subTy, TypeId superTy) auto skipCacheFor = [this](TypeId ty) { SkipCacheForType visitor{sharedState.skipCacheForType, types}; - visitTypeVarOnce(ty, visitor, sharedState.seenAny); + DEPRECATED_visitTypeVarOnce(ty, visitor, sharedState.seenAny); sharedState.skipCacheForType[ty] = visitor.result; diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 31cd01c..6f39e3f 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -313,7 +313,7 @@ template struct AstArray { T* data; - std::size_t size; + size_t size; const T* begin() const { diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 31ff3f7..91f5cd2 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,7 +10,6 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauParseRecoverUnexpectedPack, false) LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkipInCapture, false) namespace Luau @@ -1430,7 +1429,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isIntersection = true; } - else if (FFlag::LuauParseRecoverUnexpectedPack && c == Lexeme::Dot3) + else if (c == Lexeme::Dot3) { report(lexer.current().location, "Unexpected '...' after type annotation"); nextLexeme(); @@ -1551,7 +1550,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) prefix = name.name; name = parseIndexName("field name", pointPosition); } - else if (FFlag::LuauParseRecoverUnexpectedPack && lexer.current().type == Lexeme::Dot3) + else if (lexer.current().type == Lexeme::Dot3) { report(lexer.current().location, "Unexpected '...' after type name; type pack is not allowed in this context"); nextLexeme(); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 4cb2234..83060f5 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -21,6 +21,8 @@ #include #endif +#include + LUAU_FASTFLAG(DebugLuauTimeTracing) enum class CliMode @@ -435,6 +437,9 @@ static void runReplImpl(lua_State* L) { ic_set_default_completer(completeRepl, L); + // Reset the locale to C + setlocale(LC_ALL, "C"); + // Make brace matching easier to see ic_style_def("ic-bracematch", "teal"); diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index c6e5a03..f71d893 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -353,6 +353,11 @@ enum LuauOpcode // AUX: constant index LOP_FASTCALL2K, + // FORGPREP: prepare loop variables for a generic for loop, jump to the loop backedge unconditionally + // A: target register; generic for loops assume a register layout [generator, state, index, variables...] + // D: jump offset (-32768..32767) + LOP_FORGPREP, + // Enum entry for number of opcodes, not a valid opcode by itself! LOP__COUNT }; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 871a148..fb70392 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -96,6 +96,7 @@ inline bool isJumpD(LuauOpcode op) case LOP_JUMPIFNOTLT: case LOP_FORNPREP: case LOP_FORNLOOP: + case LOP_FORGPREP: case LOP_FORGLOOP: case LOP_FORGPREP_INEXT: case LOP_FORGLOOP_INEXT: @@ -1269,6 +1270,11 @@ void BytecodeBuilder::validate() const VJUMP(LUAU_INSN_D(insn)); break; + case LOP_FORGPREP: + VREG(LUAU_INSN_A(insn) + 2 + 1); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + VJUMP(LUAU_INSN_D(insn)); + break; + case LOP_FORGLOOP: VREG( LUAU_INSN_A(insn) + 2 + insns[i + 1]); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables @@ -1622,6 +1628,10 @@ const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::stri formatAppend(result, "FORNLOOP R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); break; + case LOP_FORGPREP: + formatAppend(result, "FORGPREP R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + case LOP_FORGLOOP: formatAppend(result, "FORGLOOP R%d %+d %d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn), *code++); break; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 0f17ee0..4fe2622 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -17,9 +17,19 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauCompileSupportInlining, false) + +LUAU_FASTFLAGVARIABLE(LuauCompileIter, false) +LUAU_FASTFLAGVARIABLE(LuauCompileIterNoReserve, false) +LUAU_FASTFLAGVARIABLE(LuauCompileIterNoPairs, false) + LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThresholdMaxBoost, 300) +LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) +LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) +LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) + namespace Luau { @@ -147,6 +157,52 @@ struct Compiler } } + AstExprFunction* getFunctionExpr(AstExpr* node) + { + if (AstExprLocal* le = node->as()) + { + Variable* lv = variables.find(le->local); + + if (!lv || lv->written || !lv->init) + return nullptr; + + return getFunctionExpr(lv->init); + } + else if (AstExprGroup* ge = node->as()) + return getFunctionExpr(ge->expr); + else + return node->as(); + } + + bool canInlineFunctionBody(AstStat* stat) + { + struct CanInlineVisitor : AstVisitor + { + bool result = true; + + bool visit(AstExpr* node) override + { + // nested functions may capture function arguments, and our upval handling doesn't handle elided variables (constant) + // TODO: we could remove this case if we changed function compilation to create temporary locals for constant upvalues + // TODO: additionally we would need to change upvalue handling in compileExprFunction to handle upvalue->local migration + result = result && !node->is(); + return result; + } + + bool visit(AstStat* node) override + { + // loops may need to be unrolled which can result in cost amplification + result = result && !node->is(); + return result; + } + }; + + CanInlineVisitor canInline; + stat->visit(&canInline); + + return canInline.result; + } + uint32_t compileFunction(AstExprFunction* func) { LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler"); @@ -214,13 +270,21 @@ struct Compiler bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size())); - stackSize = 0; - Function& f = functions[func]; f.id = fid; f.upvals = upvals; + // record information for inlining + if (FFlag::LuauCompileSupportInlining && options.optimizationLevel >= 2 && !func->vararg && canInlineFunctionBody(func->body) && + !getfenvUsed && !setfenvUsed) + { + f.canInline = true; + f.stackSize = stackSize; + f.costModel = modelCost(func->body, func->args.data, func->args.size); + } + upvals.clear(); // note: instead of std::move above, we copy & clear to preserve capacity for future pushes + stackSize = 0; return fid; } @@ -390,12 +454,183 @@ struct Compiler } } + bool tryCompileInlinedCall(AstExprCall* expr, AstExprFunction* func, uint8_t target, uint8_t targetCount, bool multRet, int thresholdBase, + int thresholdMaxBoost, int depthLimit) + { + Function* fi = functions.find(func); + LUAU_ASSERT(fi); + + // make sure we have enough register space + if (regTop > 128 || fi->stackSize > 32) + { + bytecode.addDebugRemark("inlining failed: high register pressure"); + return false; + } + + // we should ideally aggregate the costs during recursive inlining, but for now simply limit the depth + if (int(inlineFrames.size()) >= depthLimit) + { + bytecode.addDebugRemark("inlining failed: too many inlined frames"); + return false; + } + + // compiling recursive inlining is difficult because we share constant/variable state but need to bind variables to different registers + for (InlineFrame& frame : inlineFrames) + if (frame.func == func) + { + bytecode.addDebugRemark("inlining failed: can't inline recursive calls"); + return false; + } + + // TODO: we can compile multret functions if all returns of the function are multret as well + if (multRet) + { + bytecode.addDebugRemark("inlining failed: can't convert fixed returns to multret"); + return false; + } + + // TODO: we can compile functions with mismatching arity at call site but it's more annoying + if (func->args.size != expr->args.size) + { + bytecode.addDebugRemark("inlining failed: argument count mismatch (expected %d, got %d)", int(func->args.size), int(expr->args.size)); + return false; + } + + // we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to inlining + bool varc[8] = {}; + for (size_t i = 0; i < expr->args.size && i < 8; ++i) + varc[i] = isConstant(expr->args.data[i]); + + int inlinedCost = computeCost(fi->costModel, varc, std::min(int(expr->args.size), 8)); + int baselineCost = computeCost(fi->costModel, nullptr, 0) + 3; + int inlineProfit = (inlinedCost == 0) ? thresholdMaxBoost : std::min(thresholdMaxBoost, 100 * baselineCost / inlinedCost); + + int threshold = thresholdBase * inlineProfit / 100; + + if (inlinedCost > threshold) + { + bytecode.addDebugRemark("inlining failed: too expensive (cost %d, profit %.2fx)", inlinedCost, double(inlineProfit) / 100); + return false; + } + + bytecode.addDebugRemark( + "inlining succeeded (cost %d, profit %.2fx, depth %d)", inlinedCost, double(inlineProfit) / 100, int(inlineFrames.size())); + + compileInlinedCall(expr, func, target, targetCount); + return true; + } + + void compileInlinedCall(AstExprCall* expr, AstExprFunction* func, uint8_t target, uint8_t targetCount) + { + RegScope rs(this); + + size_t oldLocals = localStack.size(); + + // note that we push the frame early; this is needed to block recursive inline attempts + inlineFrames.push_back({func, target, targetCount}); + + // evaluate all arguments; note that we don't emit code for constant arguments (relying on constant folding) + for (size_t i = 0; i < func->args.size; ++i) + { + AstLocal* var = func->args.data[i]; + AstExpr* arg = expr->args.data[i]; + + if (Variable* vv = variables.find(var); vv && vv->written) + { + // if the argument is mutated, we need to allocate a fresh register even if it's a constant + uint8_t reg = allocReg(arg, 1); + compileExprTemp(arg, reg); + pushLocal(var, reg); + } + else if (const Constant* cv = constants.find(arg); cv && cv->type != Constant::Type_Unknown) + { + // since the argument is not mutated, we can simply fold the value into the expressions that need it + locstants[var] = *cv; + } + else + { + AstExprLocal* le = arg->as(); + Variable* lv = le ? variables.find(le->local) : nullptr; + + // if the argument is a local that isn't mutated, we will simply reuse the existing register + if (isExprLocalReg(arg) && (!lv || !lv->written)) + { + uint8_t reg = getLocal(le->local); + pushLocal(var, reg); + } + else + { + uint8_t reg = allocReg(arg, 1); + compileExprTemp(arg, reg); + pushLocal(var, reg); + } + } + } + + // fold constant values updated above into expressions in the function body + foldConstants(constants, variables, locstants, func->body); + + bool usedFallthrough = false; + + for (size_t i = 0; i < func->body->body.size; ++i) + { + AstStat* stat = func->body->body.data[i]; + + if (AstStatReturn* ret = stat->as()) + { + // Optimization: use fallthrough when compiling return at the end of the function to avoid an extra JUMP + compileInlineReturn(ret, /* fallthrough= */ true); + // TODO: This doesn't work when return is part of control flow; ideally we would track the state somehow and generalize this + usedFallthrough = true; + break; + } + else + compileStat(stat); + } + + // for the fallthrough path we need to ensure we clear out target registers + if (!usedFallthrough && !allPathsEndWithReturn(func->body)) + { + for (size_t i = 0; i < targetCount; ++i) + bytecode.emitABC(LOP_LOADNIL, uint8_t(target + i), 0, 0); + } + + popLocals(oldLocals); + + size_t returnLabel = bytecode.emitLabel(); + patchJumps(expr, inlineFrames.back().returnJumps, returnLabel); + + inlineFrames.pop_back(); + + // clean up constant state for future inlining attempts + for (size_t i = 0; i < func->args.size; ++i) + if (Constant* var = locstants.find(func->args.data[i])) + var->type = Constant::Type_Unknown; + + foldConstants(constants, variables, locstants, func->body); + } + void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false) { LUAU_ASSERT(!targetTop || unsigned(target + targetCount) == regTop); setDebugLine(expr); // normally compileExpr sets up line info, but compileExprCall can be called directly + // try inlining the function + if (options.optimizationLevel >= 2 && !expr->self) + { + AstExprFunction* func = getFunctionExpr(expr->func); + Function* fi = func ? functions.find(func) : nullptr; + + if (fi && fi->canInline && + tryCompileInlinedCall(expr, func, target, targetCount, multRet, FInt::LuauCompileInlineThreshold, + FInt::LuauCompileInlineThresholdMaxBoost, FInt::LuauCompileInlineDepth)) + return; + + if (fi && !fi->canInline) + bytecode.addDebugRemark("inlining failed: complex constructs in function body"); + } + RegScope rs(this); unsigned int regCount = std::max(unsigned(1 + expr->self + expr->args.size), unsigned(targetCount)); @@ -760,7 +995,7 @@ struct Compiler { const Constant* c = constants.find(node); - if (!c) + if (!c || c->type == Constant::Type_Unknown) return -1; int cid = -1; @@ -1395,27 +1630,29 @@ struct Compiler { RegScope rs(this); + // note: cv may be invalidated by compileExpr* so we stop using it before calling compile recursively const Constant* cv = constants.find(expr->index); if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && double(int(cv->valueNumber)) == cv->valueNumber) { - uint8_t rt = compileExprAuto(expr->expr, rs); uint8_t i = uint8_t(int(cv->valueNumber) - 1); + uint8_t rt = compileExprAuto(expr->expr, rs); + setDebugLine(expr->index); bytecode.emitABC(LOP_GETTABLEN, target, rt, i); } else if (cv && cv->type == Constant::Type_String) { - uint8_t rt = compileExprAuto(expr->expr, rs); - BytecodeBuilder::StringRef iname = sref(cv->getString()); int32_t cid = bytecode.addConstantString(iname); if (cid < 0) CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + uint8_t rt = compileExprAuto(expr->expr, rs); + setDebugLine(expr->index); bytecode.emitABC(LOP_GETTABLEKS, target, rt, uint8_t(BytecodeBuilder::getStringHash(iname))); @@ -1561,8 +1798,9 @@ struct Compiler } else if (AstExprLocal* expr = node->as()) { - if (expr->upvalue) + if (FFlag::LuauCompileSupportInlining ? !isExprLocalReg(expr) : expr->upvalue) { + LUAU_ASSERT(expr->upvalue); uint8_t uid = getUpval(expr->local); bytecode.emitABC(LOP_GETUPVAL, target, uid, 0); @@ -1650,12 +1888,12 @@ struct Compiler // initializes target..target+targetCount-1 range using expressions from the list // if list has fewer expressions, and last expression is a call, we assume the call returns the rest of the values // if list has fewer expressions, and last expression isn't a call, we fill the rest with nil - // assumes target register range can be clobbered and is at the top of the register space - void compileExprListTop(const AstArray& list, uint8_t target, uint8_t targetCount) + // assumes target register range can be clobbered and is at the top of the register space if targetTop = true + void compileExprListTemp(const AstArray& list, uint8_t target, uint8_t targetCount, bool targetTop) { // we assume that target range is at the top of the register space and can be clobbered // this is what allows us to compile the last call expression - if it's a call - using targetTop=true - LUAU_ASSERT(unsigned(target + targetCount) == regTop); + LUAU_ASSERT(!targetTop || unsigned(target + targetCount) == regTop); if (list.size == targetCount) { @@ -1683,7 +1921,7 @@ struct Compiler if (AstExprCall* expr = last->as()) { - compileExprCall(expr, uint8_t(target + list.size - 1), uint8_t(targetCount - (list.size - 1)), /* targetTop= */ true); + compileExprCall(expr, uint8_t(target + list.size - 1), uint8_t(targetCount - (list.size - 1)), targetTop); } else if (AstExprVarargs* expr = last->as()) { @@ -1765,8 +2003,10 @@ struct Compiler if (AstExprLocal* expr = node->as()) { - if (expr->upvalue) + if (FFlag::LuauCompileSupportInlining ? !isExprLocalReg(expr) : expr->upvalue) { + LUAU_ASSERT(expr->upvalue); + LValue result = {LValue::Kind_Upvalue}; result.upval = getUpval(expr->local); result.location = node->location; @@ -1873,7 +2113,7 @@ struct Compiler bool isExprLocalReg(AstExpr* expr) { AstExprLocal* le = expr->as(); - if (!le || le->upvalue) + if (!le || (!FFlag::LuauCompileSupportInlining && le->upvalue)) return false; Local* l = locals.find(le->local); @@ -2080,6 +2320,23 @@ struct Compiler loops.pop_back(); } + void compileInlineReturn(AstStatReturn* stat, bool fallthrough) + { + setDebugLine(stat); // normally compileStat sets up line info, but compileInlineReturn can be called directly + + InlineFrame frame = inlineFrames.back(); + + compileExprListTemp(stat->list, frame.target, frame.targetCount, /* targetTop= */ false); + + if (!fallthrough) + { + size_t jumpLabel = bytecode.emitLabel(); + bytecode.emitAD(LOP_JUMP, 0, 0); + + inlineFrames.back().returnJumps.push_back(jumpLabel); + } + } + void compileStatReturn(AstStatReturn* stat) { RegScope rs(this); @@ -2138,7 +2395,7 @@ struct Compiler // note: allocReg in this case allocates into parent block register - note that we don't have RegScope here uint8_t vars = allocReg(stat, unsigned(stat->vars.size)); - compileExprListTop(stat->values, vars, uint8_t(stat->vars.size)); + compileExprListTemp(stat->values, vars, uint8_t(stat->vars.size), /* targetTop= */ true); for (size_t i = 0; i < stat->vars.size; ++i) pushLocal(stat->vars.data[i], uint8_t(vars + i)); @@ -2168,6 +2425,7 @@ struct Compiler bool visit(AstExpr* node) override { // functions may capture loop variable, and our upval handling doesn't handle elided variables (constant) + // TODO: we could remove this case if we changed function compilation to create temporary locals for constant upvalues result = result && !node->is(); return result; } @@ -2251,6 +2509,11 @@ struct Compiler compileStat(stat->body); } + // clean up fold state in case we need to recompile - normally we compile the loop body once, but due to inlining we may need to do it again + locstants[var].type = Constant::Type_Unknown; + + foldConstants(constants, variables, locstants, stat); + return true; } @@ -2336,12 +2599,17 @@ struct Compiler uint8_t regs = allocReg(stat, 3); // this puts initial values of (generator, state, index) into the loop registers - compileExprListTop(stat->values, regs, 3); + compileExprListTemp(stat->values, regs, 3, /* targetTop= */ true); - // for the general case, we will execute a CALL for every iteration that needs to evaluate "variables... = generator(state, index)" - // this requires at least extra 3 stack slots after index - // note that these stack slots overlap with the variables so we only need to reserve them to make sure stack frame is large enough - reserveReg(stat, 3); + // we don't need this because the extra stack space is just for calling the function with a loop protocol which is similar to calling + // metamethods - it should fit into the extra stack reservation + if (!FFlag::LuauCompileIterNoReserve) + { + // for the general case, we will execute a CALL for every iteration that needs to evaluate "variables... = generator(state, index)" + // this requires at least extra 3 stack slots after index + // note that these stack slots overlap with the variables so we only need to reserve them to make sure stack frame is large enough + reserveReg(stat, 3); + } // note that we reserve at least 2 variables; this allows our fast path to assume that we need 2 variables instead of 1 or 2 uint8_t vars = allocReg(stat, std::max(unsigned(stat->vars.size), 2u)); @@ -2350,7 +2618,7 @@ struct Compiler // Optimization: when we iterate through pairs/ipairs, we generate special bytecode that optimizes the traversal using internal iteration // index These instructions dynamically check if generator is equal to next/inext and bail out They assume that the generator produces 2 // variables, which is why we allocate at least 2 above (see vars assignment) - LuauOpcode skipOp = LOP_JUMP; + LuauOpcode skipOp = FFlag::LuauCompileIter ? LOP_FORGPREP : LOP_JUMP; LuauOpcode loopOp = LOP_FORGLOOP; if (options.optimizationLevel >= 1 && stat->vars.size <= 2) @@ -2367,7 +2635,7 @@ struct Compiler else if (builtin.isGlobal("pairs")) // for .. in pairs(t) { skipOp = LOP_FORGPREP_NEXT; - loopOp = LOP_FORGLOOP_NEXT; + loopOp = FFlag::LuauCompileIterNoPairs ? LOP_FORGLOOP : LOP_FORGLOOP_NEXT; } } else if (stat->values.size == 2) @@ -2377,7 +2645,7 @@ struct Compiler if (builtin.isGlobal("next")) // for .. in next,t { skipOp = LOP_FORGPREP_NEXT; - loopOp = LOP_FORGLOOP_NEXT; + loopOp = FFlag::LuauCompileIterNoPairs ? LOP_FORGLOOP : LOP_FORGLOOP_NEXT; } } } @@ -2514,10 +2782,10 @@ struct Compiler // compute values into temporaries uint8_t regs = allocReg(stat, unsigned(stat->vars.size)); - compileExprListTop(stat->values, regs, uint8_t(stat->vars.size)); + compileExprListTemp(stat->values, regs, uint8_t(stat->vars.size), /* targetTop= */ true); - // assign variables that have associated values; note that if we have fewer values than variables, we'll assign nil because compileExprListTop - // will generate nils + // assign variables that have associated values; note that if we have fewer values than variables, we'll assign nil because + // compileExprListTemp will generate nils for (size_t i = 0; i < stat->vars.size; ++i) { setDebugLine(stat->vars.data[i]); @@ -2675,7 +2943,10 @@ struct Compiler } else if (AstStatReturn* stat = node->as()) { - compileStatReturn(stat); + if (options.optimizationLevel >= 2 && !inlineFrames.empty()) + compileInlineReturn(stat, /* fallthrough= */ false); + else + compileStatReturn(stat); } else if (AstStatExpr* stat = node->as()) { @@ -3069,6 +3340,10 @@ struct Compiler { uint32_t id; std::vector upvals; + + uint64_t costModel = 0; + unsigned int stackSize = 0; + bool canInline = false; }; struct Local @@ -3098,6 +3373,16 @@ struct Compiler AstExpr* untilCondition; }; + struct InlineFrame + { + AstExprFunction* func; + + uint8_t target; + uint8_t targetCount; + + std::vector returnJumps; + }; + BytecodeBuilder& bytecode; CompileOptions options; @@ -3120,6 +3405,7 @@ struct Compiler std::vector upvals; std::vector loopJumps; std::vector loops; + std::vector inlineFrames; }; void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options) diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 7ad91d4..52ece73 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -3,6 +3,8 @@ #include +LUAU_FASTFLAG(LuauCompileSupportInlining) + namespace Luau { namespace Compile @@ -314,12 +316,35 @@ struct ConstantVisitor : AstVisitor LUAU_ASSERT(!"Unknown expression type"); } - if (result.type != Constant::Type_Unknown) - constants[node] = result; + recordConstant(constants, node, result); return result; } + template + void recordConstant(DenseHashMap& map, T key, const Constant& value) + { + if (value.type != Constant::Type_Unknown) + map[key] = value; + else if (!FFlag::LuauCompileSupportInlining) + ; + else if (Constant* old = map.find(key)) + old->type = Constant::Type_Unknown; + } + + void recordValue(AstLocal* local, const Constant& value) + { + // note: we rely on trackValues to have been run before us + Variable* v = variables.find(local); + LUAU_ASSERT(v); + + if (!v->written) + { + v->constant = (value.type != Constant::Type_Unknown); + recordConstant(locals, local, value); + } + } + bool visit(AstExpr* node) override { // note: we short-circuit the visitor traversal through any expression trees by returning false @@ -336,18 +361,7 @@ struct ConstantVisitor : AstVisitor { Constant arg = analyze(node->values.data[i]); - if (arg.type != Constant::Type_Unknown) - { - // note: we rely on trackValues to have been run before us - Variable* v = variables.find(node->vars.data[i]); - LUAU_ASSERT(v); - - if (!v->written) - { - locals[node->vars.data[i]] = arg; - v->constant = true; - } - } + recordValue(node->vars.data[i], arg); } if (node->vars.size > node->values.size) @@ -361,15 +375,8 @@ struct ConstantVisitor : AstVisitor { for (size_t i = node->values.size; i < node->vars.size; ++i) { - // note: we rely on trackValues to have been run before us - Variable* v = variables.find(node->vars.data[i]); - LUAU_ASSERT(v); - - if (!v->written) - { - locals[node->vars.data[i]].type = Constant::Type_Nil; - v->constant = true; - } + Constant nil = {Constant::Type_Nil}; + recordValue(node->vars.data[i], nil); } } } diff --git a/Sources.cmake b/Sources.cmake index f9263b2..d2430cc 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -264,6 +264,7 @@ if(TARGET Luau.UnitTest) tests/TypePack.test.cpp tests/TypeVar.test.cpp tests/Variant.test.cpp + tests/VisitTypeVar.test.cpp tests/main.cpp) endif() diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 1f3b094..f8baefa 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -1270,7 +1270,7 @@ const char* lua_setupvalue(lua_State* L, int funcindex, int n) L->top--; setobj(L, val, L->top); luaC_barrier(L, clvalue(fi), L->top); - luaC_upvalbarrier(L, NULL, val); + luaC_upvalbarrier(L, cast_to(UpVal*, NULL), val); } return name; } diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 718d387..6014919 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -15,6 +15,8 @@ #include #endif +LUAU_FASTFLAGVARIABLE(LuauFixBuiltinsStackLimit, false) + // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -1003,7 +1005,7 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St else if (nparams == 3 && ttisnumber(args) && ttisnumber(args + 1) && nvalue(args) == 1.0) n = int(nvalue(args + 1)); - if (n >= 0 && n <= t->sizearray && cast_int(L->stack_last - res) >= n) + if (n >= 0 && n <= t->sizearray && cast_int(L->stack_last - res) >= n && (!FFlag::LuauFixBuiltinsStackLimit || n + nparams <= LUAI_MAXCSTACK)) { TValue* array = t->array; for (int i = 0; i < n; ++i) diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 08d1ff5..797284a 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -120,7 +120,7 @@ #define luaC_upvalbarrier(L, uv, tv) \ { \ - if (iscollectable(tv) && iswhite(gcvalue(tv)) && (!(uv) || ((UpVal*)uv)->v != &((UpVal*)uv)->u.value)) \ + if (iscollectable(tv) && iswhite(gcvalue(tv)) && (!(uv) || (uv)->v != &(uv)->u.value)) \ luaC_barrierupval(L, gcvalue(tv)); \ } diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 3dc3bd1..8251b51 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -33,8 +33,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauTableNewBoundary2, false) - // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) @@ -431,7 +429,6 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) static int adjustasize(Table* t, int size, const TValue* ek) { - LUAU_ASSERT(FFlag::LuauTableNewBoundary2); bool tbound = t->node != dummynode || size < t->sizearray; int ekindex = ek && ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; /* move the array size up until the boundary is guaranteed to be inside the array part */ @@ -443,7 +440,7 @@ static int adjustasize(Table* t, int size, const TValue* ek) void luaH_resizearray(lua_State* L, Table* t, int nasize) { int nsize = (t->node == dummynode) ? 0 : sizenode(t); - int asize = FFlag::LuauTableNewBoundary2 ? adjustasize(t, nasize, NULL) : nasize; + int asize = adjustasize(t, nasize, NULL); resize(L, t, asize, nsize); } @@ -468,8 +465,7 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) int na = computesizes(nums, &nasize); int nh = totaluse - na; /* enforce the boundary invariant; for performance, only do hash lookups if we must */ - if (FFlag::LuauTableNewBoundary2) - nasize = adjustasize(t, nasize, ek); + nasize = adjustasize(t, nasize, ek); /* resize the table to new computed sizes */ resize(L, t, nasize, nh); } @@ -531,7 +527,7 @@ static LuaNode* getfreepos(Table* t) static TValue* newkey(lua_State* L, Table* t, const TValue* key) { /* enforce boundary invariant */ - if (FFlag::LuauTableNewBoundary2 && ttisnumber(key) && nvalue(key) == t->sizearray + 1) + if (ttisnumber(key) && nvalue(key) == t->sizearray + 1) { rehash(L, t, key); /* grow table */ @@ -713,37 +709,6 @@ TValue* luaH_setstr(lua_State* L, Table* t, TString* key) } } -static LUAU_NOINLINE int unbound_search(Table* t, unsigned int j) -{ - LUAU_ASSERT(!FFlag::LuauTableNewBoundary2); - unsigned int i = j; /* i is zero or a present index */ - j++; - /* find `i' and `j' such that i is present and j is not */ - while (!ttisnil(luaH_getnum(t, j))) - { - i = j; - j *= 2; - if (j > cast_to(unsigned int, INT_MAX)) - { /* overflow? */ - /* table was built with bad purposes: resort to linear search */ - i = 1; - while (!ttisnil(luaH_getnum(t, i))) - i++; - return i - 1; - } - } - /* now do a binary search between them */ - while (j - i > 1) - { - unsigned int m = (i + j) / 2; - if (ttisnil(luaH_getnum(t, m))) - j = m; - else - i = m; - } - return i; -} - static int updateaboundary(Table* t, int boundary) { if (boundary < t->sizearray && ttisnil(&t->array[boundary - 1])) @@ -800,17 +765,12 @@ int luaH_getn(Table* t) maybesetaboundary(t, boundary); return boundary; } - else if (FFlag::LuauTableNewBoundary2) + else { /* validate boundary invariant */ LUAU_ASSERT(t->node == dummynode || ttisnil(luaH_getnum(t, j + 1))); return j; } - /* else must find a boundary in hash part */ - else if (t->node == dummynode) /* hash part is empty? */ - return j; /* that is easy... */ - else - return unbound_search(t, j); } Table* luaH_clone(lua_State* L, Table* tt) diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index 106efb2..9b99506 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -37,6 +37,8 @@ const char* const luaT_eventname[] = { "__newindex", "__mode", "__namecall", + "__call", + "__iter", "__eq", @@ -54,13 +56,13 @@ const char* const luaT_eventname[] = { "__lt", "__le", "__concat", - "__call", "__type", }; // clang-format on static_assert(sizeof(luaT_typenames) / sizeof(luaT_typenames[0]) == LUA_T_COUNT, "luaT_typenames size mismatch"); static_assert(sizeof(luaT_eventname) / sizeof(luaT_eventname[0]) == TM_N, "luaT_eventname size mismatch"); +static_assert(TM_EQ < 8, "fasttm optimization stores a bitfield with metamethods in a byte"); void luaT_init(lua_State* L) { diff --git a/VM/src/ltm.h b/VM/src/ltm.h index 0e4e915..e1b95c2 100644 --- a/VM/src/ltm.h +++ b/VM/src/ltm.h @@ -16,6 +16,8 @@ typedef enum TM_NEWINDEX, TM_MODE, TM_NAMECALL, + TM_CALL, + TM_ITER, TM_EQ, /* last tag method with `fast' access */ @@ -33,7 +35,6 @@ typedef enum TM_LT, TM_LE, TM_CONCAT, - TM_CALL, TM_TYPE, TM_N /* number of elements in the enum */ diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 39c60ea..3c7c276 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,7 +16,10 @@ #include -LUAU_FASTFLAG(LuauTableNewBoundary2) +LUAU_FASTFLAGVARIABLE(LuauIter, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauIterCallTelemetry, false) + +void (*lua_iter_call_telemetry)(lua_State* L); // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ @@ -110,7 +113,7 @@ LUAU_FASTFLAG(LuauTableNewBoundary2) VM_DISPATCH_OP(LOP_FORGLOOP_NEXT), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \ VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \ VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_JUMPIFEQK), VM_DISPATCH_OP(LOP_JUMPIFNOTEQK), VM_DISPATCH_OP(LOP_FASTCALL1), \ - VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), + VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), VM_DISPATCH_OP(LOP_FORGPREP), #if defined(__GNUC__) || defined(__clang__) #define VM_USE_CGOTO 1 @@ -150,8 +153,20 @@ LUAU_NOINLINE static void luau_prepareFORN(lua_State* L, StkId plimit, StkId pst LUAU_NOINLINE static bool luau_loopFORG(lua_State* L, int a, int c) { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) StkId ra = &L->base[a]; - LUAU_ASSERT(ra + 6 <= L->top); + LUAU_ASSERT(ra + 3 <= L->top); + + if (DFFlag::LuauIterCallTelemetry) + { + /* TODO: we might be able to stop supporting this depending on whether it's used in practice */ + void (*telemetrycb)(lua_State* L) = lua_iter_call_telemetry; + + if (telemetrycb && ttistable(ra) && fasttm(L, hvalue(ra)->metatable, TM_CALL)) + telemetrycb(L); + if (telemetrycb && ttisuserdata(ra) && fasttm(L, uvalue(ra)->metatable, TM_CALL)) + telemetrycb(L); + } setobjs2s(L, ra + 3 + 2, ra + 2); setobjs2s(L, ra + 3 + 1, ra + 1); @@ -2204,20 +2219,149 @@ static void luau_execute(lua_State* L) } } + VM_CASE(LOP_FORGPREP) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + if (ttisfunction(ra)) + { + /* will be called during FORGLOOP */ + } + else if (FFlag::LuauIter) + { + Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); + + if (const TValue* fn = fasttm(L, mt, TM_ITER)) + { + setobj2s(L, ra + 1, ra); + setobj2s(L, ra, fn); + + L->top = ra + 2; /* func + self arg */ + LUAU_ASSERT(L->top <= L->stack_last); + + VM_PROTECT(luaD_call(L, ra, 3)); + L->top = L->ci->top; + } + else if (fasttm(L, mt, TM_CALL)) + { + /* table or userdata with __call, will be called during FORGLOOP */ + /* TODO: we might be able to stop supporting this depending on whether it's used in practice */ + } + else if (ttistable(ra)) + { + /* set up registers for builtin iteration */ + setobj2s(L, ra + 1, ra); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + setnilvalue(ra); + } + else + { + VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); + } + } + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + VM_CASE(LOP_FORGLOOP) { VM_INTERRUPT(); Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); uint32_t aux = *pc; - // note: this is a slow generic path, fast-path is FORGLOOP_INEXT/NEXT - bool stop; - VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), aux)); + if (!FFlag::LuauIter) + { + bool stop; + VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), aux)); - // note that we need to increment pc by 1 to exit the loop since we need to skip over aux - pc += stop ? 1 : LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); + // note that we need to increment pc by 1 to exit the loop since we need to skip over aux + pc += stop ? 1 : LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + // fast-path: builtin table iteration + if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) + { + Table* h = hvalue(ra + 1); + int index = int(reinterpret_cast(pvalue(ra + 2))); + + int sizearray = h->sizearray; + int sizenode = 1 << h->lsizenode; + + // clear extra variables since we might have more than two + if (LUAU_UNLIKELY(aux > 2)) + for (int i = 2; i < int(aux); ++i) + setnilvalue(ra + 3 + i); + + // first we advance index through the array portion + while (unsigned(index) < unsigned(sizearray)) + { + if (!ttisnil(&h->array[index])) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setnvalue(ra + 3, double(index + 1)); + setobj2s(L, ra + 4, &h->array[index]); + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + index++; + } + + // then we advance index through the hash portion + while (unsigned(index - sizearray) < unsigned(sizenode)) + { + LuaNode* n = &h->node[index - sizearray]; + + if (!ttisnil(gval(n))) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + getnodekey(L, ra + 3, n); + setobj2s(L, ra + 4, gval(n)); + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + index++; + } + + // fallthrough to exit + pc++; + VM_NEXT(); + } + else + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + setobjs2s(L, ra + 3 + 2, ra + 2); + setobjs2s(L, ra + 3 + 1, ra + 1); + setobjs2s(L, ra + 3, ra); + + L->top = ra + 3 + 3; /* func + 2 args (state and index) */ + LUAU_ASSERT(L->top <= L->stack_last); + + VM_PROTECT(luaD_call(L, ra + 3, aux)); + L->top = L->ci->top; + + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + + // copy first variable back into the iteration index + setobjs2s(L, ra + 2, ra + 3); + + // note that we need to increment pc by 1 to exit the loop since we need to skip over aux + pc += ttisnil(ra + 3) ? 1 : LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } } VM_CASE(LOP_FORGPREP_INEXT) @@ -2228,8 +2372,15 @@ static void luau_execute(lua_State* L) // fast-path: ipairs/inext if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) { + if (FFlag::LuauIter) + setnilvalue(ra); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } + else if (FFlag::LuauIter && !ttisfunction(ra)) + { + VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); + } pc += LUAU_INSN_D(insn); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -2268,23 +2419,9 @@ static void luau_execute(lua_State* L) VM_NEXT(); } } - else if (FFlag::LuauTableNewBoundary2 || (h->lsizenode == 0 && ttisnil(gval(h->node)))) - { - // fallthrough to exit - VM_NEXT(); - } else { - // the table has a hash part; index + 1 may appear in it in which case we need to iterate through the hash portion as well - const TValue* val = luaH_getnum(h, index + 1); - - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); - setnvalue(ra + 3, double(index + 1)); - setobj2s(L, ra + 4, val); - - // note that nil elements inside the array terminate the traversal - pc += ttisnil(ra + 4) ? 0 : LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + // fallthrough to exit VM_NEXT(); } } @@ -2308,8 +2445,15 @@ static void luau_execute(lua_State* L) // fast-path: pairs/next if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) { + if (FFlag::LuauIter) + setnilvalue(ra); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } + else if (FFlag::LuauIter && !ttisfunction(ra)) + { + VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); + } pc += LUAU_INSN_D(insn); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -2704,7 +2848,7 @@ static void luau_execute(lua_State* L) { VM_PROTECT_PC(); - int n = f(L, ra, arg, nresults, nullptr, nparams); + int n = f(L, ra, arg, nresults, NULL, nparams); if (n >= 0) { diff --git a/bench/micro_tests/test_LargeTableSum_loop_iter.lua b/bench/micro_tests/test_LargeTableSum_loop_iter.lua new file mode 100644 index 0000000..057420f --- /dev/null +++ b/bench/micro_tests/test_LargeTableSum_loop_iter.lua @@ -0,0 +1,17 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local t = {} + + for i=1,1000000 do t[i] = i end + + local ts0 = os.clock() + local sum = 0 + for k,v in t do sum = sum + v end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "LargeTableSum: for k,v in {}") diff --git a/bench/tests/sunspider/3d-cube.lua b/bench/tests/sunspider/3d-cube.lua index 5d162ab..77fa085 100644 --- a/bench/tests/sunspider/3d-cube.lua +++ b/bench/tests/sunspider/3d-cube.lua @@ -25,7 +25,7 @@ local DisplArea = {} DisplArea.Width = 300; DisplArea.Height = 300; -function DrawLine(From, To) +local function DrawLine(From, To) local x1 = From.V[1]; local x2 = To.V[1]; local y1 = From.V[2]; @@ -81,7 +81,7 @@ function DrawLine(From, To) Q.LastPx = NumPix; end -function CalcCross(V0, V1) +local function CalcCross(V0, V1) local Cross = {}; Cross[1] = V0[2]*V1[3] - V0[3]*V1[2]; Cross[2] = V0[3]*V1[1] - V0[1]*V1[3]; @@ -89,7 +89,7 @@ function CalcCross(V0, V1) return Cross; end -function CalcNormal(V0, V1, V2) +local function CalcNormal(V0, V1, V2) local A = {}; local B = {}; for i = 1,3 do A[i] = V0[i] - V1[i]; @@ -102,14 +102,14 @@ function CalcNormal(V0, V1, V2) return A; end -function CreateP(X,Y,Z) +local function CreateP(X,Y,Z) local result = {} result.V = {X,Y,Z,1}; return result end -- multiplies two matrices -function MMulti(M1, M2) +local function MMulti(M1, M2) local M = {{},{},{},{}}; for i = 1,4 do for j = 1,4 do @@ -120,7 +120,7 @@ function MMulti(M1, M2) end -- multiplies matrix with vector -function VMulti(M, V) +local function VMulti(M, V) local Vect = {}; for i = 1,4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4]; @@ -128,7 +128,7 @@ function VMulti(M, V) return Vect; end -function VMulti2(M, V) +local function VMulti2(M, V) local Vect = {}; for i = 1,3 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3]; @@ -137,7 +137,7 @@ function VMulti2(M, V) end -- add to matrices -function MAdd(M1, M2) +local function MAdd(M1, M2) local M = {{},{},{},{}}; for i = 1,4 do for j = 1,4 do @@ -147,7 +147,7 @@ function MAdd(M1, M2) return M; end -function Translate(M, Dx, Dy, Dz) +local function Translate(M, Dx, Dy, Dz) local T = { {1,0,0,Dx}, {0,1,0,Dy}, @@ -157,7 +157,7 @@ function Translate(M, Dx, Dy, Dz) return MMulti(T, M); end -function RotateX(M, Phi) +local function RotateX(M, Phi) local a = Phi; a = a * math.pi / 180; local Cos = math.cos(a); @@ -171,7 +171,7 @@ function RotateX(M, Phi) return MMulti(R, M); end -function RotateY(M, Phi) +local function RotateY(M, Phi) local a = Phi; a = a * math.pi / 180; local Cos = math.cos(a); @@ -185,7 +185,7 @@ function RotateY(M, Phi) return MMulti(R, M); end -function RotateZ(M, Phi) +local function RotateZ(M, Phi) local a = Phi; a = a * math.pi / 180; local Cos = math.cos(a); @@ -199,7 +199,7 @@ function RotateZ(M, Phi) return MMulti(R, M); end -function DrawQube() +local function DrawQube() -- calc current normals local CurN = {}; local i = 5; @@ -245,7 +245,7 @@ function DrawQube() Q.LastPx = 0; end -function Loop() +local function Loop() if (Testing.LoopCount > Testing.LoopMax) then return; end local TestingStr = tostring(Testing.LoopCount); while (#TestingStr < 3) do TestingStr = "0" .. TestingStr; end @@ -265,7 +265,7 @@ function Loop() Loop(); end -function Init(CubeSize) +local function Init(CubeSize) -- init/reset vars Origin.V = {150,150,20,1}; Testing.LoopCount = 0; diff --git a/bench/tests/sunspider/3d-morph.lua b/bench/tests/sunspider/3d-morph.lua index f73f173..79e9141 100644 --- a/bench/tests/sunspider/3d-morph.lua +++ b/bench/tests/sunspider/3d-morph.lua @@ -31,7 +31,7 @@ local loops = 15 local nx = 120 local nz = 120 -function morph(a, f) +local function morph(a, f) local PI2nx = math.pi * 8/nx local sin = math.sin local f30 = -(50 * sin(f*math.pi*2)) diff --git a/bench/tests/sunspider/3d-raytrace.lua b/bench/tests/sunspider/3d-raytrace.lua index c8f6b5d..3d5276c 100644 --- a/bench/tests/sunspider/3d-raytrace.lua +++ b/bench/tests/sunspider/3d-raytrace.lua @@ -28,40 +28,40 @@ function test() local size = 30 -function createVector(x,y,z) +local function createVector(x,y,z) return { x,y,z }; end -function sqrLengthVector(self) +local function sqrLengthVector(self) return self[1] * self[1] + self[2] * self[2] + self[3] * self[3]; end -function lengthVector(self) +local function lengthVector(self) return math.sqrt(self[1] * self[1] + self[2] * self[2] + self[3] * self[3]); end -function addVector(self, v) +local function addVector(self, v) self[1] = self[1] + v[1]; self[2] = self[2] + v[2]; self[3] = self[3] + v[3]; return self; end -function subVector(self, v) +local function subVector(self, v) self[1] = self[1] - v[1]; self[2] = self[2] - v[2]; self[3] = self[3] - v[3]; return self; end -function scaleVector(self, scale) +local function scaleVector(self, scale) self[1] = self[1] * scale; self[2] = self[2] * scale; self[3] = self[3] * scale; return self; end -function normaliseVector(self) +local function normaliseVector(self) local len = math.sqrt(self[1] * self[1] + self[2] * self[2] + self[3] * self[3]); self[1] = self[1] / len; self[2] = self[2] / len; @@ -69,39 +69,39 @@ function normaliseVector(self) return self; end -function add(v1, v2) +local function add(v1, v2) return { v1[1] + v2[1], v1[2] + v2[2], v1[3] + v2[3] }; end -function sub(v1, v2) +local function sub(v1, v2) return { v1[1] - v2[1], v1[2] - v2[2], v1[3] - v2[3] }; end -function scalev(v1, v2) +local function scalev(v1, v2) return { v1[1] * v2[1], v1[2] * v2[2], v1[3] * v2[3] }; end -function dot(v1, v2) +local function dot(v1, v2) return v1[1] * v2[1] + v1[2] * v2[2] + v1[3] * v2[3]; end -function scale(v, scale) +local function scale(v, scale) return { v[1] * scale, v[2] * scale, v[3] * scale }; end -function cross(v1, v2) +local function cross(v1, v2) return { v1[2] * v2[3] - v1[3] * v2[2], v1[3] * v2[1] - v1[1] * v2[3], v1[1] * v2[2] - v1[2] * v2[1] }; end -function normalise(v) +local function normalise(v) local len = lengthVector(v); return { v[1] / len, v[2] / len, v[3] / len }; end -function transformMatrix(self, v) +local function transformMatrix(self, v) local vals = self; local x = vals[1] * v[1] + vals[2] * v[2] + vals[3] * v[3] + vals[4]; local y = vals[5] * v[1] + vals[6] * v[2] + vals[7] * v[3] + vals[8]; @@ -109,7 +109,7 @@ function transformMatrix(self, v) return { x, y, z }; end -function invertMatrix(self) +local function invertMatrix(self) local temp = {} local tx = -self[4]; local ty = -self[8]; @@ -131,7 +131,7 @@ function invertMatrix(self) end -- Triangle intersection using barycentric coord method -function Triangle(p1, p2, p3) +local function Triangle(p1, p2, p3) local this = {} local edge1 = sub(p3, p1); @@ -205,7 +205,7 @@ function Triangle(p1, p2, p3) return this end -function Scene(a_triangles) +local function Scene(a_triangles) local this = {} this.triangles = a_triangles; this.lights = {}; @@ -302,7 +302,7 @@ local zero = { 0,0,0 }; -- this camera code is from notes i made ages ago, it is from *somewhere* -- i cannot remember where -- that somewhere is -function Camera(origin, lookat, up) +local function Camera(origin, lookat, up) local this = {} local zaxis = normaliseVector(subVector(lookat, origin)); @@ -357,7 +357,7 @@ function Camera(origin, lookat, up) return this end -function raytraceScene() +local function raytraceScene() local startDate = 13154863; local numTriangles = 2 * 6; local triangles = {}; -- numTriangles); @@ -450,7 +450,7 @@ function raytraceScene() return pixels; end -function arrayToCanvasCommands(pixels) +local function arrayToCanvasCommands(pixels) local s = {}; table.insert(s, 'Test\nvar pixels = ['); for y = 0,size-1 do @@ -485,7 +485,7 @@ for (var y = 0; y < size; y++) {\n\ return table.concat(s); end -testOutput = arrayToCanvasCommands(raytraceScene()); +local testOutput = arrayToCanvasCommands(raytraceScene()); --local f = io.output("output.html") --f:write(testOutput) diff --git a/bench/tests/sunspider/access-binary-trees.lua b/bench/tests/sunspider/access-binary-trees.lua deleted file mode 100644 index 9eb9358..0000000 --- a/bench/tests/sunspider/access-binary-trees.lua +++ /dev/null @@ -1,69 +0,0 @@ ---[[ - The Great Computer Language Shootout - http://shootout.alioth.debian.org/ - contributed by Isaac Gouy -]] - -local bench = script and require(script.Parent.bench_support) or require("bench_support") - -function test() - -function TreeNode(left,right,item) - local this = {} - this.left = left; - this.right = right; - this.item = item; - - this.itemCheck = function(self) - if (self.left==nil) then return self.item; - else return self.item + self.left:itemCheck() - self.right:itemCheck(); end - end - - return this -end - -function bottomUpTree(item,depth) - if (depth>0) then - return TreeNode( - bottomUpTree(2*item-1, depth-1) - ,bottomUpTree(2*item, depth-1) - ,item - ); - else - return TreeNode(nil,nil,item); - end -end - -local ret = 0; - -for n = 4,7,1 do - local minDepth = 4; - local maxDepth = math.max(minDepth + 2, n); - local stretchDepth = maxDepth + 1; - - local check = bottomUpTree(0,stretchDepth):itemCheck(); - - local longLivedTree = bottomUpTree(0,maxDepth); - - for depth = minDepth,maxDepth,2 do - local iterations = 2.0 ^ (maxDepth - depth + minDepth - 1) -- 1 << (maxDepth - depth + minDepth); - - check = 0; - for i = 1,iterations do - check = check + bottomUpTree(i,depth):itemCheck(); - check = check + bottomUpTree(-i,depth):itemCheck(); - end - end - - ret = ret + longLivedTree:itemCheck(); -end - -local expected = -4; - -if (ret ~= expected) then - assert(false, "ERROR: bad result: expected " .. expected .. " but got " .. ret); -end - -end - -bench.runCode(test, "access-binary-trees") diff --git a/bench/tests/sunspider/controlflow-recursive.lua b/bench/tests/sunspider/controlflow-recursive.lua index d079162..a2591b2 100644 --- a/bench/tests/sunspider/controlflow-recursive.lua +++ b/bench/tests/sunspider/controlflow-recursive.lua @@ -7,18 +7,18 @@ local bench = script and require(script.Parent.bench_support) or require("bench_ function test() -function ack(m,n) +local function ack(m,n) if (m==0) then return n+1; end if (n==0) then return ack(m-1,1); end return ack(m-1, ack(m,n-1) ); end -function fib(n) +local function fib(n) if (n < 2) then return 1; end return fib(n-2) + fib(n-1); end -function tak(x,y,z) +local function tak(x,y,z) if (y >= x) then return z; end return tak(tak(x-1,y,z), tak(y-1,z,x), tak(z-1,x,y)); end @@ -27,7 +27,7 @@ local result = 0; for i = 3,5 do result = result + ack(3,i); - result = result + fib(17.0+i); + result = result + fib(17+i); result = result + tak(3*i+3,2*i+2,i+1); end diff --git a/bench/tests/sunspider/crypto-aes.lua b/bench/tests/sunspider/crypto-aes.lua index 3b28972..8dd0cec 100644 --- a/bench/tests/sunspider/crypto-aes.lua +++ b/bench/tests/sunspider/crypto-aes.lua @@ -42,7 +42,68 @@ local Rcon = { { 0x00, 0x00, 0x00, 0x00 }, {0x1b, 0x00, 0x00, 0x00}, {0x36, 0x00, 0x00, 0x00} }; -function Cipher(input, w) -- main Cipher function [§5.1] +local function SubBytes(s, Nb) -- apply SBox to state S [§5.1.1] + for r = 0,3 do + for c = 0,Nb-1 do s[r + 1][c + 1] = Sbox[s[r + 1][c + 1] + 1]; end + end + return s; +end + + +local function ShiftRows(s, Nb) -- shift row r of state S left by r bytes [§5.1.2] + local t = {}; + for r = 1,3 do + for c = 0,3 do t[c + 1] = s[r + 1][((c + r) % Nb) + 1] end; -- shift into temp copy + for c = 0,3 do s[r + 1][c + 1] = t[c + 1]; end -- and copy back + end -- note that this will work for Nb=4,5,6, but not 7,8 (always 4 for AES): + return s; -- see fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf +end + + +local function MixColumns(s, Nb) -- combine bytes of each col of state S [§5.1.3] + for c = 0,3 do + local a = {}; -- 'a' is a copy of the current column from 's' + local b = {}; -- 'b' is a•{02} in GF(2^8) + for i = 0,3 do + a[i + 1] = s[i + 1][c + 1]; + + if bit32.band(s[i + 1][c + 1], 0x80) ~= 0 then + b[i + 1] = bit32.bxor(bit32.lshift(s[i + 1][c + 1], 1), 0x011b); + else + b[i + 1] = bit32.lshift(s[i + 1][c + 1], 1); + end + end + -- a[n] ^ b[n] is a•{03} in GF(2^8) + s[1][c + 1] = bit32.bxor(b[1], a[2], b[2], a[3], a[4]); -- 2*a0 + 3*a1 + a2 + a3 + s[2][c + 1] = bit32.bxor(a[1], b[2], a[3], b[3], a[4]); -- a0 * 2*a1 + 3*a2 + a3 + s[3][c + 1] = bit32.bxor(a[1], a[2], b[3], a[4], b[4]); -- a0 + a1 + 2*a2 + 3*a3 + s[4][c + 1] = bit32.bxor(a[1], b[1], a[2], a[3], b[4]); -- 3*a0 + a1 + a2 + 2*a3 +end + return s; +end + + +local function SubWord(w) -- apply SBox to 4-byte word w + for i = 0,3 do w[i + 1] = Sbox[w[i + 1] + 1]; end + return w; +end + +local function RotWord(w) -- rotate 4-byte word w left by one byte + w[5] = w[1]; + for i = 0,3 do w[i + 1] = w[i + 2]; end + return w; +end + + + +local function AddRoundKey(state, w, rnd, Nb) -- xor Round Key into state S [§5.1.4] + for r = 0,3 do + for c = 0,Nb-1 do state[r + 1][c + 1] = bit32.bxor(state[r + 1][c + 1], w[rnd*4+c + 1][r + 1]); end + end + return state; +end + +local function Cipher(input, w) -- main Cipher function [§5.1] local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES) local Nr = #w / Nb - 1; -- no of rounds: 10/12/14 for 128/192/256-bit keys @@ -69,56 +130,7 @@ function Cipher(input, w) -- main Cipher function [§5.1] end -function SubBytes(s, Nb) -- apply SBox to state S [§5.1.1] - for r = 0,3 do - for c = 0,Nb-1 do s[r + 1][c + 1] = Sbox[s[r + 1][c + 1] + 1]; end - end - return s; -end - - -function ShiftRows(s, Nb) -- shift row r of state S left by r bytes [§5.1.2] - local t = {}; - for r = 1,3 do - for c = 0,3 do t[c + 1] = s[r + 1][((c + r) % Nb) + 1] end; -- shift into temp copy - for c = 0,3 do s[r + 1][c + 1] = t[c + 1]; end -- and copy back - end -- note that this will work for Nb=4,5,6, but not 7,8 (always 4 for AES): - return s; -- see fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf -end - - -function MixColumns(s, Nb) -- combine bytes of each col of state S [§5.1.3] - for c = 0,3 do - local a = {}; -- 'a' is a copy of the current column from 's' - local b = {}; -- 'b' is a•{02} in GF(2^8) - for i = 0,3 do - a[i + 1] = s[i + 1][c + 1]; - - if bit32.band(s[i + 1][c + 1], 0x80) ~= 0 then - b[i + 1] = bit32.bxor(bit32.lshift(s[i + 1][c + 1], 1), 0x011b); - else - b[i + 1] = bit32.lshift(s[i + 1][c + 1], 1); - end - end - -- a[n] ^ b[n] is a•{03} in GF(2^8) - s[1][c + 1] = bit32.bxor(b[1], a[2], b[2], a[3], a[4]); -- 2*a0 + 3*a1 + a2 + a3 - s[2][c + 1] = bit32.bxor(a[1], b[2], a[3], b[3], a[4]); -- a0 * 2*a1 + 3*a2 + a3 - s[3][c + 1] = bit32.bxor(a[1], a[2], b[3], a[4], b[4]); -- a0 + a1 + 2*a2 + 3*a3 - s[4][c + 1] = bit32.bxor(a[1], b[1], a[2], a[3], b[4]); -- 3*a0 + a1 + a2 + 2*a3 -end - return s; -end - - -function AddRoundKey(state, w, rnd, Nb) -- xor Round Key into state S [§5.1.4] - for r = 0,3 do - for c = 0,Nb-1 do state[r + 1][c + 1] = bit32.bxor(state[r + 1][c + 1], w[rnd*4+c + 1][r + 1]); end - end - return state; -end - - -function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from Key [§5.2] +local function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from Key [§5.2] local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES) local Nk = #key / 4 -- key length (in words): 4/6/8 for 128/192/256-bit keys local Nr = Nk + 6; -- no of rounds: 10/12/14 for 128/192/256-bit keys @@ -146,17 +158,17 @@ function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from return w; end -function SubWord(w) -- apply SBox to 4-byte word w - for i = 0,3 do w[i + 1] = Sbox[w[i + 1] + 1]; end - return w; +local function escCtrlChars(str) -- escape control chars which might cause problems handling ciphertext + return string.gsub(str, "[\0\t\n\v\f\r\'\"!-]", function(c) return '!' .. string.byte(c, 1) .. '!'; end); end -function RotWord(w) -- rotate 4-byte word w left by one byte - w[5] = w[1]; - for i = 0,3 do w[i + 1] = w[i + 2]; end - return w; -end +local function unescCtrlChars(str) -- unescape potentially problematic control characters + return string.gsub(str, "!%d%d?%d?!", function(c) + local sc = string.sub(c, 2,-2) + return string.char(tonumber(sc)); + end); +end --[[ * Use AES to encrypt 'plaintext' with 'password' using 'nBits' key, in 'Counter' mode of operation @@ -166,7 +178,7 @@ end * - cipherblock = plaintext xor outputblock ]] -function AESEncryptCtr(plaintext, password, nBits) +local function AESEncryptCtr(plaintext, password, nBits) if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys -- for this example script, generate the key by applying Cipher to 1st 16/24/32 chars of password; @@ -243,7 +255,7 @@ end * - cipherblock = plaintext xor outputblock ]] -function AESDecryptCtr(ciphertext, password, nBits) +local function AESDecryptCtr(ciphertext, password, nBits) if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys local nBytes = nBits/8; -- no bytes in key @@ -300,19 +312,7 @@ function AESDecryptCtr(ciphertext, password, nBits) return table.concat(plaintext) end -function escCtrlChars(str) -- escape control chars which might cause problems handling ciphertext - return string.gsub(str, "[\0\t\n\v\f\r\'\"!-]", function(c) return '!' .. string.byte(c, 1) .. '!'; end); -end - -function unescCtrlChars(str) -- unescape potentially problematic control characters - return string.gsub(str, "!%d%d?%d?!", function(c) - local sc = string.sub(c, 2,-2) - - return string.char(tonumber(sc)); - end); -end - -function test() +local function test() local plainText = "ROMEO: But, soft! what light through yonder window breaks?\n\ It is the east, and Juliet is the sun.\n\ diff --git a/bench/tests/sunspider/math-cordic.lua b/bench/tests/sunspider/math-cordic.lua index 94a64f4..cdb10fa 100644 --- a/bench/tests/sunspider/math-cordic.lua +++ b/bench/tests/sunspider/math-cordic.lua @@ -31,15 +31,15 @@ function test() local AG_CONST = 0.6072529350; -function FIXED(X) +local function FIXED(X) return X * 65536.0; end -function FLOAT(X) +local function FLOAT(X) return X / 65536.0; end -function DEG2RAD(X) +local function DEG2RAD(X) return 0.017453 * (X); end @@ -52,7 +52,7 @@ local Angles = { local Target = 28.027; -function cordicsincos(Target) +local function cordicsincos(Target) local X; local Y; local TargetAngle; @@ -85,7 +85,7 @@ end local total = 0; -function cordic( runs ) +local function cordic( runs ) for i = 1,runs do total = total + cordicsincos(Target); end diff --git a/bench/tests/sunspider/math-partial-sums.lua b/bench/tests/sunspider/math-partial-sums.lua index 3c22287..9977cef 100644 --- a/bench/tests/sunspider/math-partial-sums.lua +++ b/bench/tests/sunspider/math-partial-sums.lua @@ -7,7 +7,7 @@ local bench = script and require(script.Parent.bench_support) or require("bench_ function test() -function partial(n) +local function partial(n) local a1, a2, a3, a4, a5, a6, a7, a8, a9 = 0, 0, 0, 0, 0, 0, 0, 0, 0; local twothirds = 2.0/3.0; local alt = -1.0; diff --git a/bench/tests/sunspider/math-spectral-norm.lua b/bench/tests/sunspider/math-spectral-norm.lua deleted file mode 100644 index 7d7ec16..0000000 --- a/bench/tests/sunspider/math-spectral-norm.lua +++ /dev/null @@ -1,72 +0,0 @@ ---[[ -The Great Computer Language Shootout -http://shootout.alioth.debian.org/ - -contributed by Ian Osgood -]] -local bench = script and require(script.Parent.bench_support) or require("bench_support") - -function test() - -function A(i,j) - return 1/((i+j)*(i+j+1)/2+i+1); -end - -function Au(u,v) - for i = 0,#u-1 do - local t = 0; - for j = 0,#u-1 do - t = t + A(i,j) * u[j + 1]; - end - v[i + 1] = t; - end -end - -function Atu(u,v) - for i = 0,#u-1 do - local t = 0; - for j = 0,#u-1 do - t = t + A(j,i) * u[j + 1]; - end - v[i + 1] = t; - end -end - -function AtAu(u,v,w) - Au(u,w); - Atu(w,v); -end - -function spectralnorm(n) - local u, v, w, vv, vBv = {}, {}, {}, 0, 0; - for i = 1,n do - u[i] = 1; v[i] = 0; w[i] = 0; - end - for i = 0,9 do - AtAu(u,v,w); - AtAu(v,u,w); - end - for i = 1,n do - vBv = vBv + u[i]*v[i]; - vv = vv + v[i]*v[i]; - end - return math.sqrt(vBv/vv); -end - -local total = 0; -local i = 6 - -while i <= 48 do - total = total + spectralnorm(i); - i = i * 2 -end - -local expected = 5.086694231303284; - -if (total ~= expected) then - assert(false, "ERROR: bad result: expected " .. expected .. " but got " .. total) -end - -end - -bench.runCode(test, "math-spectral-norm") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 5b70481..1c284f1 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2760,8 +2760,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") { - ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; - check(R"( type tag = "cat" | "dog" local function f(a: tag) end @@ -2798,8 +2796,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality") { - ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; - check(R"( type tagged = {tag:"cat", fieldx:number} | {tag:"dog", fieldy:number} local x: tagged = {tag="cat", fieldx=2} @@ -2821,8 +2817,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality") TEST_CASE_FIXTURE(ACFixture, "autocomplete_boolean_singleton") { - ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; - check(R"( local function f(x: true) end f(@1) @@ -2838,8 +2832,6 @@ f(@1) TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_escape") { - ScopedFastFlag luauAutocompleteSingletonTypes{"LuauAutocompleteSingletonTypes", true}; - check(R"( type tag = "strange\t\"cat\"" | 'nice\t"dog"' local function f(x: tag) end diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 7b4bfc7..f206438 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -261,6 +261,9 @@ RETURN R0 0 TEST_CASE("ForBytecode") { + ScopedFastFlag sff("LuauCompileIter", true); + ScopedFastFlag sff2("LuauCompileIterNoPairs", false); + // basic for loop: variable directly refers to internal iteration index (R2) CHECK_EQ("\n" + compileFunction0("for i=1,5 do print(i) end"), R"( LOADN R2 1 @@ -295,7 +298,7 @@ GETIMPORT R0 2 LOADK R1 K3 LOADK R2 K4 CALL R0 2 3 -JUMP +4 +FORGPREP R0 +4 GETIMPORT R5 6 MOVE R6 R3 CALL R5 1 0 @@ -347,6 +350,8 @@ RETURN R0 0 TEST_CASE("ForBytecodeBuiltin") { + ScopedFastFlag sff("LuauCompileIter", true); + // we generally recognize builtins like pairs/ipairs and emit special opcodes CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"( GETIMPORT R0 1 @@ -385,7 +390,7 @@ GETIMPORT R0 3 MOVE R1 R0 NEWTABLE R2 0 0 CALL R1 1 3 -JUMP +0 +FORGPREP R1 +0 FORGLOOP R1 -1 2 RETURN R0 0 )"); @@ -397,7 +402,7 @@ SETGLOBAL R0 K2 GETGLOBAL R0 K2 NEWTABLE R1 0 0 CALL R0 1 3 -JUMP +0 +FORGPREP R0 +0 FORGLOOP R0 -1 2 RETURN R0 0 )"); @@ -407,7 +412,7 @@ RETURN R0 0 GETIMPORT R0 1 NEWTABLE R1 0 0 CALL R0 1 3 -JUMP +0 +FORGPREP R0 +0 FORGLOOP R0 -1 2 RETURN R0 0 )"); @@ -2260,6 +2265,8 @@ TEST_CASE("TypeAliasing") TEST_CASE("DebugLineInfo") { + ScopedFastFlag sff("LuauCompileIterNoPairs", false); + Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); Luau::compileOrThrow(bcb, R"( @@ -2316,6 +2323,8 @@ return result TEST_CASE("DebugLineInfoFor") { + ScopedFastFlag sff("LuauCompileIter", true); + Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); Luau::compileOrThrow(bcb, R"( @@ -2336,7 +2345,7 @@ end 5: LOADN R0 1 7: LOADN R1 2 9: LOADN R2 3 -9: JUMP +4 +9: FORGPREP R0 +4 11: GETIMPORT R5 1 11: MOVE R6 R3 11: CALL R5 1 0 @@ -2541,6 +2550,8 @@ a TEST_CASE("DebugSource") { + ScopedFastFlag sff("LuauCompileIterNoPairs", false); + const char* source = R"( local kSelectedBiomes = { ['Mountains'] = true, @@ -2616,6 +2627,8 @@ RETURN R1 1 TEST_CASE("DebugLocals") { + ScopedFastFlag sff("LuauCompileIterNoPairs", false); + const char* source = R"( function foo(e, f) local a = 1 @@ -3767,6 +3780,8 @@ RETURN R0 1 TEST_CASE("SharedClosure") { + ScopedFastFlag sff("LuauCompileIterNoPairs", false); + // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( local val = ... @@ -4452,5 +4467,688 @@ RETURN R0 0 )"); } +TEST_CASE("InlineBasic") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // inline function that returns a constant + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + return 42 +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // inline function that returns the argument + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = foo(42) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // inline function that returns one of the two arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b, c) + if a then + return b + else + return c + end +end + +local x = foo(true, math.random(), 5) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETIMPORT R2 3 +CALL R2 0 1 +MOVE R1 R2 +RETURN R1 1 +RETURN R1 1 +)"); + + // inline function that returns one of the two arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b, c) + if a then + return b + else + return c + end +end + +local x = foo(true, 5, math.random()) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETIMPORT R2 3 +CALL R2 0 1 +LOADN R1 5 +RETURN R1 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineMutate") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // if the argument is mutated, it gets a register even if the value is constant + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = a or 5 + return a +end + +local x = foo(42) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R2 42 +ORK R2 R2 K1 +MOVE R1 R2 +RETURN R1 1 +)"); + + // if the argument is a local, it can be used directly + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = ... +local y = foo(x) +return y +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +MOVE R2 R1 +RETURN R2 1 +)"); + + // ... but if it's mutated, we move it in case it is mutated through a capture during the inlined function + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = ... +x = nil +local y = foo(x) +return y +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +LOADNIL R1 +MOVE R3 R1 +MOVE R2 R3 +RETURN R2 1 +)"); + + // we also don't inline functions if they have been assigned to + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +foo = foo + +local x = foo(42) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R0 R0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineUpval") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // if the argument is an upvalue, we naturally need to copy it to a local + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local b = ... + +function bar() + local x = foo(b) + return x +end +)", + 1, 2), + R"( +GETUPVAL R1 0 +MOVE R0 R1 +RETURN R0 1 +)"); + + // if the function uses an upvalue it's more complicated, because the lexical upvalue may become a local + CHECK_EQ("\n" + compileFunction(R"( +local b = ... + +local function foo(a) + return a + b +end + +local x = foo(42) +return x +)", + 1, 2), + R"( +GETVARARGS R0 1 +DUPCLOSURE R1 K0 +CAPTURE VAL R0 +LOADN R3 42 +ADD R2 R3 R0 +RETURN R2 1 +)"); + + // sometimes the lexical upvalue is deep enough that it's still an upvalue though + CHECK_EQ("\n" + compileFunction(R"( +local b = ... + +function bar() + local function foo(a) + return a + b + end + + local x = foo(42) + return x +end +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +CAPTURE UPVAL U0 +LOADN R2 42 +GETUPVAL R3 0 +ADD R1 R2 R3 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineFallthrough") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // if the function doesn't return, we still fill the results with nil + CHECK_EQ("\n" + compileFunction(R"( +local function foo() +end + +local a, b = foo() + +return a, b +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +LOADNIL R2 +MOVE R3 R1 +MOVE R4 R2 +RETURN R3 2 +)"); + + // this happens even if the function returns conditionally + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + if a then return 42 end +end + +local a, b = foo(false) + +return a, b +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +LOADNIL R2 +MOVE R3 R1 +MOVE R4 R2 +RETURN R3 2 +)"); + + // note though that we can't inline a function like this in multret context + // this is because we don't have a SETTOP instruction + CHECK_EQ("\n" + compileFunction(R"( +local function foo() +end + +return foo() +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 -1 +RETURN R1 -1 +)"); +} + +TEST_CASE("InlineCapture") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // can't inline function with nested functions that capture locals because they might be constants + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + local function bar() + return a + end + return bar() +end +)", + 1, 2), + R"( +NEWCLOSURE R1 P0 +CAPTURE VAL R0 +MOVE R2 R1 +CALL R2 0 -1 +RETURN R2 -1 +)"); +} + +TEST_CASE("InlineArgMismatch") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // when inlining a function, we must respect all the usual rules + + // caller might not have enough arguments + // TODO: we don't inline this atm + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); + + // caller might be using multret for arguments + // TODO: we don't inline this atm + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local x = foo(math.modf(1.5)) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADK R3 K1 +FASTCALL1 20 R3 +2 +GETIMPORT R2 4 +CALL R2 1 -1 +CALL R1 -1 1 +RETURN R1 1 +)"); + + // caller might have too many arguments, but we still need to compute them for side effects + // TODO: we don't inline this atm + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local x = foo(42, print()) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +GETIMPORT R3 2 +CALL R3 0 -1 +CALL R1 -1 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineMultiple") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // we call this with a different set of variable/constant args + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local x, y = ... +local a = foo(x, 1) +local b = foo(1, x) +local c = foo(1, 2) +local d = foo(x, y) +return a, b, c, d +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 2 +ADDK R3 R1 K1 +LOADN R5 1 +ADD R4 R5 R1 +LOADN R5 3 +ADD R6 R1 R2 +MOVE R7 R3 +MOVE R8 R4 +MOVE R9 R5 +MOVE R10 R6 +RETURN R7 4 +)"); +} + +TEST_CASE("InlineChain") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // inline a chain of functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local function bar(x) + return foo(x, 1) * foo(x, -1) +end + +local function baz() + return (bar(42)) +end + +return (baz()) +)", + 3, 2), + R"( +DUPCLOSURE R0 K0 +DUPCLOSURE R1 K1 +DUPCLOSURE R2 K2 +LOADN R4 43 +LOADN R5 41 +MUL R3 R4 R5 +RETURN R3 1 +)"); +} + +TEST_CASE("InlineThresholds") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + ScopedFastInt sfis[] = { + {"LuauCompileInlineThreshold", 25}, + {"LuauCompileInlineThresholdMaxBoost", 300}, + {"LuauCompileInlineDepth", 2}, + }; + + // this function has enormous register pressure (50 regs) so we choose not to inline it + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + return {{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}} +end + +return (foo()) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); + + // this function has less register pressure but a large cost + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + return {},{},{},{},{} +end + +return (foo()) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); + + // this chain of function is of length 3 but our limit in this test is 2, so we call foo twice + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local function bar(x) + return foo(x, 1) * foo(x, -1) +end + +local function baz() + return (bar(42)) +end + +return (baz()) +)", + 3, 2), + R"( +DUPCLOSURE R0 K0 +DUPCLOSURE R1 K1 +DUPCLOSURE R2 K2 +MOVE R4 R0 +LOADN R5 42 +LOADN R6 1 +CALL R4 2 1 +MOVE R5 R0 +LOADN R6 42 +LOADN R7 -1 +CALL R5 2 1 +MUL R3 R4 R5 +RETURN R3 1 +)"); +} + +TEST_CASE("InlineIIFE") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // IIFE with arguments + CHECK_EQ("\n" + compileFunction(R"( +function choose(a, b, c) + return ((function(a, b, c) if a then return b else return c end end)(a, b, c)) +end +)", + 1, 2), + R"( +JUMPIFNOT R0 +2 +MOVE R3 R1 +RETURN R3 1 +MOVE R3 R2 +RETURN R3 1 +RETURN R3 1 +)"); + + // IIFE with upvalues + CHECK_EQ("\n" + compileFunction(R"( +function choose(a, b, c) + return ((function() if a then return b else return c end end)()) +end +)", + 1, 2), + R"( +JUMPIFNOT R0 +2 +MOVE R3 R1 +RETURN R3 1 +MOVE R3 R2 +RETURN R3 1 +RETURN R3 1 +)"); +} + +TEST_CASE("InlineRecurseArguments") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + // we can't inline a function if it's used to compute its own arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) +end +foo(foo(foo,foo(foo,foo))[foo]) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +MOVE R4 R0 +MOVE R5 R0 +MOVE R6 R0 +CALL R4 2 1 +LOADNIL R3 +GETTABLE R2 R3 R0 +CALL R1 1 0 +RETURN R0 0 +)"); +} + +TEST_CASE("InlineFastCallK") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + CHECK_EQ("\n" + compileFunction(R"( +local function set(l0) + rawset({}, l0) +end + +set(false) +set({}) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +NEWTABLE R2 0 0 +FASTCALL2K 49 R2 K1 +4 +LOADK R3 K1 +GETIMPORT R1 3 +CALL R1 2 0 +NEWTABLE R1 0 0 +NEWTABLE R3 0 0 +FASTCALL2 49 R3 R1 +4 +MOVE R4 R1 +GETIMPORT R2 3 +CALL R2 2 0 +RETURN R0 0 +)"); +} + +TEST_CASE("InlineExprIndexK") +{ + ScopedFastFlag sff("LuauCompileSupportInlining", true); + + CHECK_EQ("\n" + compileFunction(R"( +local _ = function(l0) +local _ = nil +while _(_)[_] do +end +end +local _ = _(0)[""] +if _ then +do +for l0=0,8 do +end +end +elseif _ then +_ = nil +do +for l0=0,8 do +return true +end +end +end +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R4 +LOADNIL R5 +CALL R4 1 1 +LOADNIL R5 +GETTABLE R3 R4 R5 +JUMPIFNOT R3 +1 +JUMPBACK -7 +LOADNIL R2 +GETTABLEKS R1 R2 K1 +JUMPIFNOT R1 +1 +RETURN R0 0 +JUMPIFNOT R1 +19 +LOADNIL R1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +LOADB R2 1 +RETURN R2 1 +RETURN R0 0 +)"); +} TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 6f136d3..a23ea47 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -241,6 +241,8 @@ TEST_CASE("Math") TEST_CASE("Table") { + ScopedFastFlag sff("LuauFixBuiltinsStackLimit", true); + runConformance("nextvar.lua"); } @@ -1099,4 +1101,14 @@ TEST_CASE("UserdataApi") CHECK(dtorhits == 42); } +TEST_CASE("Iter") +{ + ScopedFastFlag sffs[] = { + { "LuauCompileIter", true }, + { "LuauIter", true }, + }; + + runConformance("iter.lua"); +} + TEST_SUITE_END(); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index e771b6b..a10e8f7 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -386,8 +386,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_error_paths") TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface") { - ScopedFastFlag luauCyclicModuleTypeSurface{"LuauCyclicModuleTypeSurface", true}; - fileResolver.source["game/A"] = R"( return {hello = 2} )"; @@ -410,8 +408,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface") TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_longer") { - ScopedFastFlag luauCyclicModuleTypeSurface{"LuauCyclicModuleTypeSurface", true}; - fileResolver.source["game/A"] = R"( return {mod_a = 2} )"; diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 55eafe3..69ff73a 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2041,8 +2041,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") TEST_CASE_FIXTURE(Fixture, "parse_type_pack_errors") { - ScopedFastFlag luauParseRecoverUnexpectedPack{"LuauParseRecoverUnexpectedPack", true}; - matchParseError("type Y = {a: T..., b: number}", "Unexpected '...' after type name; type pack is not allowed in this context", Location{{0, 20}, {0, 23}}); matchParseError("type Y = {a: (number | string)...", "Unexpected '...' after type annotation", Location{{0, 36}, {0, 39}}); @@ -2618,8 +2616,6 @@ type Y = (T...) -> U... TEST_CASE_FIXTURE(Fixture, "recover_unexpected_type_pack") { - ScopedFastFlag luauParseRecoverUnexpectedPack{"LuauParseRecoverUnexpectedPack", true}; - ParseResult result = tryParse(R"( type X = { a: T..., b: number } type Y = { a: T..., b: number } diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index 0b615e1..538f357 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -19,8 +19,8 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); struct LimitFixture : Fixture { -#if defined(_NOOPT) - ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 150}; +#if defined(_NOOPT) || defined(_DEBUG) + ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100}; #endif ScopedFastFlag LuauJustOneCallFrameForHaveSeen{"LuauJustOneCallFrameForHaveSeen", true}; @@ -35,12 +35,17 @@ bool hasError(const CheckResult& result, T* = nullptr) return it != result.errors.end(); } -TEST_SUITE_BEGIN("RuntimeLimitTests"); +TEST_SUITE_BEGIN("RuntimeLimits"); -TEST_CASE_FIXTURE(LimitFixture, "bail_early_on_typescript_port_of_Result_type" * doctest::timeout(1.0)) +TEST_CASE_FIXTURE(LimitFixture, "typescript_port_of_Result_type") { constexpr const char* src = R"LUA( --!strict + + -- Big thanks to Dionysusnu by letting us use this code as part of our test suite! + -- https://github.com/Dionysusnu/rbxts-rust-classes + -- Licensed under the MPL 2.0: https://raw.githubusercontent.com/Dionysusnu/rbxts-rust-classes/master/LICENSE + local TS = _G[script] local lazyGet = TS.import(script, script.Parent.Parent, "util", "lazyLoad").lazyGet local unit = TS.import(script, script.Parent.Parent, "util", "Unit").unit diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 960c6ed..f9b510c 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -488,4 +488,71 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_fail_missing_instantitation_follow") )"); } +TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") +{ + ScopedFastFlag sff{"LuauTypecheckIter", true}; + + CheckResult result = check(R"( + local t: {string} = {} + local key + for k: number in t do + end + for k: number, v: string in t do + end + for k, v in t do + key = k + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + CHECK_EQ(*typeChecker.numberType, *requireType("key")); +} + +TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") +{ + ScopedFastFlag sff{"LuauTypecheckIter", true}; + + CheckResult result = check(R"( + local t: {string} = {} + local extra + for k, v, e in t do + extra = e + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + CHECK_EQ(*typeChecker.nilType, *requireType("extra")); +} + +TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer") +{ + ScopedFastFlag sff{"LuauTypecheckIter", true}; + + CheckResult result = check(R"( + local t = {} + for k, v in t do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("Cannot iterate over a table without indexer", ge->message); +} + +TEST_CASE_FIXTURE(Fixture, "loop_iter_iter_metamethod") +{ + ScopedFastFlag sff{"LuauTypecheckIter", true}; + + CheckResult result = check(R"( + local t = {} + setmetatable(t, { __iter = function(o) return next, o.children end }) + for k: number, v: string in t do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index fa1f519..b6f49f9 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -5,7 +5,6 @@ #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" -#include "Luau/VisitTypeVar.h" #include "Fixture.h" diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 5bd522a..8e53599 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2331,7 +2331,7 @@ TEST_CASE_FIXTURE(Fixture, "confusing_indexing") TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table") { - ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter", true}; + ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter2", true}; CheckResult result = check(R"( local a: {x: number, y: number, [any]: any} | {y: number} @@ -2351,7 +2351,7 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table_2") { - ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter", true}; + ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter2", true}; CheckResult result = check(R"( local a: {y: number} | {x: number, y: number, [any]: any} diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index a578b1c..e81ef1a 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1034,4 +1034,45 @@ TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") LUAU_REQUIRE_NO_ERRORS(result); } +/** + * The problem we had here was that the type of q in B.h was initially inferring to {} | {prop: free} before we bound + * that second table to the enclosing union. + */ +TEST_CASE_FIXTURE(Fixture, "do_not_bind_a_free_table_to_a_union_containing_that_table") +{ + ScopedFastFlag flag[] = { + {"LuauStatFunctionSimplify4", true}, + {"LuauLowerBoundsCalculation", true}, + {"LuauDifferentOrderOfUnificationDoesntMatter2", true}, + }; + + CheckResult result = check(R"( + --!strict + + local A = {} + + function A:f() + local t = {} + + for key, value in pairs(self) do + t[key] = value + end + + return t + end + + local B = A:f() + + function B.g(t) + assert(type(t) == "table") + assert(t.prop ~= nil) + end + + function B.h(q) + q = q or {} + return q or {} + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index b6e9326..8756264 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -242,8 +242,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") { - ScopedFastFlag luauTxnLogPreserveOwner{"LuauTxnLogPreserveOwner", true}; - TypeId a = arena.addType(TypeVar{FreeTypeVar{TypeLevel{}}}); TypeId b = typeChecker.numberType; @@ -255,8 +253,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") { - ScopedFastFlag luauTxnLogPreserveOwner{"LuauTxnLogPreserveOwner", true}; - TypePackId a = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); TypePackId b = typeChecker.anyTypePack; diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index e033fe2..a45af39 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -313,23 +313,33 @@ TEST_CASE("tagging_props") CHECK(Luau::hasTag(prop, "foo")); } -struct VisitCountTracker +struct VisitCountTracker final : TypeVarOnceVisitor { std::unordered_map tyVisits; std::unordered_map tpVisits; - void cycle(TypeId) {} - void cycle(TypePackId) {} + void cycle(TypeId) override {} + void cycle(TypePackId) override {} template bool operator()(TypeId ty, const T& t) + { + return visit(ty); + } + + template + bool operator()(TypePackId tp, const T&) + { + return visit(tp); + } + + bool visit(TypeId ty) override { tyVisits[ty]++; return true; } - template - bool operator()(TypePackId tp, const T&) + bool visit(TypePackId tp) override { tpVisits[tp]++; return true; @@ -348,7 +358,7 @@ local b: (T, T, T) -> T VisitCountTracker tester; DenseHashSet seen{nullptr}; - visitTypeVarOnce(bType, tester, seen); + DEPRECATED_visitTypeVarOnce(bType, tester, seen); for (auto [_, count] : tester.tyVisits) CHECK_EQ(count, 1); diff --git a/tests/VisitTypeVar.test.cpp b/tests/VisitTypeVar.test.cpp new file mode 100644 index 0000000..3d426f1 --- /dev/null +++ b/tests/VisitTypeVar.test.cpp @@ -0,0 +1,48 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "Luau/RecursionCounter.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauUseVisitRecursionLimit) +LUAU_FASTINT(LuauVisitRecursionLimit) + +struct VisitTypeVarFixture : Fixture +{ + ScopedFastFlag flag1 = {"LuauUseVisitRecursionLimit", true}; + ScopedFastFlag flag2 = {"LuauRecursionLimitException", true}; +}; + +TEST_SUITE_BEGIN("VisitTypeVar"); + +TEST_CASE_FIXTURE(VisitTypeVarFixture, "throw_when_limit_is_exceeded") +{ + ScopedFastInt sfi{"LuauVisitRecursionLimit", 3}; + + CheckResult result = check(R"( + local t : {a: {b: {c: {d: {e: boolean}}}}} + )"); + + TypeId tType = requireType("t"); + + CHECK_THROWS_AS(toString(tType), RecursionLimitException); +} + +TEST_CASE_FIXTURE(VisitTypeVarFixture, "dont_throw_when_limit_is_high_enough") +{ + ScopedFastInt sfi{"LuauVisitRecursionLimit", 8}; + + CheckResult result = check(R"( + local t : {a: {b: {c: {d: {e: boolean}}}}} + )"); + + TypeId tType = requireType("t"); + + (void)toString(tType); +} + +TEST_SUITE_END(); diff --git a/tests/conformance/iter.lua b/tests/conformance/iter.lua new file mode 100644 index 0000000..468ffaf --- /dev/null +++ b/tests/conformance/iter.lua @@ -0,0 +1,196 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print('testing iteration') + +-- basic for loop tests +do + local a + for a,b in pairs{} do error("not here") end + for i=1,0 do error("not here") end + for i=0,1,-1 do error("not here") end + a = nil; for i=1,1 do assert(not a); a=1 end; assert(a) + a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a) + a = 0; for i=0, 1, 0.1 do a=a+1 end; assert(a==11) +end + +-- precision tests for for loops +do + local a + --a = 0; for i=1, 0, -0.01 do a=a+1 end; assert(a==101) + a = 0; for i=0, 0.999999999, 0.1 do a=a+1 end; assert(a==10) + a = 0; for i=1, 1, 1 do a=a+1 end; assert(a==1) + a = 0; for i=1e10, 1e10, -1 do a=a+1 end; assert(a==1) + a = 0; for i=1, 0.99999, 1 do a=a+1 end; assert(a==0) + a = 0; for i=99999, 1e5, -1 do a=a+1 end; assert(a==0) + a = 0; for i=1, 0.99999, -1 do a=a+1 end; assert(a==1) +end + +-- for loops do string->number coercion +do + local a = 0; for i="10","1","-2" do a=a+1 end; assert(a==5) +end + +-- generic for with function iterators +do + local function f (n, p) + local t = {}; for i=1,p do t[i] = i*10 end + return function (_,n) + if n > 0 then + n = n-1 + return n, unpack(t) + end + end, nil, n + end + + local x = 0 + for n,a,b,c,d in f(5,3) do + x = x+1 + assert(a == 10 and b == 20 and c == 30 and d == nil) + end + assert(x == 5) +end + +-- generic for with __call (tables) +do + local f = {} + setmetatable(f, { __call = function(_, _, n) if n > 0 then return n - 1 end end }) + + local x = 0 + for n in f, nil, 5 do + x += n + end + assert(x == 10) +end + +-- generic for with __call (userdata) +do + local f = newproxy(true) + getmetatable(f).__call = function(_, _, n) if n > 0 then return n - 1 end end + + local x = 0 + for n in f, nil, 5 do + x += n + end + assert(x == 10) +end + +-- generic for with pairs +do + local x = 0 + for k, v in pairs({a = 1, b = 2, c = 3}) do + x += v + end + assert(x == 6) +end + +-- generic for with pairs with holes +do + local x = 0 + for k, v in pairs({1, 2, 3, nil, 5}) do + x += v + end + assert(x == 11) +end + +-- generic for with ipairs +do + local x = 0 + for k, v in ipairs({1, 2, 3, nil, 5}) do + x += v + end + assert(x == 6) +end + +-- generic for with __iter (tables) +do + local f = {} + setmetatable(f, { __iter = function(x) + assert(f == x) + return next, {1, 2, 3, 4} + end }) + + local x = 0 + for n in f do + x += n + end + assert(x == 10) +end + +-- generic for with __iter (userdata) +do + local f = newproxy(true) + getmetatable(f).__iter = function(x) + assert(f == x) + return next, {1, 2, 3, 4} + end + + local x = 0 + for n in f do + x += n + end + assert(x == 10) +end + +-- generic for with tables (dictionary) +do + local x = 0 + for k, v in {a = 1, b = 2, c = 3} do + print(k, v) + x += v + end + assert(x == 6) +end + +-- generic for with tables (arrays) +do + local x = '' + for k, v in {1, 2, 3, nil, 5} do + x ..= tostring(v) + end + assert(x == "1235") +end + +-- generic for with tables (mixed) +do + local x = 0 + for k, v in {1, 2, 3, nil, 5, a = 1, b = 2, c = 3} do + x += v + end + assert(x == 17) +end + +-- generic for over a non-iterable object +do + local ok, err = pcall(function() for x in 42 do end end) + assert(not ok and err:match("attempt to iterate")) +end + +-- generic for over an iterable object that doesn't return a function +do + local obj = {} + setmetatable(obj, { __iter = function() end }) + + local ok, err = pcall(function() for x in obj do end end) + assert(not ok and err:match("attempt to call a nil value")) +end + +-- it's okay to iterate through a table with a single variable +do + local x = 0 + for k in {1, 2, 3, 4, 5} do + x += k + end + assert(x == 15) +end + +-- all extra variables should be set to nil during builtin traversal +do + local x = 0 + for k,v,a,b,c,d,e in {1, 2, 3, 4, 5} do + x += k + assert(a == nil and b == nil and c == nil and d == nil and e == nil) + end + assert(x == 15) +end + +return"OK" diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua index c817645..0dba8fa 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/nextvar.lua @@ -368,48 +368,6 @@ assert(next(a,nil) == 1000 and next(a,1000) == nil) assert(next({}) == nil) assert(next({}, nil) == nil) -for a,b in pairs{} do error("not here") end -for i=1,0 do error("not here") end -for i=0,1,-1 do error("not here") end -a = nil; for i=1,1 do assert(not a); a=1 end; assert(a) -a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a) - -a = 0; for i=0, 1, 0.1 do a=a+1 end; assert(a==11) --- precision problems ---a = 0; for i=1, 0, -0.01 do a=a+1 end; assert(a==101) -a = 0; for i=0, 0.999999999, 0.1 do a=a+1 end; assert(a==10) -a = 0; for i=1, 1, 1 do a=a+1 end; assert(a==1) -a = 0; for i=1e10, 1e10, -1 do a=a+1 end; assert(a==1) -a = 0; for i=1, 0.99999, 1 do a=a+1 end; assert(a==0) -a = 0; for i=99999, 1e5, -1 do a=a+1 end; assert(a==0) -a = 0; for i=1, 0.99999, -1 do a=a+1 end; assert(a==1) - --- conversion -a = 0; for i="10","1","-2" do a=a+1 end; assert(a==5) - - -collectgarbage() - - --- testing generic 'for' - -local function f (n, p) - local t = {}; for i=1,p do t[i] = i*10 end - return function (_,n) - if n > 0 then - n = n-1 - return n, unpack(t) - end - end, nil, n -end - -local x = 0 -for n,a,b,c,d in f(5,3) do - x = x+1 - assert(a == 10 and b == 20 and c == 30 and d == nil) -end -assert(x == 5) - -- testing table.create and table.find do local t = table.create(5) @@ -596,4 +554,17 @@ do assert(#t2 == 6) end +-- test table.unpack fastcall for rejecting large unpacks +do + local ok, res = pcall(function() + local a = table.create(7999, 0) + local b = table.create(8000, 0) + + local at = { table.unpack(a) } + local bt = { table.unpack(b) } + end) + + assert(not ok) +end + return"OK" diff --git a/tools/lldb_formatters.py b/tools/lldb_formatters.py index b3d2b4f..ff610d0 100644 --- a/tools/lldb_formatters.py +++ b/tools/lldb_formatters.py @@ -97,7 +97,7 @@ class LuauVariantSyntheticChildrenProvider: if self.current_type: storage = self.valobj.GetChildMemberWithName("storage") - self.stored_value = storage.Cast(self.current_type.GetPointerType()).Dereference() + self.stored_value = storage.Cast(self.current_type) else: self.stored_value = None else: From 298b33859b67014d76542b8a49c9be6b95730600 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 13 May 2022 12:16:50 -0700 Subject: [PATCH 09/19] Sync to upstream/release/527 --- Analysis/include/Luau/TypeVar.h | 2 +- Analysis/src/Clone.cpp | 4 +- Analysis/src/Normalize.cpp | 3 +- Analysis/src/Substitution.cpp | 4 +- Analysis/src/ToString.cpp | 16 +- Analysis/src/TypeInfer.cpp | 191 ++++++++------------- Ast/src/Parser.cpp | 3 +- CMakeLists.txt | 6 + Compiler/src/Compiler.cpp | 30 +++- Compiler/src/ConstantFolding.cpp | 6 +- Compiler/src/CostModel.cpp | 2 +- extern/isocline/src/bbcode.c | 1 + tests/AstQuery.test.cpp | 2 +- tests/Autocomplete.test.cpp | 44 +++-- tests/BuiltinDefinitions.test.cpp | 4 +- tests/Compiler.test.cpp | 71 ++++++-- tests/CostModel.test.cpp | 14 +- tests/Fixture.cpp | 19 +- tests/Fixture.h | 5 + tests/Frontend.test.cpp | 2 +- tests/Linter.test.cpp | 6 +- tests/Module.test.cpp | 4 +- tests/NonstrictMode.test.cpp | 6 +- tests/Normalize.test.cpp | 7 +- tests/Parser.test.cpp | 2 - tests/RuntimeLimits.test.cpp | 2 +- tests/ToDot.test.cpp | 2 +- tests/ToString.test.cpp | 4 +- tests/TypeInfer.aliases.test.cpp | 10 +- tests/TypeInfer.annotations.test.cpp | 8 +- tests/TypeInfer.anyerror.test.cpp | 4 +- tests/TypeInfer.builtins.test.cpp | 117 +++++++------ tests/TypeInfer.classes.test.cpp | 2 +- tests/TypeInfer.functions.test.cpp | 38 ++-- tests/TypeInfer.generics.test.cpp | 12 +- tests/TypeInfer.intersectionTypes.test.cpp | 6 +- tests/TypeInfer.loops.test.cpp | 28 +-- tests/TypeInfer.modules.test.cpp | 28 +-- tests/TypeInfer.oop.test.cpp | 4 +- tests/TypeInfer.operators.test.cpp | 71 ++++++-- tests/TypeInfer.provisional.test.cpp | 10 +- tests/TypeInfer.refinements.test.cpp | 20 +-- tests/TypeInfer.singletons.test.cpp | 2 +- tests/TypeInfer.tables.test.cpp | 74 ++++---- tests/TypeInfer.test.cpp | 17 +- tests/TypeInfer.tryUnify.test.cpp | 2 +- tests/TypeInfer.typePacks.cpp | 6 +- tests/TypeInfer.unionTypes.test.cpp | 6 +- 48 files changed, 511 insertions(+), 416 deletions(-) diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 8457675..9cacbc6 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -329,7 +329,7 @@ struct TableTypeVar // We need to know which is which when we stringify types. std::optional syntheticName; - std::map methodDefinitionLocations; + std::map methodDefinitionLocations; // TODO: Remove with FFlag::LuauNoMethodLocations std::vector instantiatedTypeParams; std::vector instantiatedTypePackParams; ModuleName definitionModuleName; diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index d5bd9da..1aa556e 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -11,6 +11,7 @@ LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAG(LuauTypecheckOptPass) LUAU_FASTFLAGVARIABLE(LuauLosslessClone, false) +LUAU_FASTFLAG(LuauNoMethodLocations) namespace Luau { @@ -277,7 +278,8 @@ void TypeCloner::operator()(const TableTypeVar& t) } ttv->definitionModuleName = t.definitionModuleName; - ttv->methodDefinitionLocations = t.methodDefinitionLocations; + if (!FFlag::LuauNoMethodLocations) + ttv->methodDefinitionLocations = t.methodDefinitionLocations; ttv->tags = t.tags; } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index d8c1138..ef5377a 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -14,7 +14,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); -LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineIntersectionFix, false); namespace Luau { @@ -863,7 +862,7 @@ struct Normalize final : TypeVarVisitor TypeId theTable = result->parts.back(); - if (!get(FFlag::LuauNormalizeCombineIntersectionFix ? follow(theTable) : theTable)) + if (!get(follow(theTable))) { result->parts.push_back(arena.addType(TableTypeVar{TableState::Sealed, TypeLevel{}})); theTable = result->parts.back(); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 30d8574..c5c7977 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -12,6 +12,7 @@ LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTFLAG(LuauTypecheckOptPass) LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowNewTypes, false) LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowPossibleMutations, false) +LUAU_FASTFLAG(LuauNoMethodLocations) namespace Luau { @@ -408,7 +409,8 @@ TypeId Substitution::clone(TypeId ty) { LUAU_ASSERT(!ttv->boundTo); TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + if (!FFlag::LuauNoMethodLocations) + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index b5d6a55..51665f7 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -248,6 +248,13 @@ struct StringifierState result.name += s; } + void emit(TypeLevel level) + { + emit(std::to_string(level.level)); + emit("-"); + emit(std::to_string(level.subLevel)); + } + void emit(const char* s) { if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) @@ -379,7 +386,7 @@ struct TypeVarStringifier if (FFlag::DebugLuauVerboseTypeNames) { state.emit("-"); - state.emit(std::to_string(ftv.level.level)); + state.emit(ftv.level); } } @@ -403,7 +410,10 @@ struct TypeVarStringifier { state.result.invalid = true; - state.emit("[["); + state.emit("["); + if (FFlag::DebugLuauVerboseTypeNames) + state.emit(ctv.level); + state.emit("["); bool first = true; for (TypeId ty : ctv.parts) @@ -947,7 +957,7 @@ struct TypePackStringifier if (FFlag::DebugLuauVerboseTypeNames) { state.emit("-"); - state.emit(std::to_string(pack.level.level)); + state.emit(pack.level); } state.emit("..."); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 4466ede..a13abd5 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,13 +36,11 @@ LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauInferStatFunction, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateFollows, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) -LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify4, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) @@ -53,12 +51,13 @@ LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) LUAU_FASTFLAGVARIABLE(LuauCheckImplicitNumbericKeys, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) -LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false) LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); LUAU_FASTFLAG(LuauLosslessClone) LUAU_FASTFLAGVARIABLE(LuauTypecheckIter, false); +LUAU_FASTFLAGVARIABLE(LuauSuccessTypingForEqualityOperations, false) +LUAU_FASTFLAGVARIABLE(LuauNoMethodLocations, false); namespace Luau { @@ -587,7 +586,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A { std::optional expectedType; - if (FFlag::LuauInferStatFunction && !fun->func->self) + if (!fun->func->self) { if (auto name = fun->name->as()) { @@ -1307,7 +1306,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; return; } - else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify4) + else if (auto name = function.name->as()) { TypeId exprTy = checkExpr(scope, *name->expr).type; TableTypeVar* ttv = getMutableTableType(exprTy); @@ -1341,7 +1340,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco if (ttv && ttv->state != TableState::Sealed) ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; } - else if (FFlag::LuauStatFunctionSimplify4) + else { LUAU_ASSERT(function.name->is()); @@ -1349,71 +1348,6 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); } - else if (function.func->self) - { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify4); - - AstExprIndexName* indexName = function.name->as(); - if (!indexName) - ice("member function declaration has malformed name expression"); - - TypeId selfTy = checkExpr(scope, *indexName->expr).type; - TableTypeVar* tableSelf = getMutableTableType(selfTy); - if (!tableSelf) - { - if (isTableIntersection(selfTy)) - reportError(TypeError{function.location, CannotExtendTable{selfTy, CannotExtendTable::Property, indexName->index.value}}); - else if (!get(selfTy) && !get(selfTy)) - reportError(TypeError{function.location, OnlyTablesCanHaveMethods{selfTy}}); - } - else if (tableSelf->state == TableState::Sealed) - reportError(TypeError{function.location, CannotExtendTable{selfTy, CannotExtendTable::Property, indexName->index.value}}); - - const bool tableIsExtendable = tableSelf && tableSelf->state != TableState::Sealed; - - ty = follow(ty); - - if (tableIsExtendable) - tableSelf->props[indexName->index.value] = {ty, /* deprecated */ false, {}, indexName->indexLocation}; - - const FunctionTypeVar* funTy = get(ty); - if (!funTy) - ice("Methods should be functions"); - - std::optional arg0 = first(funTy->argTypes); - if (!arg0) - ice("Methods should always have at least 1 argument (self)"); - - checkFunctionBody(funScope, ty, *function.func); - - if (tableIsExtendable) - tableSelf->props[indexName->index.value] = { - follow(quantify(funScope, ty, indexName->indexLocation)), /* deprecated */ false, {}, indexName->indexLocation}; - } - else - { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify4); - - TypeId leftType = checkLValueBinding(scope, *function.name); - - checkFunctionBody(funScope, ty, *function.func); - - unify(ty, leftType, function.location); - - LUAU_ASSERT(function.name->is() || function.name->is()); - - if (auto exprIndexName = function.name->as()) - { - if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr)) - { - if (auto ttv = getMutableTableType(*typeIt)) - { - if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end()) - it->second.type = follow(quantify(funScope, leftType, function.name->location)); - } - } - } - } } void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function) @@ -1523,7 +1457,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + if (!FFlag::LuauNoMethodLocations) + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = name; @@ -2652,13 +2587,58 @@ TypeId TypeChecker::checkRelationalOperation( std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); - // TODO: this check seems odd, the second part is redundant - // is it meant to be if (leftMetatable && rightMetatable && leftMetatable != rightMetatable) - if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) + if (FFlag::LuauSuccessTypingForEqualityOperations) { - reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); - return errorRecoveryType(booleanType); + if (leftMetatable != rightMetatable) + { + bool matches = false; + if (isEquality) + { + if (const UnionTypeVar* utv = get(leftType); utv && rightMetatable) + { + for (TypeId leftOption : utv) + { + if (getMetatable(follow(leftOption)) == rightMetatable) + { + matches = true; + break; + } + } + } + + if (!matches) + { + if (const UnionTypeVar* utv = get(rhsType); utv && leftMetatable) + { + for (TypeId rightOption : utv) + { + if (getMetatable(follow(rightOption)) == leftMetatable) + { + matches = true; + break; + } + } + } + } + } + + + if (!matches) + { + reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + return errorRecoveryType(booleanType); + } + } + } + else + { + if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) + { + reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + return errorRecoveryType(booleanType); + } } if (leftMetatable) @@ -2754,22 +2734,11 @@ TypeId TypeChecker::checkBinaryOperation( lhsType = follow(lhsType); rhsType = follow(rhsType); - if (FFlag::LuauDecoupleOperatorInferenceFromUnifiedTypeInference) + if (!isNonstrictMode() && get(lhsType)) { - if (!isNonstrictMode() && get(lhsType)) - { - auto name = getIdentifierOfBaseVar(expr.left); - reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - // We will fall-through to the `return anyType` check below. - } - } - else - { - if (!isNonstrictMode() && get(lhsType)) - { - auto name = getIdentifierOfBaseVar(expr.left); - reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - } + auto name = getIdentifierOfBaseVar(expr.left); + reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); + // We will fall-through to the `return anyType` check below. } // If we know nothing at all about the lhs type, we can usually say nothing about the result. @@ -3231,43 +3200,27 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T else if (auto indexName = funName.as()) { TypeId lhsType = checkExpr(scope, *indexName->expr).type; - - if (!FFlag::LuauStatFunctionSimplify4 && (get(lhsType) || get(lhsType))) - return lhsType; - TableTypeVar* ttv = getMutableTableType(lhsType); - if (FFlag::LuauStatFunctionSimplify4) + if (!ttv || ttv->state == TableState::Sealed) { - if (!ttv || ttv->state == TableState::Sealed) - { - if (auto ty = getIndexTypeFromType(scope, lhsType, indexName->index.value, indexName->indexLocation, false)) - return *ty; + if (auto ty = getIndexTypeFromType(scope, lhsType, indexName->index.value, indexName->indexLocation, false)) + return *ty; - return errorRecoveryType(scope); - } - } - else - { - if (!ttv || lhsType->persistent || ttv->state == TableState::Sealed) - return errorRecoveryType(scope); + return errorRecoveryType(scope); } Name name = indexName->index.value; if (ttv->props.count(name)) - { - if (FFlag::LuauStatFunctionSimplify4) - return ttv->props[name].type; - else - return errorRecoveryType(scope); - } + return ttv->props[name].type; Property& property = ttv->props[name]; property.type = freshTy(); property.location = indexName->indexLocation; - ttv->methodDefinitionLocations[name] = funName.location; + if (!FFlag::LuauNoMethodLocations) + ttv->methodDefinitionLocations[name] = funName.location; return property.type; } else if (funName.is()) @@ -4669,7 +4622,8 @@ TypeId ReplaceGenerics::clean(TypeId ty) if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + if (!FFlag::LuauNoMethodLocations) + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; return addType(std::move(clone)); } @@ -4715,7 +4669,8 @@ TypeId Anyification::clean(TypeId ty) if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + if (!FFlag::LuauNoMethodLocations) + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 91f5cd2..c053e6b 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,7 +10,6 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkipInCapture, false) namespace Luau { @@ -2821,7 +2820,7 @@ void Parser::nextLexeme() hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); } - type = lexer.next(/* skipComments= */ false, !FFlag::LuauParseLocationIgnoreCommentSkipInCapture).type; + type = lexer.next(/* skipComments= */ false, /* updatePrevLocation= */ false).type; } } else diff --git a/CMakeLists.txt b/CMakeLists.txt index af03b33..ea35230 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -110,6 +110,12 @@ if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) endif() +if(MSVC AND LUAU_BUILD_CLI) + # the default stack size that MSVC linker uses is 1 MB; we need more stack space in Debug because stack frames are larger + set_target_properties(Luau.Analyze.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152) + set_target_properties(Luau.Repl.CLI PROPERTIES LINK_FLAGS_DEBUG /STACK:2097152) +endif() + # embed .natvis inside the library debug information if(MSVC) target_link_options(Luau.Ast INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Ast.natvis) diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 4fe2622..e177e92 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -628,7 +628,12 @@ struct Compiler return; if (fi && !fi->canInline) - bytecode.addDebugRemark("inlining failed: complex constructs in function body"); + { + if (func->vararg) + bytecode.addDebugRemark("inlining failed: function is variadic"); + else + bytecode.addDebugRemark("inlining failed: complex constructs in function body"); + } } RegScope rs(this); @@ -2342,17 +2347,28 @@ struct Compiler RegScope rs(this); uint8_t temp = 0; + bool consecutive = false; bool multRet = false; - // Optimization: return local value directly instead of copying it into a temporary - if (stat->list.size == 1 && isExprLocalReg(stat->list.data[0])) + // Optimization: return locals directly instead of copying them into a temporary + // this is very important for a single return value and occasionally effective for multiple values + if (stat->list.size > 0 && isExprLocalReg(stat->list.data[0])) { - AstExprLocal* le = stat->list.data[0]->as(); - LUAU_ASSERT(le); + temp = getLocal(stat->list.data[0]->as()->local); + consecutive = true; - temp = getLocal(le->local); + for (size_t i = 1; i < stat->list.size; ++i) + { + AstExpr* v = stat->list.data[i]; + if (!isExprLocalReg(v) || getLocal(v->as()->local) != temp + i) + { + consecutive = false; + break; + } + } } - else if (stat->list.size > 0) + + if (!consecutive && stat->list.size > 0) { temp = allocReg(stat, unsigned(stat->list.size)); diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 52ece73..e4d59ea 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -195,12 +195,16 @@ struct ConstantVisitor : AstVisitor DenseHashMap& variables; DenseHashMap& locals; + bool wasEmpty = false; + ConstantVisitor( DenseHashMap& constants, DenseHashMap& variables, DenseHashMap& locals) : constants(constants) , variables(variables) , locals(locals) { + // since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries + wasEmpty = constants.empty() && locals.empty(); } Constant analyze(AstExpr* node) @@ -326,7 +330,7 @@ struct ConstantVisitor : AstVisitor { if (value.type != Constant::Type_Unknown) map[key] = value; - else if (!FFlag::LuauCompileSupportInlining) + else if (!FFlag::LuauCompileSupportInlining || wasEmpty) ; else if (Constant* old = map.find(key)) old->type = Constant::Type_Unknown; diff --git a/Compiler/src/CostModel.cpp b/Compiler/src/CostModel.cpp index 9afd09f..f804e9d 100644 --- a/Compiler/src/CostModel.cpp +++ b/Compiler/src/CostModel.cpp @@ -187,7 +187,7 @@ struct CostVisitor : AstVisitor if (node->is()) result += 2; else if (node->is() || node->is() || node->is() || node->is()) - result += 2; + result += 5; else if (node->is() || node->is()) result += 1; diff --git a/extern/isocline/src/bbcode.c b/extern/isocline/src/bbcode.c index 4d11ac3..8722cbd 100644 --- a/extern/isocline/src/bbcode.c +++ b/extern/isocline/src/bbcode.c @@ -575,6 +575,7 @@ ic_private const char* parse_tag_value( tag_t* tag, char* idbuf, const char* s, } // limit name and attr to 128 bytes char valbuf[128]; + valbuf[0] = 0; // fixes gcc uninitialized warning ic_strncpy( idbuf, 128, id, idend - id); ic_strncpy( valbuf, 128, val, valend - val); ic_str_tolower(idbuf); diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 292625b..12c6845 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -7,7 +7,7 @@ using namespace Luau; -struct DocumentationSymbolFixture : Fixture +struct DocumentationSymbolFixture : BuiltinsFixture { std::optional getDocSymbol(const std::string& source, Position position) { diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 1c284f1..b4e9340 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -27,7 +27,7 @@ template struct ACFixtureImpl : BaseType { ACFixtureImpl() - : Fixture(true, true) + : BaseType(true, true) { } @@ -111,6 +111,18 @@ struct ACFixtureImpl : BaseType }; struct ACFixture : ACFixtureImpl +{ + ACFixture() + : ACFixtureImpl() + { + addGlobalBinding(frontend.typeChecker, "table", Binding{typeChecker.anyType}); + addGlobalBinding(frontend.typeChecker, "math", Binding{typeChecker.anyType}); + addGlobalBinding(frontend.typeCheckerForAutocomplete, "table", Binding{typeChecker.anyType}); + addGlobalBinding(frontend.typeCheckerForAutocomplete, "math", Binding{typeChecker.anyType}); + } +}; + +struct ACBuiltinsFixture : ACFixtureImpl { }; @@ -277,7 +289,7 @@ TEST_CASE_FIXTURE(ACFixture, "function_parameters") CHECK(ac.entryMap.count("test")); } -TEST_CASE_FIXTURE(ACFixture, "get_member_completions") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "get_member_completions") { check(R"( local a = table.@1 @@ -376,7 +388,7 @@ TEST_CASE_FIXTURE(ACFixture, "table_intersection") CHECK(ac.entryMap.count("c3")); } -TEST_CASE_FIXTURE(ACFixture, "get_string_completions") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "get_string_completions") { check(R"( local a = ("foo"):@1 @@ -427,7 +439,7 @@ TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") CHECK(!ac.entryMap.count("math")); } -TEST_CASE_FIXTURE(ACFixture, "method_call_inside_if_conditional") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "method_call_inside_if_conditional") { check(R"( if table: @1 @@ -1884,7 +1896,7 @@ ex.b(function(x: CHECK(!ac.entryMap.count("(done) -> number")); } -TEST_CASE_FIXTURE(ACFixture, "suggest_external_module_type") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "suggest_external_module_type") { fileResolver.source["Module/A"] = R"( export type done = { x: number, y: number } @@ -2235,7 +2247,7 @@ local a: aaa.do CHECK(ac.entryMap.count("other")); } -TEST_CASE_FIXTURE(ACFixture, "autocompleteSource") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteSource") { std::string_view source = R"( local a = table. -- Line 1 @@ -2269,7 +2281,7 @@ TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_comments") CHECK_EQ(0, ac.entryMap.size()); } -TEST_CASE_FIXTURE(ACFixture, "autocompleteProp_index_function_metamethod_is_variadic") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteProp_index_function_metamethod_is_variadic") { std::string_view source = R"( type Foo = {x: number} @@ -2720,7 +2732,7 @@ type A = () -> T CHECK(ac.entryMap.count("string")); } -TEST_CASE_FIXTURE(ACFixture, "autocomplete_oop_implicit_self") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_oop_implicit_self") { check(R"( --!strict @@ -2728,15 +2740,15 @@ local Class = {} Class.__index = Class type Class = typeof(setmetatable({} :: { x: number }, Class)) function Class.new(x: number): Class - return setmetatable({x = x}, Class) + return setmetatable({x = x}, Class) end function Class.getx(self: Class) - return self.x + return self.x end function test() - local c = Class.new(42) - local n = c:@1 - print(n) + local c = Class.new(42) + local n = c:@1 + print(n) end )"); @@ -2745,7 +2757,7 @@ end CHECK(ac.entryMap.count("getx")); } -TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_on_string_singletons") { check(R"( --!strict @@ -2989,7 +3001,7 @@ s.@1 CHECK(ac.entryMap["sub"].wrongIndexType == true); } -TEST_CASE_FIXTURE(ACFixture, "string_library_non_self_calls_are_fine") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "string_library_non_self_calls_are_fine") { ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; @@ -3007,7 +3019,7 @@ string.@1 CHECK(ac.entryMap["sub"].wrongIndexType == false); } -TEST_CASE_FIXTURE(ACFixture, "string_library_self_calls_are_invalid") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "string_library_self_calls_are_invalid") { ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; diff --git a/tests/BuiltinDefinitions.test.cpp b/tests/BuiltinDefinitions.test.cpp index dbe80f2..496df4b 100644 --- a/tests/BuiltinDefinitions.test.cpp +++ b/tests/BuiltinDefinitions.test.cpp @@ -10,8 +10,10 @@ using namespace Luau; TEST_SUITE_BEGIN("BuiltinDefinitionsTest"); -TEST_CASE_FIXTURE(Fixture, "lib_documentation_symbols") +TEST_CASE_FIXTURE(BuiltinsFixture, "lib_documentation_symbols") { + CHECK(!typeChecker.globalScope->bindings.empty()); + for (const auto& [name, binding] : typeChecker.globalScope->bindings) { std::string nameString(name.c_str()); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index f206438..b032060 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -4713,7 +4713,6 @@ local function foo() end local a, b = foo() - return a, b )", 1, 2), @@ -4721,9 +4720,7 @@ return a, b DUPCLOSURE R0 K0 LOADNIL R1 LOADNIL R2 -MOVE R3 R1 -MOVE R4 R2 -RETURN R3 2 +RETURN R1 2 )"); // this happens even if the function returns conditionally @@ -4733,7 +4730,6 @@ local function foo(a) end local a, b = foo(false) - return a, b )", 1, 2), @@ -4741,9 +4737,7 @@ return a, b DUPCLOSURE R0 K0 LOADNIL R1 LOADNIL R2 -MOVE R3 R1 -MOVE R4 R2 -RETURN R3 2 +RETURN R1 2 )"); // note though that we can't inline a function like this in multret context @@ -4880,11 +4874,7 @@ LOADN R5 1 ADD R4 R5 R1 LOADN R5 3 ADD R6 R1 R2 -MOVE R7 R3 -MOVE R8 R4 -MOVE R9 R5 -MOVE R10 R6 -RETURN R7 4 +RETURN R3 4 )"); } @@ -5151,4 +5141,59 @@ RETURN R0 0 )"); } +TEST_CASE("ReturnConsecutive") +{ + // we can return a single local directly + CHECK_EQ("\n" + compileFunction0(R"( +local x = ... +return x +)"), + R"( +GETVARARGS R0 1 +RETURN R0 1 +)"); + + // or multiple, when they are allocated in consecutive registers + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return x, y +)"), + R"( +GETVARARGS R0 2 +RETURN R0 2 +)"); + + // but not if it's an expression + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return x, y + 1 +)"), + R"( +GETVARARGS R0 2 +MOVE R2 R0 +ADDK R3 R1 K0 +RETURN R2 2 +)"); + + // or a local with wrong register number + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return y, x +)"), + R"( +GETVARARGS R0 2 +MOVE R2 R1 +MOVE R3 R0 +RETURN R2 2 +)"); + + // also double check the optimization doesn't trip on no-argument return (these are rare) + CHECK_EQ("\n" + compileFunction0(R"( +return +)"), + R"( +RETURN R0 0 +)"); +} + TEST_SUITE_END(); diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp index aa5b728..2fa0659 100644 --- a/tests/CostModel.test.cpp +++ b/tests/CostModel.test.cpp @@ -76,9 +76,9 @@ end const bool args1[] = {false}; const bool args2[] = {true}; - // loop baseline cost is 2 - CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); - CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1)); + // loop baseline cost is 5 + CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("MutableVariable") @@ -154,8 +154,8 @@ end const bool args1[] = {false}; const bool args2[] = {true}; - CHECK_EQ(38, Luau::Compile::computeCost(model, args1, 1)); - CHECK_EQ(37, Luau::Compile::computeCost(model, args2, 1)); + CHECK_EQ(50, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(49, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("Conditional") @@ -219,8 +219,8 @@ end const bool args1[] = {false}; const bool args2[] = {true}; - CHECK_EQ(4, Luau::Compile::computeCost(model, args1, 1)); - CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1)); + CHECK_EQ(7, Luau::Compile::computeCost(model, args1, 1)); + CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); } TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index d8b37a6..03f3e15 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -92,10 +92,6 @@ Fixture::Fixture(bool freeze, bool prepareAutocomplete) configResolver.defaultConfig.enabledLint.warningMask = ~0ull; configResolver.defaultConfig.parseOptions.captureComments = true; - registerBuiltinTypes(frontend.typeChecker); - if (prepareAutocomplete) - registerBuiltinTypes(frontend.typeCheckerForAutocomplete); - registerTestTypes(); Luau::freeze(frontend.typeChecker.globalTypes); Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); @@ -410,6 +406,21 @@ LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) return result; } +BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) + : Fixture(freeze, prepareAutocomplete) +{ + Luau::unfreeze(frontend.typeChecker.globalTypes); + Luau::unfreeze(frontend.typeCheckerForAutocomplete.globalTypes); + + registerBuiltinTypes(frontend.typeChecker); + if (prepareAutocomplete) + registerBuiltinTypes(frontend.typeCheckerForAutocomplete); + registerTestTypes(); + + Luau::freeze(frontend.typeChecker.globalTypes); + Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); +} + ModuleName fromString(std::string_view name) { return ModuleName(name); diff --git a/tests/Fixture.h b/tests/Fixture.h index 0d1233b..901f7d4 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -151,6 +151,11 @@ struct Fixture LoadDefinitionFileResult loadDefinition(const std::string& source); }; +struct BuiltinsFixture : Fixture +{ + BuiltinsFixture(bool freeze = true, bool prepareAutocomplete = false); +}; + ModuleName fromString(std::string_view name); template diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index a10e8f7..33b81be 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -77,7 +77,7 @@ struct NaiveFileResolver : NullFileResolver } // namespace -struct FrontendFixture : Fixture +struct FrontendFixture : BuiltinsFixture { FrontendFixture() { diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 6649cb7..202aece 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -75,7 +75,7 @@ _ = 6 CHECK_EQ(result.warnings.size(), 0); } -TEST_CASE_FIXTURE(Fixture, "BuiltinGlobalWrite") +TEST_CASE_FIXTURE(BuiltinsFixture, "BuiltinGlobalWrite") { LintResult result = lint(R"( math = {} @@ -309,7 +309,7 @@ print(arg) CHECK_EQ(result.warnings[0].text, "Variable 'arg' shadows previous declaration at line 2"); } -TEST_CASE_FIXTURE(Fixture, "LocalShadowGlobal") +TEST_CASE_FIXTURE(BuiltinsFixture, "LocalShadowGlobal") { LintResult result = lint(R"( local math = math @@ -1470,7 +1470,7 @@ end CHECK_EQ(result.warnings[2].text, "Member 'Instance.DataCost' is deprecated"); } -TEST_CASE_FIXTURE(Fixture, "TableOperations") +TEST_CASE_FIXTURE(BuiltinsFixture, "TableOperations") { LintResult result = lintTyped(R"( local t = {} diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 44cc20a..4a99986 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -113,7 +113,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") CHECK_EQ(2, dest.typeVars.size()); // One table and one function } -TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") { CheckResult result = check(R"( return {sign=math.sign} @@ -250,7 +250,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection") CHECK_EQ(getSingletonTypes().stringType, ctv->parts[1]); } -TEST_CASE_FIXTURE(Fixture, "clone_self_property") +TEST_CASE_FIXTURE(BuiltinsFixture, "clone_self_property") { ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index feeaf2c..69430b1 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -13,7 +13,7 @@ using namespace Luau; TEST_SUITE_BEGIN("NonstrictModeTests"); -TEST_CASE_FIXTURE(Fixture, "function_returns_number_or_string") +TEST_CASE_FIXTURE(BuiltinsFixture, "function_returns_number_or_string") { ScopedFastFlag sff[]{ {"LuauReturnTypeInferenceInNonstrict", true}, @@ -224,7 +224,7 @@ TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") CHECK_MESSAGE(get(ttv->props["three"].type), "Should be a function: " << *ttv->props["three"].type); } -TEST_CASE_FIXTURE(Fixture, "for_in_iterator_variables_are_any") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_iterator_variables_are_any") { CheckResult result = check(R"( --!nonstrict @@ -243,7 +243,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_iterator_variables_are_any") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_dot_insert_and_recursive_calls") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_insert_and_recursive_calls") { CheckResult result = check(R"( --!nonstrict diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index d3778f6..4183068 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -739,7 +739,7 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_table_normalizes_sensibly") CHECK_EQ("t1 where t1 = { get: () -> t1 }", toString(ty, {true})); } -TEST_CASE_FIXTURE(Fixture, "union_of_distinct_free_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "union_of_distinct_free_types") { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, @@ -760,7 +760,7 @@ TEST_CASE_FIXTURE(Fixture, "union_of_distinct_free_types") CHECK("(a, b) -> a | b" == toString(requireType("fussy"))); } -TEST_CASE_FIXTURE(Fixture, "constrained_intersection_of_intersections") +TEST_CASE_FIXTURE(BuiltinsFixture, "constrained_intersection_of_intersections") { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, @@ -951,7 +951,7 @@ TEST_CASE_FIXTURE(Fixture, "nested_table_normalization_with_non_table__no_ice") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "visiting_a_type_twice_is_not_considered_normal") +TEST_CASE_FIXTURE(BuiltinsFixture, "visiting_a_type_twice_is_not_considered_normal") { ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; @@ -976,7 +976,6 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_failure_instersection_combine_must_follow") { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, - {"LuauNormalizeCombineIntersectionFix", true}, }; CheckResult result = check(R"( diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 69ff73a..c9d8d0b 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1618,8 +1618,6 @@ TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments") TEST_CASE_FIXTURE(Fixture, "end_extent_doesnt_consume_comments_even_with_capture") { - ScopedFastFlag luauParseLocationIgnoreCommentSkipInCapture{"LuauParseLocationIgnoreCommentSkipInCapture", true}; - // Same should hold when comments are captured ParseOptions opts; opts.captureComments = true; diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index 538f357..c16f60d 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -17,7 +17,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauLowerBoundsCalculation); -struct LimitFixture : Fixture +struct LimitFixture : BuiltinsFixture { #if defined(_NOOPT) || defined(_DEBUG) ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100}; diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 332a4b2..e9fa5b2 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -224,7 +224,7 @@ n1 -> n4 [label="typePackParam"]; (void)toDot(requireType("a")); } -TEST_CASE_FIXTURE(Fixture, "metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable") { CheckResult result = check(R"( local a: typeof(setmetatable({}, {})) diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index ccf5c58..50d0838 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -60,7 +60,7 @@ TEST_CASE_FIXTURE(Fixture, "named_table") CHECK_EQ("TheTable", toString(&table)); } -TEST_CASE_FIXTURE(Fixture, "exhaustive_toString_of_cyclic_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") { CheckResult result = check(R"( --!strict @@ -338,7 +338,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") REQUIRE_EQ("c", toString(params[2], opts)); } -TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") +TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") { ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index b0eb31c..7562a4d 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -279,7 +279,7 @@ TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") CHECK_EQ("Node", toString(e->wantedType)); } -TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") +TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_multi_assign") { fileResolver.source["workspace/A"] = R"( export type myvec2 = {x: number, y: number} @@ -317,7 +317,7 @@ TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") REQUIRE(bType->props.size() == 3); } -TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") { CheckResult result = check("type t10 = typeof(table)"); LUAU_REQUIRE_NO_ERRORS(result); @@ -385,7 +385,7 @@ type Cool = typeof(c) CHECK_EQ(ttv->name, "Cool"); } -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_of_an_imported_recursive_type") { fileResolver.source["game/A"] = R"( export type X = { a: number, b: X? } @@ -410,7 +410,7 @@ type X = Import.X CHECK_EQ(follow(*ty1), follow(*ty2)); } -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_of_an_imported_recursive_generic_type") { fileResolver.source["game/A"] = R"( export type X = { a: T, b: U, C: X? } @@ -564,7 +564,7 @@ TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") * * We solved this by ascribing a unique subLevel to each prototyped alias. */ -TEST_CASE_FIXTURE(Fixture, "do_not_quantify_unresolved_aliases") +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_quantify_unresolved_aliases") { CheckResult result = check(R"( --!strict diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 7f1c757..b9e1ae9 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -528,7 +528,7 @@ TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definiti CHECK_EQ(recordType, bType); } -TEST_CASE_FIXTURE(Fixture, "use_type_required_from_another_file") +TEST_CASE_FIXTURE(BuiltinsFixture, "use_type_required_from_another_file") { addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); @@ -554,7 +554,7 @@ TEST_CASE_FIXTURE(Fixture, "use_type_required_from_another_file") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "cannot_use_nonexported_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_use_nonexported_type") { addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); @@ -580,7 +580,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_use_nonexported_type") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "builtin_types_are_not_exported") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_are_not_exported") { addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); @@ -676,7 +676,7 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_is_not_special_without_the_flag") )"); } -TEST_CASE_FIXTURE(Fixture, "luau_print_is_magic_if_the_flag_is_set") +TEST_CASE_FIXTURE(BuiltinsFixture, "luau_print_is_magic_if_the_flag_is_set") { // Luau::resetPrintLine(); ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 5224b5d..bc55940 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -237,7 +237,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") CHECK_EQ("*unknown*", toString(requireType("a"))); } -TEST_CASE_FIXTURE(Fixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") +TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") { CheckResult result = check(R"( local a: any @@ -285,7 +285,7 @@ end LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "metatable_of_any_can_be_a_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_of_any_can_be_a_table") { CheckResult result = check(R"( --!strict diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 1ae6594..b710ea0 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -12,7 +12,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); TEST_SUITE_BEGIN("BuiltinTests"); -TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") +TEST_CASE_FIXTURE(BuiltinsFixture, "math_things_are_defined") { CheckResult result = check(R"( local a00 = math.frexp @@ -50,7 +50,7 @@ TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "next_iterator_should_infer_types_and_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "next_iterator_should_infer_types_and_type_check") { CheckResult result = check(R"( local a: string, b: number = next({ 1 }) @@ -63,7 +63,7 @@ TEST_CASE_FIXTURE(Fixture, "next_iterator_should_infer_types_and_type_check") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "pairs_iterator_should_infer_types_and_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_iterator_should_infer_types_and_type_check") { CheckResult result = check(R"( type Map = { [K]: V } @@ -75,7 +75,7 @@ TEST_CASE_FIXTURE(Fixture, "pairs_iterator_should_infer_types_and_type_check") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "ipairs_iterator_should_infer_types_and_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_iterator_should_infer_types_and_type_check") { CheckResult result = check(R"( type Map = { [K]: V } @@ -87,7 +87,7 @@ TEST_CASE_FIXTURE(Fixture, "ipairs_iterator_should_infer_types_and_type_check") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_dot_remove_optionally_returns_generic") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_remove_optionally_returns_generic") { CheckResult result = check(R"( local t = { 1 } @@ -98,7 +98,7 @@ TEST_CASE_FIXTURE(Fixture, "table_dot_remove_optionally_returns_generic") CHECK_EQ(toString(requireType("n")), "number?"); } -TEST_CASE_FIXTURE(Fixture, "table_concat_returns_string") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_concat_returns_string") { CheckResult result = check(R"( local r = table.concat({1,2,3,4}, ",", 2); @@ -108,7 +108,7 @@ TEST_CASE_FIXTURE(Fixture, "table_concat_returns_string") CHECK_EQ(*typeChecker.stringType, *requireType("r")); } -TEST_CASE_FIXTURE(Fixture, "sort") +TEST_CASE_FIXTURE(BuiltinsFixture, "sort") { CheckResult result = check(R"( local t = {1, 2, 3}; @@ -118,7 +118,7 @@ TEST_CASE_FIXTURE(Fixture, "sort") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "sort_with_predicate") +TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_predicate") { CheckResult result = check(R"( --!strict @@ -130,7 +130,7 @@ TEST_CASE_FIXTURE(Fixture, "sort_with_predicate") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "sort_with_bad_predicate") +TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_bad_predicate") { CheckResult result = check(R"( --!strict @@ -140,6 +140,12 @@ TEST_CASE_FIXTURE(Fixture, "sort_with_bad_predicate") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type '(number, number) -> boolean' could not be converted into '((a, a) -> boolean)?' +caused by: + None of the union options are compatible. For example: Type '(number, number) -> boolean' could not be converted into '(a, a) -> boolean' +caused by: + Argument #1 type is not compatible. Type 'string' could not be converted into 'number')", + toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "strings_have_methods") @@ -152,7 +158,7 @@ TEST_CASE_FIXTURE(Fixture, "strings_have_methods") CHECK_EQ(*typeChecker.stringType, *requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "math_max_variatic") +TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_variatic") { CheckResult result = check(R"( local n = math.max(1,2,3,4,5,6,7,8,9,0) @@ -162,16 +168,17 @@ TEST_CASE_FIXTURE(Fixture, "math_max_variatic") CHECK_EQ(*typeChecker.numberType, *requireType("n")); } -TEST_CASE_FIXTURE(Fixture, "math_max_checks_for_numbers") +TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_checks_for_numbers") { CheckResult result = check(R"( local n = math.max(1,2,"3") )"); CHECK(!result.errors.empty()); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "builtin_tables_sealed") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_tables_sealed") { CheckResult result = check(R"LUA( local b = bit32 @@ -183,7 +190,7 @@ TEST_CASE_FIXTURE(Fixture, "builtin_tables_sealed") CHECK_EQ(bit32t->state, TableState::Sealed); } -TEST_CASE_FIXTURE(Fixture, "lua_51_exported_globals_all_exist") +TEST_CASE_FIXTURE(BuiltinsFixture, "lua_51_exported_globals_all_exist") { // Extracted from lua5.1 CheckResult result = check(R"( @@ -340,7 +347,7 @@ TEST_CASE_FIXTURE(Fixture, "lua_51_exported_globals_all_exist") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "setmetatable_unpacks_arg_types_correctly") +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_unpacks_arg_types_correctly") { CheckResult result = check(R"( setmetatable({}, setmetatable({}, {})) @@ -348,7 +355,7 @@ TEST_CASE_FIXTURE(Fixture, "setmetatable_unpacks_arg_types_correctly") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_2_args_overload") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_2_args_overload") { CheckResult result = check(R"( local t = {} @@ -360,7 +367,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_2_args_o CHECK_EQ(typeChecker.stringType, requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_3_args_overload") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_3_args_overload") { CheckResult result = check(R"( local t = {} @@ -372,7 +379,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_3_args_o CHECK_EQ("string", toString(requireType("s"))); } -TEST_CASE_FIXTURE(Fixture, "table_pack") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack") { CheckResult result = check(R"( local t = table.pack(1, "foo", true) @@ -382,7 +389,7 @@ TEST_CASE_FIXTURE(Fixture, "table_pack") CHECK_EQ("{| [number]: boolean | number | string, n: number |}", toString(requireType("t"))); } -TEST_CASE_FIXTURE(Fixture, "table_pack_variadic") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack_variadic") { CheckResult result = check(R"( --!strict @@ -397,7 +404,7 @@ local t = table.pack(f()) CHECK_EQ("{| [number]: number | string, n: number |}", toString(requireType("t"))); } -TEST_CASE_FIXTURE(Fixture, "table_pack_reduce") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_pack_reduce") { CheckResult result = check(R"( local t = table.pack(1, 2, true) @@ -414,7 +421,7 @@ TEST_CASE_FIXTURE(Fixture, "table_pack_reduce") CHECK_EQ("{| [number]: string, n: number |}", toString(requireType("t"))); } -TEST_CASE_FIXTURE(Fixture, "gcinfo") +TEST_CASE_FIXTURE(BuiltinsFixture, "gcinfo") { CheckResult result = check(R"( local n = gcinfo() @@ -424,12 +431,12 @@ TEST_CASE_FIXTURE(Fixture, "gcinfo") CHECK_EQ(*typeChecker.numberType, *requireType("n")); } -TEST_CASE_FIXTURE(Fixture, "getfenv") +TEST_CASE_FIXTURE(BuiltinsFixture, "getfenv") { LUAU_REQUIRE_NO_ERRORS(check("getfenv(1)")); } -TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "os_time_takes_optional_date_table") { CheckResult result = check(R"( local n1 = os.time() @@ -443,7 +450,7 @@ TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") CHECK_EQ(*typeChecker.numberType, *requireType("n3")); } -TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "thread_is_a_type") { CheckResult result = check(R"( local co = coroutine.create(function() end) @@ -453,7 +460,7 @@ TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") CHECK_EQ(*typeChecker.threadType, *requireType("co")); } -TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") +TEST_CASE_FIXTURE(BuiltinsFixture, "coroutine_resume_anything_goes") { CheckResult result = check(R"( local function nifty(x, y) @@ -471,7 +478,7 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") +TEST_CASE_FIXTURE(BuiltinsFixture, "coroutine_wrap_anything_goes") { CheckResult result = check(R"( --!nonstrict @@ -490,7 +497,7 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "setmetatable_should_not_mutate_persisted_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_should_not_mutate_persisted_types") { CheckResult result = check(R"( local string = string @@ -505,7 +512,7 @@ TEST_CASE_FIXTURE(Fixture, "setmetatable_should_not_mutate_persisted_types") REQUIRE(ttv); } -TEST_CASE_FIXTURE(Fixture, "string_format_arg_types_inference") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_arg_types_inference") { CheckResult result = check(R"( --!strict @@ -518,7 +525,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_arg_types_inference") CHECK_EQ("(number, number, string) -> string", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "string_format_arg_count_mismatch") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_arg_count_mismatch") { CheckResult result = check(R"( --!strict @@ -534,7 +541,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_arg_count_mismatch") CHECK_EQ(result.errors[2].location.begin.line, 4); } -TEST_CASE_FIXTURE(Fixture, "string_format_correctly_ordered_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_correctly_ordered_types") { CheckResult result = check(R"( --!strict @@ -548,7 +555,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_correctly_ordered_types") CHECK_EQ(tm->givenType, typeChecker.numberType); } -TEST_CASE_FIXTURE(Fixture, "xpcall") +TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall") { CheckResult result = check(R"( --!strict @@ -564,7 +571,7 @@ TEST_CASE_FIXTURE(Fixture, "xpcall") CHECK_EQ("boolean", toString(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "see_thru_select") +TEST_CASE_FIXTURE(BuiltinsFixture, "see_thru_select") { CheckResult result = check(R"( local a:number, b:boolean = select(2,"hi", 10, true) @@ -573,7 +580,7 @@ TEST_CASE_FIXTURE(Fixture, "see_thru_select") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "see_thru_select_count") +TEST_CASE_FIXTURE(BuiltinsFixture, "see_thru_select_count") { CheckResult result = check(R"( local a = select("#","hi", 10, true) @@ -583,7 +590,7 @@ TEST_CASE_FIXTURE(Fixture, "see_thru_select_count") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "select_with_decimal_argument_is_rounded_down") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_decimal_argument_is_rounded_down") { CheckResult result = check(R"( local a: number, b: boolean = select(2.9, "foo", 1, true) @@ -608,7 +615,7 @@ TEST_CASE_FIXTURE(Fixture, "bad_select_should_not_crash") CHECK_LE(0, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "select_way_out_of_range") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_way_out_of_range") { CheckResult result = check(R"( select(5432598430953240958) @@ -619,7 +626,7 @@ TEST_CASE_FIXTURE(Fixture, "select_way_out_of_range") CHECK_EQ("bad argument #1 to select (index out of range)", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "select_slightly_out_of_range") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_slightly_out_of_range") { CheckResult result = check(R"( select(3, "a", 1) @@ -630,7 +637,7 @@ TEST_CASE_FIXTURE(Fixture, "select_slightly_out_of_range") CHECK_EQ("bad argument #1 to select (index out of range)", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail") { CheckResult result = check(R"( --!nonstrict @@ -649,7 +656,7 @@ TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail") CHECK_EQ("any", toString(requireType("quux"))); } -TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail_and_string_head") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_string_head") { CheckResult result = check(R"( --!nonstrict @@ -703,7 +710,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument2") CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[1])); } -TEST_CASE_FIXTURE(Fixture, "debug_traceback_is_crazy") +TEST_CASE_FIXTURE(BuiltinsFixture, "debug_traceback_is_crazy") { CheckResult result = check(R"( local co: thread = ... @@ -720,7 +727,7 @@ debug.traceback(co, "msg", 1) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "debug_info_is_crazy") +TEST_CASE_FIXTURE(BuiltinsFixture, "debug_info_is_crazy") { CheckResult result = check(R"( local co: thread, f: ()->() = ... @@ -734,7 +741,7 @@ debug.info(f, "n") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "aliased_string_format") +TEST_CASE_FIXTURE(BuiltinsFixture, "aliased_string_format") { CheckResult result = check(R"( local fmt = string.format @@ -745,7 +752,7 @@ TEST_CASE_FIXTURE(Fixture, "aliased_string_format") CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "string_lib_self_noself") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_lib_self_noself") { CheckResult result = check(R"( --!nonstrict @@ -764,7 +771,7 @@ TEST_CASE_FIXTURE(Fixture, "string_lib_self_noself") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "gmatch_definition") +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_definition") { CheckResult result = check(R"_( local a, b, c = ("hey"):gmatch("(.)(.)(.)")() @@ -777,7 +784,7 @@ end LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "select_on_variadic") +TEST_CASE_FIXTURE(BuiltinsFixture, "select_on_variadic") { CheckResult result = check(R"( local function f(): (number, ...(boolean | number)) @@ -793,7 +800,7 @@ TEST_CASE_FIXTURE(Fixture, "select_on_variadic") CHECK_EQ("any", toString(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_positions") +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_report_all_type_errors_at_correct_positions") { CheckResult result = check(R"( ("%s%d%s"):format(1, "hello", true) @@ -825,7 +832,7 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[5].data); } -TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "tonumber_returns_optional_number_type") { CheckResult result = check(R"( --!strict @@ -836,7 +843,7 @@ TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") CHECK_EQ("Type 'number?' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type2") +TEST_CASE_FIXTURE(BuiltinsFixture, "tonumber_returns_optional_number_type2") { CheckResult result = check(R"( --!strict @@ -846,7 +853,7 @@ TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type2") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_add_definitions_to_persistent_types") { CheckResult result = check(R"( local f = math.sin @@ -868,7 +875,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") REQUIRE(gtv->definition); } -TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types") { ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, @@ -889,7 +896,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types2") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types2") { ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, @@ -907,7 +914,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types2") CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type") { ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, @@ -924,7 +931,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types_even_from_type_pack_tail_ CHECK_EQ("(...number?) -> (number, ...number?)", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy") { ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, @@ -941,7 +948,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_returns_false_and_string_iff_it_knows_the_fir CHECK_EQ("(nil) -> nil", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "table_freeze_is_generic") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") { CheckResult result = check(R"( local t1: {a: number} = {a = 42} @@ -968,7 +975,7 @@ TEST_CASE_FIXTURE(Fixture, "table_freeze_is_generic") CHECK_EQ("*unknown*", toString(requireType("d"))); } -TEST_CASE_FIXTURE(Fixture, "set_metatable_needs_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") { ScopedFastFlag sff{"LuauSetMetaTableArgsCheck", true}; CheckResult result = check(R"( @@ -991,7 +998,7 @@ local function f(a: typeof(f)) end CHECK_EQ("Unknown global 'f'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "no_persistent_typelevel_change") +TEST_CASE_FIXTURE(BuiltinsFixture, "no_persistent_typelevel_change") { TypeId mathTy = requireType(typeChecker.globalScope, "math"); REQUIRE(mathTy); @@ -1008,7 +1015,7 @@ TEST_CASE_FIXTURE(Fixture, "no_persistent_typelevel_change") CHECK(ftv->level.subLevel == original.subLevel); } -TEST_CASE_FIXTURE(Fixture, "global_singleton_types_are_sealed") +TEST_CASE_FIXTURE(BuiltinsFixture, "global_singleton_types_are_sealed") { CheckResult result = check(R"( local function f(x: string) diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 5a6e403..d90129d 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -10,7 +10,7 @@ using namespace Luau; using std::nullopt; -struct ClassFixture : Fixture +struct ClassFixture : BuiltinsFixture { ClassFixture() { diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 0e07121..14f1f70 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -85,7 +85,7 @@ TEST_CASE_FIXTURE(Fixture, "vararg_functions_should_allow_calls_of_any_types_and LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "vararg_function_is_quantified") +TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") { CheckResult result = check(R"( local T = {} @@ -555,7 +555,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") CHECK(bool(argType->indexer)); } -TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") +TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4") { CheckResult result = check(R"( function bottomupmerge(comp, a, b, left, mid, right) @@ -620,7 +620,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); } -TEST_CASE_FIXTURE(Fixture, "mutual_recursion") +TEST_CASE_FIXTURE(BuiltinsFixture, "mutual_recursion") { CheckResult result = check(R"( --!strict @@ -639,7 +639,7 @@ TEST_CASE_FIXTURE(Fixture, "mutual_recursion") dumpErrors(result); } -TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") +TEST_CASE_FIXTURE(BuiltinsFixture, "toposort_doesnt_break_mutual_recursion") { CheckResult result = check(R"( --!strict @@ -676,7 +676,7 @@ TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "it_is_ok_to_oversaturate_a_higher_order_function_argument") +TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_oversaturate_a_higher_order_function_argument") { CheckResult result = check(R"( function onerror() end @@ -794,7 +794,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields }})); } -TEST_CASE_FIXTURE(Fixture, "calling_function_with_anytypepack_doesnt_leak_free_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "calling_function_with_anytypepack_doesnt_leak_free_types") { ScopedFastFlag sff[]{ {"LuauReturnTypeInferenceInNonstrict", true}, @@ -966,7 +966,7 @@ TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") CHECK_EQ("string", toString(requireType("z"))); } -TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_anonymous_function_arguments") { // Simple direct arg to arg propagation CheckResult result = check(R"( @@ -1068,7 +1068,7 @@ f(function(x) return x * 2 end) } } -TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_anonymous_function_arguments") { // Simple direct arg to arg propagation CheckResult result = check(R"( @@ -1287,10 +1287,8 @@ caused by: Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); } -TEST_CASE_FIXTURE(Fixture, "function_decl_quantify_right_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_quantify_right_type") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - fileResolver.source["game/isAMagicMock"] = R"( --!nonstrict return function(value) @@ -1311,10 +1309,8 @@ end LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite") +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_sealed_overwrite") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - CheckResult result = check(R"( function string.len(): number return 1 @@ -1333,11 +1329,8 @@ print(string.len('hello')) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite_2") +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_sealed_overwrite_2") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - ScopedFastFlag inferStatFunction{"LuauInferStatFunction", true}; - CheckResult result = check(R"( local t: { f: ((x: number) -> number)? } = {} @@ -1477,11 +1470,8 @@ TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_th LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_unsealed_overwrite") +TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_non_self_unsealed_overwrite") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - ScopedFastFlag inferStatFunction{"LuauInferStatFunction", true}; - CheckResult result = check(R"( local t = { f = nil :: ((x: number) -> number)? } @@ -1518,8 +1508,6 @@ TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") TEST_CASE_FIXTURE(Fixture, "function_statement_sealed_table_assignment_through_indexer") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - CheckResult result = check(R"( local t: {[string]: () -> number} = {} @@ -1580,7 +1568,7 @@ wrapper(test) CHECK(acm->isVariadic); } -TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic2") +TEST_CASE_FIXTURE(BuiltinsFixture, "too_few_arguments_variadic_generic2") { CheckResult result = check(R"( function test(a: number, b: string, ...) diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 91be2c1..de0c939 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -67,7 +67,7 @@ TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_polytypes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") +TEST_CASE_FIXTURE(BuiltinsFixture, "inferred_local_vars_can_be_polytypes") { CheckResult result = check(R"( local function id(x) return x end @@ -79,7 +79,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_instantiated_polytypes") +TEST_CASE_FIXTURE(BuiltinsFixture, "local_vars_can_be_instantiated_polytypes") { CheckResult result = check(R"( local function id(x) return x end @@ -609,7 +609,7 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") CHECK(requireType("y1") == requireType("y2")); } -TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields") +TEST_CASE_FIXTURE(BuiltinsFixture, "bound_tables_do_not_clone_original_fields") { CheckResult result = check(R"( local exports = {} @@ -675,7 +675,7 @@ local d: D = c R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); } -TEST_CASE_FIXTURE(Fixture, "generic_functions_dont_cache_type_parameters") +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_functions_dont_cache_type_parameters") { CheckResult result = check(R"( -- See https://github.com/Roblox/luau/issues/332 @@ -1013,7 +1013,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") CHECK(it != result.errors.end()); } -TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument") { ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; @@ -1078,7 +1078,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded" LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_lib_function_function_argument") { CheckResult result = check(R"( local a = {{x=4}, {x=7}, {x=1}} diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 3675919..41bc0c2 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -316,8 +316,6 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - CheckResult result = check(R"( type X = { x: (number) -> number } type Y = { y: (string) -> string } @@ -351,8 +349,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - // After normalization, previous 'table_intersection_write_sealed_indirect' is identical to this one CheckResult result = check(R"( type XY = { x: (number) -> number, y: (string) -> string } @@ -375,7 +371,7 @@ caused by: CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'XY'"); } -TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_intersection_setmetatable") { CheckResult result = check(R"( local t: {} & {} diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index f9b510c..765419c 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -29,7 +29,7 @@ TEST_CASE_FIXTURE(Fixture, "for_loop") CHECK_EQ(*typeChecker.numberType, *requireType("q")); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") { CheckResult result = check(R"( local n @@ -46,7 +46,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop") CHECK_EQ(*typeChecker.stringType, *requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_next") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_next") { CheckResult result = check(R"( local n @@ -90,7 +90,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") CHECK_EQ("Cannot call non-function string", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "for_in_with_just_one_iterator_is_ok") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_just_one_iterator_is_ok") { CheckResult result = check(R"( local function keys(dictionary) @@ -109,7 +109,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_with_just_one_iterator_is_ok") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "for_in_with_a_custom_iterator_should_type_check") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_a_custom_iterator_should_type_check") { CheckResult result = check(R"( local function range(l, h): () -> number @@ -161,7 +161,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") REQUIRE(get(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") { CheckResult result = check(R"( local function hasDivisors(value: number, table) @@ -210,7 +210,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right CHECK_EQ(typeChecker.stringType, tm->givenType); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") +TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") { CheckResult result = check(R"( function prime_iter(state, index) @@ -288,7 +288,7 @@ TEST_CASE_FIXTURE(Fixture, "repeat_loop_condition_binds_to_its_block") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "symbols_in_repeat_block_should_not_be_visible_beyond_until_condition") +TEST_CASE_FIXTURE(BuiltinsFixture, "symbols_in_repeat_block_should_not_be_visible_beyond_until_condition") { CheckResult result = check(R"( repeat @@ -301,7 +301,7 @@ TEST_CASE_FIXTURE(Fixture, "symbols_in_repeat_block_should_not_be_visible_beyond LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") +TEST_CASE_FIXTURE(BuiltinsFixture, "varlist_declared_by_for_in_loop_should_be_free") { CheckResult result = check(R"( local T = {} @@ -316,7 +316,7 @@ TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "properly_infer_iteratee_is_a_free_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "properly_infer_iteratee_is_a_free_table") { // In this case, we cannot know the element type of the table {}. It could be anything. // We therefore must initially ascribe a free typevar to iter. @@ -329,7 +329,7 @@ TEST_CASE_FIXTURE(Fixture, "properly_infer_iteratee_is_a_free_table") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") +TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_while") { CheckResult result = check(R"( while true do @@ -346,7 +346,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") CHECK_EQ(us->name, "a"); } -TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indices") +TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_produces_integral_indices") { CheckResult result = check(R"( local key @@ -378,7 +378,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") )"); } -TEST_CASE_FIXTURE(Fixture, "unreachable_code_after_infinite_loop") +TEST_CASE_FIXTURE(BuiltinsFixture, "unreachable_code_after_infinite_loop") { { CheckResult result = check(R"( @@ -460,7 +460,7 @@ TEST_CASE_FIXTURE(Fixture, "unreachable_code_after_infinite_loop") } } -TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_typecheck_crash_on_empty_optional") { CheckResult result = check(R"( local t = {} @@ -541,7 +541,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer") CHECK_EQ("Cannot iterate over a table without indexer", ge->message); } -TEST_CASE_FIXTURE(Fixture, "loop_iter_iter_metamethod") +TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_iter_metamethod") { ScopedFastFlag sff{"LuauTypecheckIter", true}; diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index b6f49f9..efa2a98 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -16,7 +16,7 @@ LUAU_FASTFLAG(LuauTableSubtypingVariance2) TEST_SUITE_BEGIN("TypeInferModules"); -TEST_CASE_FIXTURE(Fixture, "require") +TEST_CASE_FIXTURE(BuiltinsFixture, "require") { fileResolver.source["game/A"] = R"( local function hooty(x: number): string @@ -54,7 +54,7 @@ TEST_CASE_FIXTURE(Fixture, "require") REQUIRE_EQ("number", toString(*hType)); } -TEST_CASE_FIXTURE(Fixture, "require_types") +TEST_CASE_FIXTURE(BuiltinsFixture, "require_types") { fileResolver.source["workspace/A"] = R"( export type Point = {x: number, y: number} @@ -69,7 +69,7 @@ TEST_CASE_FIXTURE(Fixture, "require_types") )"; CheckResult bResult = frontend.check("workspace/B"); - dumpErrors(bResult); + LUAU_REQUIRE_NO_ERRORS(bResult); ModulePtr b = frontend.moduleResolver.modules["workspace/B"]; REQUIRE(b != nullptr); @@ -78,7 +78,7 @@ TEST_CASE_FIXTURE(Fixture, "require_types") REQUIRE_MESSAGE(bool(get(hType)), "Expected table but got " << toString(hType)); } -TEST_CASE_FIXTURE(Fixture, "require_a_variadic_function") +TEST_CASE_FIXTURE(BuiltinsFixture, "require_a_variadic_function") { fileResolver.source["game/A"] = R"( local T = {} @@ -121,7 +121,7 @@ TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type") REQUIRE_EQ(result.errors[0], (TypeError{Location{{1, 17}, {1, 40}}, UnknownSymbol{"SomeModule.DoesNotExist"}})); } -TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") +TEST_CASE_FIXTURE(BuiltinsFixture, "require_module_that_does_not_export") { const std::string sourceA = R"( )"; @@ -148,7 +148,7 @@ TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") CHECK_EQ("*unknown*", toString(hootyType)); } -TEST_CASE_FIXTURE(Fixture, "warn_if_you_try_to_require_a_non_modulescript") +TEST_CASE_FIXTURE(BuiltinsFixture, "warn_if_you_try_to_require_a_non_modulescript") { fileResolver.source["Modules/A"] = ""; fileResolver.sourceTypes["Modules/A"] = SourceCode::Local; @@ -164,7 +164,7 @@ TEST_CASE_FIXTURE(Fixture, "warn_if_you_try_to_require_a_non_modulescript") CHECK(get(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "general_require_call_expression") +TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_call_expression") { fileResolver.source["game/A"] = R"( --!strict @@ -183,7 +183,7 @@ a = tbl.abc.def CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "general_require_type_mismatch") +TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_type_mismatch") { fileResolver.source["game/A"] = R"( return { def = 4 } @@ -219,7 +219,7 @@ return m LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "custom_require_global") +TEST_CASE_FIXTURE(BuiltinsFixture, "custom_require_global") { CheckResult result = check(R"( --!nonstrict @@ -231,7 +231,7 @@ local crash = require(game.A) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "require_failed_module") +TEST_CASE_FIXTURE(BuiltinsFixture, "require_failed_module") { fileResolver.source["game/A"] = R"( return unfortunately() @@ -267,7 +267,7 @@ function x:Destroy(): () end LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_2") +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_2") { fileResolver.source["game/A"] = R"( export type Type = { x: { a: number } } @@ -285,7 +285,7 @@ type Rename = typeof(x.x) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_3") +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_3") { fileResolver.source["game/A"] = R"( local y = setmetatable({}, {}) @@ -304,7 +304,7 @@ type Rename = typeof(x.x) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "module_type_conflict") +TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict") { fileResolver.source["game/A"] = R"( export type T = { x: number } @@ -338,7 +338,7 @@ caused by: } } -TEST_CASE_FIXTURE(Fixture, "module_type_conflict_instantiated") +TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict_instantiated") { fileResolver.source["game/A"] = R"( export type Wrap = { x: T } diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 5cd3f3b..4169070 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -142,7 +142,7 @@ TEST_CASE_FIXTURE(Fixture, "inferring_hundreds_of_self_calls_should_not_suffocat CHECK_GE(50, module->internalTypes.typeVars.size()); } -TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self") +TEST_CASE_FIXTURE(BuiltinsFixture, "object_constructor_can_refer_to_method_of_self") { // CLI-30902 CheckResult result = check(R"( @@ -243,7 +243,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_ )"); } -TEST_CASE_FIXTURE(Fixture, "table_oop") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_oop") { CheckResult result = check(R"( --!strict diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index a2787ca..51f6fdf 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -77,7 +77,7 @@ TEST_CASE_FIXTURE(Fixture, "and_or_ternary") CHECK_EQ(toString(*requireType("s")), "number | string"); } -TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") { CheckResult result = check(R"( function add(a: number, b: string) @@ -140,7 +140,7 @@ TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") CHECK_EQ("number", toString(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_overloaded_multiply_that_is_an_intersection") { CheckResult result = check(R"( --!strict @@ -174,7 +174,7 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio CHECK_EQ("Vec3", toString(requireType("e"))); } -TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") { CheckResult result = check(R"( --!strict @@ -245,7 +245,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_have_a_m REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); } -TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") +TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") { CheckResult result = check(R"( local M = {} @@ -266,7 +266,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_offer_ov REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); } -TEST_CASE_FIXTURE(Fixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") { CheckResult result = check(R"( --!strict @@ -289,7 +289,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_compare_tables_that_do_not_have_the_same_meta REQUIRE_EQ((Location{{11, 18}, {11, 23}}), result.errors[1].location); } -TEST_CASE_FIXTURE(Fixture, "produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not") +TEST_CASE_FIXTURE(BuiltinsFixture, "produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not") { CheckResult result = check(R"( --!strict @@ -361,7 +361,7 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); } -TEST_CASE_FIXTURE(Fixture, "compound_assign_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") { CheckResult result = check(R"( --!strict @@ -381,7 +381,7 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_metatable") CHECK_EQ(0, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_mismatch_metatable") { CheckResult result = check(R"( --!strict @@ -428,7 +428,7 @@ local x = false LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") { CheckResult result = check(R"( --!strict @@ -461,7 +461,7 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); } -TEST_CASE_FIXTURE(Fixture, "unary_not_is_boolean") +TEST_CASE_FIXTURE(BuiltinsFixture, "unary_not_is_boolean") { CheckResult result = check(R"( local b = not "string" @@ -473,7 +473,7 @@ TEST_CASE_FIXTURE(Fixture, "unary_not_is_boolean") REQUIRE_EQ("boolean", toString(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "disallow_string_and_types_without_metatables_from_arithmetic_binary_ops") +TEST_CASE_FIXTURE(BuiltinsFixture, "disallow_string_and_types_without_metatables_from_arithmetic_binary_ops") { CheckResult result = check(R"( --!strict @@ -573,7 +573,7 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "and_binexps_dont_unify") +TEST_CASE_FIXTURE(BuiltinsFixture, "and_binexps_dont_unify") { CheckResult result = check(R"( --!strict @@ -628,7 +628,7 @@ TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") CHECK_EQ("Type contains a self-recursive construct that cannot be resolved", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") +TEST_CASE_FIXTURE(BuiltinsFixture, "UnknownGlobalCompoundAssign") { // In non-strict mode, global definition is still allowed { @@ -755,8 +755,6 @@ TEST_CASE_FIXTURE(Fixture, "refine_and_or") TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") { - ScopedFastFlag sff{"LuauDecoupleOperatorInferenceFromUnifiedTypeInference", true}; - CheckResult result = check(Mode::Strict, R"( local function f(x, y) return x + y @@ -779,4 +777,47 @@ TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") // the case right now, though. } +TEST_CASE_FIXTURE(BuiltinsFixture, "equality_operations_succeed_if_any_union_branch_succeeds") +{ + ScopedFastFlag sff("LuauSuccessTypingForEqualityOperations", true); + + CheckResult result = check(R"( + local mm = {} + type Foo = typeof(setmetatable({}, mm)) + local x: Foo + local y: Foo? + + local v1 = x == y + local v2 = y == x + local v3 = x ~= y + local v4 = y ~= x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CheckResult result2 = check(R"( + local mm1 = { + x = "foo", + } + + local mm2 = { + y = "bar", + } + + type Foo = typeof(setmetatable({}, mm1)) + type Bar = typeof(setmetatable({}, mm2)) + + local x1: Foo + local x2: Foo? + local y1: Bar + local y2: Bar? + + local v1 = x1 == y1 + local v2 = x2 == y2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result2); + CHECK(toString(result2.errors[0]) == "Types Foo and Bar cannot be compared with == because they do not have the same metatable"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 2ef7741..ee3ae97 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -53,7 +53,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") CHECK_EQ(expected, decorateWithTypes(code)); } -TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns") +TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall_returns_what_f_returns") { const std::string code = R"( local a, b, c = xpcall(function() return 1, "foo" end, function() return "foo", 1 end) @@ -105,7 +105,7 @@ TEST_CASE_FIXTURE(Fixture, "it_should_be_agnostic_of_actual_size") // Ideally setmetatable's second argument would be an optional free table. // For now, infer it as just a free table. -TEST_CASE_FIXTURE(Fixture, "setmetatable_constrains_free_type_into_free_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_constrains_free_type_into_free_table") { CheckResult result = check(R"( local a = {} @@ -146,7 +146,7 @@ TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") // Originally from TypeInfer.test.cpp. // I dont think type checking the metamethod at every site of == is the correct thing to do. // We should be type checking the metamethod at the call site of setmetatable. -TEST_CASE_FIXTURE(Fixture, "error_on_eq_metamethod_returning_a_type_other_than_boolean") +TEST_CASE_FIXTURE(BuiltinsFixture, "error_on_eq_metamethod_returning_a_type_other_than_boolean") { CheckResult result = check(R"( local tab = {a = 1} @@ -428,7 +428,7 @@ TEST_CASE_FIXTURE(Fixture, "pcall_returns_at_least_two_value_but_function_return } // Belongs in TypeInfer.builtins.test.cpp. -TEST_CASE_FIXTURE(Fixture, "choose_the_right_overload_for_pcall") +TEST_CASE_FIXTURE(BuiltinsFixture, "choose_the_right_overload_for_pcall") { CheckResult result = check(R"( local function f(): number @@ -449,7 +449,7 @@ TEST_CASE_FIXTURE(Fixture, "choose_the_right_overload_for_pcall") } // Belongs in TypeInfer.builtins.test.cpp. -TEST_CASE_FIXTURE(Fixture, "function_returns_many_things_but_first_of_it_is_forgotten") +TEST_CASE_FIXTURE(BuiltinsFixture, "function_returns_many_things_but_first_of_it_is_forgotten") { CheckResult result = check(R"( local function f(): (number, string, boolean) diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 136ca00..8c13049 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -240,7 +240,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_in_if_condition_position") CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); } -TEST_CASE_FIXTURE(Fixture, "typeguard_in_assert_position") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_assert_position") { CheckResult result = check(R"( local a @@ -300,7 +300,7 @@ TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "impossible_type_narrow_is_not_an_error") +TEST_CASE_FIXTURE(BuiltinsFixture, "impossible_type_narrow_is_not_an_error") { // This unit test serves as a reminder to not implement this warning until Luau is intelligent enough. // For instance, getting a value out of the indexer and checking whether the value exists is not an error. @@ -333,7 +333,7 @@ TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") CHECK_EQ("number?", toString(requireType("bar"))); } -TEST_CASE_FIXTURE(Fixture, "index_on_a_refined_property") +TEST_CASE_FIXTURE(BuiltinsFixture, "index_on_a_refined_property") { CheckResult result = check(R"( local t: {x: {y: string}?} = {x = {y = "hello!"}} @@ -346,7 +346,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_refined_property") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_constraints") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_non_binary_expressions_actually_resolve_constraints") { CheckResult result = check(R"( local foo: string? = "hello" @@ -730,7 +730,7 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } -TEST_CASE_FIXTURE(Fixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") { CheckResult result = check(R"( local function f(t: {x: number}) @@ -846,7 +846,7 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") CHECK_EQ("{| x: boolean |}?", toString(requireTypeAtPosition({3, 28}))); } -TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") +TEST_CASE_FIXTURE(BuiltinsFixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") { CheckResult result = check(R"( local a: (number | string)? @@ -862,7 +862,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") CHECK_EQ("number", toString(requireTypeAtPosition({5, 18}))); } -TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") +TEST_CASE_FIXTURE(BuiltinsFixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") { // This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result. CheckResult result = check(R"( @@ -899,7 +899,7 @@ TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_n CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); } -TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") +TEST_CASE_FIXTURE(BuiltinsFixture, "is_truthy_constraint_ifelse_expression") { CheckResult result = check(R"( function f(v:string?) @@ -913,7 +913,7 @@ TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") CHECK_EQ("nil", toString(requireTypeAtPosition({2, 45}))); } -TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression") +TEST_CASE_FIXTURE(BuiltinsFixture, "invert_is_truthy_constraint_ifelse_expression") { CheckResult result = check(R"( function f(v:string?) @@ -945,7 +945,7 @@ TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); } -TEST_CASE_FIXTURE(Fixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") +TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") { CheckResult result = check(R"( local foo: string? = "hi" diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 8d6682b..79eeb82 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -415,7 +415,7 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_insert_with_a_singleton_argument") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") { ScopedFastFlag sff[]{ {"LuauWidenIfSupertypeIsFree2", true}, diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 8e53599..5078b0b 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -201,7 +201,7 @@ TEST_CASE_FIXTURE(Fixture, "used_dot_instead_of_colon") REQUIRE(it != result.errors.end()); } -TEST_CASE_FIXTURE(Fixture, "used_colon_correctly") +TEST_CASE_FIXTURE(BuiltinsFixture, "used_colon_correctly") { CheckResult result = check(R"( --!nonstrict @@ -883,7 +883,7 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s CHECK_EQ(*typeChecker.stringType, *propertyA); } -TEST_CASE_FIXTURE(Fixture, "oop_indexer_works") +TEST_CASE_FIXTURE(BuiltinsFixture, "oop_indexer_works") { CheckResult result = check(R"( local clazz = {} @@ -906,7 +906,7 @@ TEST_CASE_FIXTURE(Fixture, "oop_indexer_works") CHECK_EQ(*typeChecker.stringType, *requireType("words")); } -TEST_CASE_FIXTURE(Fixture, "indexer_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_table") { CheckResult result = check(R"( local clazz = {a="hello"} @@ -919,7 +919,7 @@ TEST_CASE_FIXTURE(Fixture, "indexer_table") CHECK_EQ(*typeChecker.stringType, *requireType("b")); } -TEST_CASE_FIXTURE(Fixture, "indexer_fn") +TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_fn") { CheckResult result = check(R"( local instanace = setmetatable({}, {__index=function() return 10 end}) @@ -930,7 +930,7 @@ TEST_CASE_FIXTURE(Fixture, "indexer_fn") CHECK_EQ(*typeChecker.numberType, *requireType("b")); } -TEST_CASE_FIXTURE(Fixture, "meta_add") +TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add") { // Note: meta_add_inferred and this unit test are currently the same exact thing. // We'll want to change this one in particular when we add real syntax for metatables. @@ -947,7 +947,7 @@ TEST_CASE_FIXTURE(Fixture, "meta_add") CHECK_EQ(follow(requireType("a")), follow(requireType("c"))); } -TEST_CASE_FIXTURE(Fixture, "meta_add_inferred") +TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add_inferred") { CheckResult result = check(R"( local a = {} @@ -960,7 +960,7 @@ TEST_CASE_FIXTURE(Fixture, "meta_add_inferred") CHECK_EQ(*requireType("a"), *requireType("c")); } -TEST_CASE_FIXTURE(Fixture, "meta_add_both_ways") +TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add_both_ways") { CheckResult result = check(R"( type VectorMt = { __add: (Vector, number) -> Vector } @@ -980,7 +980,7 @@ TEST_CASE_FIXTURE(Fixture, "meta_add_both_ways") // This test exposed a bug where we let go of the "seen" stack while unifying table types // As a result, type inference crashed with a stack overflow. -TEST_CASE_FIXTURE(Fixture, "unification_of_unions_in_a_self_referential_type") +TEST_CASE_FIXTURE(BuiltinsFixture, "unification_of_unions_in_a_self_referential_type") { CheckResult result = check(R"( type A = {} @@ -1009,7 +1009,7 @@ TEST_CASE_FIXTURE(Fixture, "unification_of_unions_in_a_self_referential_type") CHECK_EQ(bmtv->metatable, requireType("bmt")); } -TEST_CASE_FIXTURE(Fixture, "oop_polymorphic") +TEST_CASE_FIXTURE(BuiltinsFixture, "oop_polymorphic") { CheckResult result = check(R"( local animal = {} @@ -1060,7 +1060,7 @@ TEST_CASE_FIXTURE(Fixture, "user_defined_table_types_are_named") CHECK_EQ("Vector3", toString(requireType("v"))); } -TEST_CASE_FIXTURE(Fixture, "result_is_always_any_if_lhs_is_any") +TEST_CASE_FIXTURE(BuiltinsFixture, "result_is_always_any_if_lhs_is_any") { CheckResult result = check(R"( type Vector3MT = { @@ -1133,7 +1133,7 @@ TEST_CASE_FIXTURE(Fixture, "nice_error_when_trying_to_fetch_property_of_boolean" CHECK_EQ("Type 'boolean' does not have key 'some_prop'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_builtin_sealed_table_must_fail") +TEST_CASE_FIXTURE(BuiltinsFixture, "defining_a_method_for_a_builtin_sealed_table_must_fail") { CheckResult result = check(R"( function string.m() end @@ -1142,7 +1142,7 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_builtin_sealed_table_must_fa LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_builtin_sealed_table_must_fail") +TEST_CASE_FIXTURE(BuiltinsFixture, "defining_a_self_method_for_a_builtin_sealed_table_must_fail") { CheckResult result = check(R"( function string:m() end @@ -1261,7 +1261,7 @@ TEST_CASE_FIXTURE(Fixture, "found_like_key_in_table_function_call") CHECK_EQ(toString(te), "Key 'fOo' not found in table 't'. Did you mean 'Foo'?"); } -TEST_CASE_FIXTURE(Fixture, "found_like_key_in_table_property_access") +TEST_CASE_FIXTURE(BuiltinsFixture, "found_like_key_in_table_property_access") { CheckResult result = check(R"( local t = {X = 1} @@ -1286,7 +1286,7 @@ TEST_CASE_FIXTURE(Fixture, "found_like_key_in_table_property_access") CHECK_EQ(toString(te), "Key 'x' not found in table 't'. Did you mean 'X'?"); } -TEST_CASE_FIXTURE(Fixture, "found_multiple_like_keys") +TEST_CASE_FIXTURE(BuiltinsFixture, "found_multiple_like_keys") { CheckResult result = check(R"( local t = {Foo = 1, foO = 2} @@ -1312,7 +1312,7 @@ TEST_CASE_FIXTURE(Fixture, "found_multiple_like_keys") CHECK_EQ(toString(te), "Key 'foo' not found in table 't'. Did you mean one of 'Foo', 'foO'?"); } -TEST_CASE_FIXTURE(Fixture, "dont_suggest_exact_match_keys") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_suggest_exact_match_keys") { CheckResult result = check(R"( local t = {} @@ -1339,7 +1339,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_exact_match_keys") CHECK_EQ(toString(te), "Key 'Foo' not found in table 't'. Did you mean 'foO'?"); } -TEST_CASE_FIXTURE(Fixture, "getmetatable_returns_pointer_to_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_pointer_to_metatable") { CheckResult result = check(R"( local t = {x = 1} @@ -1352,7 +1352,7 @@ TEST_CASE_FIXTURE(Fixture, "getmetatable_returns_pointer_to_metatable") CHECK_EQ(*requireType("mt"), *requireType("returnedMT")); } -TEST_CASE_FIXTURE(Fixture, "metatable_mismatch_should_fail") +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_mismatch_should_fail") { CheckResult result = check(R"( local t1 = {x = 1} @@ -1374,7 +1374,7 @@ TEST_CASE_FIXTURE(Fixture, "metatable_mismatch_should_fail") CHECK_EQ(*tm->givenType, *requireType("t2")); } -TEST_CASE_FIXTURE(Fixture, "property_lookup_through_tabletypevar_metatable") +TEST_CASE_FIXTURE(BuiltinsFixture, "property_lookup_through_tabletypevar_metatable") { CheckResult result = check(R"( local t = {x = 1} @@ -1393,7 +1393,7 @@ TEST_CASE_FIXTURE(Fixture, "property_lookup_through_tabletypevar_metatable") CHECK_EQ(up->key, "z"); } -TEST_CASE_FIXTURE(Fixture, "missing_metatable_for_sealed_tables_do_not_get_inferred") +TEST_CASE_FIXTURE(BuiltinsFixture, "missing_metatable_for_sealed_tables_do_not_get_inferred") { CheckResult result = check(R"( local t = {x = 1} @@ -1742,7 +1742,7 @@ TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") CHECK_EQ("Cannot add property 'b' to table '{| x: number |}'", toString(result.errors[1])); } -TEST_CASE_FIXTURE(Fixture, "builtin_table_names") +TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_table_names") { CheckResult result = check(R"( os.h = 2 @@ -1755,7 +1755,7 @@ TEST_CASE_FIXTURE(Fixture, "builtin_table_names") CHECK_EQ("Cannot add property 'k' to table 'string'", toString(result.errors[1])); } -TEST_CASE_FIXTURE(Fixture, "persistent_sealed_table_is_immutable") +TEST_CASE_FIXTURE(BuiltinsFixture, "persistent_sealed_table_is_immutable") { CheckResult result = check(R"( --!nonstrict @@ -1858,7 +1858,7 @@ local foos: {Foo} = { LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "quantifying_a_bound_var_works") +TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") { CheckResult result = check(R"( local clazz = {} @@ -1983,7 +1983,7 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") { CheckResult result = check(R"( --!nonstrict @@ -1996,7 +1996,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_strict") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_strict") { ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; @@ -2052,7 +2052,7 @@ caused by: Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); } -TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") +TEST_CASE_FIXTURE(BuiltinsFixture, "error_detailed_metatable_prop") { ScopedFastFlag sff[]{ {"LuauTableSubtypingVariance2", true}, @@ -2183,7 +2183,7 @@ a.p = { x = 9 } LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_metatable_type_call") { ScopedFastFlag sff[]{ {"LuauUnsealedTableLiteral", true}, @@ -2277,7 +2277,7 @@ local y = #x LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") { ScopedFastFlag sff{"LuauTerminateCyclicMetatableIndexLookup", true}; @@ -2296,7 +2296,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable CHECK_EQ("Type 't' does not have key 'p'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "give_up_after_one_metatable_index_look_up") +TEST_CASE_FIXTURE(BuiltinsFixture, "give_up_after_one_metatable_index_look_up") { CheckResult result = check(R"( local data = { x = 5 } @@ -2478,7 +2478,7 @@ TEST_CASE_FIXTURE(Fixture, "free_rhs_table_can_also_be_bound") )"); } -TEST_CASE_FIXTURE(Fixture, "table_unifies_into_map") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_unifies_into_map") { CheckResult result = check(R"( local Instance: any @@ -2564,7 +2564,7 @@ TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") * the generalization process), then it loses the knowledge that its metatable will have an :incr() * method. */ -TEST_CASE_FIXTURE(Fixture, "dont_quantify_table_that_belongs_to_outer_scope") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_quantify_table_that_belongs_to_outer_scope") { CheckResult result = check(R"( local Counter = {} @@ -2606,7 +2606,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_quantify_table_that_belongs_to_outer_scope") } // TODO: CLI-39624 -TEST_CASE_FIXTURE(Fixture, "instantiate_tables_at_scope_level") +TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_tables_at_scope_level") { CheckResult result = check(R"( --!strict @@ -2690,7 +2690,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_crash_when_setmetatable_does_not_produce_a_meta LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning") +TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_table_cloning") { CheckResult result = check(R"( --!nonstrict @@ -2711,7 +2711,7 @@ type t0 = any CHECK(ttv->instantiatedTypeParams.empty()); } -TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_2") +TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_table_cloning_2") { ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; @@ -2767,7 +2767,7 @@ local baz = foo[bar] CHECK_EQ(result.errors[0].location, Location{Position{3, 16}, Position{3, 19}}); } -TEST_CASE_FIXTURE(Fixture, "table_simple_call") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_simple_call") { CheckResult result = check(R"( local a = setmetatable({ x = 2 }, { @@ -2783,7 +2783,7 @@ local c = a(2) -- too many arguments CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "access_index_metamethod_that_returns_variadic") +TEST_CASE_FIXTURE(BuiltinsFixture, "access_index_metamethod_that_returns_variadic") { CheckResult result = check(R"( type Foo = {x: string} @@ -2878,7 +2878,7 @@ TEST_CASE_FIXTURE(Fixture, "pairs_parameters_are_not_unsealed_tables") )"); } -TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_function_check_use_after_free") { CheckResult result = check(R"( local t = {} @@ -2916,7 +2916,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the } // The real bug here was that we weren't always uncondionally typechecking a trailing return statement last. -TEST_CASE_FIXTURE(Fixture, "dont_leak_free_table_props") +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_leak_free_table_props") { CheckResult result = check(R"( local function a(state) diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index e81ef1a..48cd1c3 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -161,7 +161,7 @@ TEST_CASE_FIXTURE(Fixture, "unify_nearly_identical_recursive_types") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") +TEST_CASE_FIXTURE(BuiltinsFixture, "warn_on_lowercase_parent_property") { CheckResult result = check(R"( local M = require(script.parent.DoesNotMatter) @@ -175,7 +175,7 @@ TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") REQUIRE_EQ("parent", ed->symbol); } -TEST_CASE_FIXTURE(Fixture, "weird_case") +TEST_CASE_FIXTURE(BuiltinsFixture, "weird_case") { CheckResult result = check(R"( local function f() return 4 end @@ -419,7 +419,7 @@ TEST_CASE_FIXTURE(Fixture, "globals_everywhere") CHECK_EQ("any", toString(requireType("bar"))); } -TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") +TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_do") { CheckResult result = check(R"( do @@ -534,7 +534,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_assert") LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_replacement_name_in_error") +TEST_CASE_FIXTURE(BuiltinsFixture, "tc_after_error_recovery_no_replacement_name_in_error") { { CheckResult result = check(R"( @@ -587,7 +587,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_replacement_name_in_error } } -TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked") +TEST_CASE_FIXTURE(BuiltinsFixture, "index_expr_should_be_checked") { CheckResult result = check(R"( local foo: any @@ -768,7 +768,7 @@ b, c = {2, "s"}, {"b", 4} LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "infer_assignment_value_types_mutable_lval") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_assignment_value_types_mutable_lval") { CheckResult result = check(R"( local a = {} @@ -836,7 +836,7 @@ local a: number? = if true then 1 else nil LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_3") +TEST_CASE_FIXTURE(BuiltinsFixture, "tc_if_else_expressions_expected_type_3") { CheckResult result = check(R"( local function times(n: any, f: () -> T) @@ -907,7 +907,7 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_found_this") )"); } -TEST_CASE_FIXTURE(Fixture, "recursive_metatable_crash") +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_metatable_crash") { CheckResult result = check(R"( local function getIt() @@ -1041,7 +1041,6 @@ TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") TEST_CASE_FIXTURE(Fixture, "do_not_bind_a_free_table_to_a_union_containing_that_table") { ScopedFastFlag flag[] = { - {"LuauStatFunctionSimplify4", true}, {"LuauLowerBoundsCalculation", true}, {"LuauDifferentOrderOfUnificationDoesntMatter2", true}, }; diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 8756264..49deae7 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -196,7 +196,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") CHECK_EQ(toString(tm->wantedType), "string"); } -TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unification") +TEST_CASE_FIXTURE(BuiltinsFixture, "cli_41095_concat_log_in_sealed_table_unification") { CheckResult result = check(R"( --!strict diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index f141622..fd66b08 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -339,7 +339,7 @@ local c: Packed CHECK_EQ(toString(ttvC->instantiatedTypePackParams[0], {true}), "number, boolean"); } -TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_type_packs_import") { fileResolver.source["game/A"] = R"( export type Packed = { a: T, b: (U...) -> () } @@ -369,7 +369,7 @@ local d: { a: typeof(c) } CHECK_EQ(toString(requireType("d")), "{| a: Packed |}"); } -TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_pack_type_parameters") { fileResolver.source["game/A"] = R"( export type Packed = { a: T, b: (U...) -> () } @@ -784,7 +784,7 @@ local a: Y<...number> LUAU_REQUIRE_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "type_alias_default_export") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_default_export") { fileResolver.source["Module/Types"] = R"( export type A = { a: T, b: U } diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 96bdd53..277f388 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -104,7 +104,7 @@ TEST_CASE_FIXTURE(Fixture, "optional_arguments_table2") REQUIRE(!result.errors.empty()); } -TEST_CASE_FIXTURE(Fixture, "error_takes_optional_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "error_takes_optional_arguments") { CheckResult result = check(R"( error("message") @@ -517,10 +517,8 @@ TEST_CASE_FIXTURE(Fixture, "dont_allow_cyclic_unions_to_be_inferred") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "table_union_write_indirect") +TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect") { - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify4", true}; - CheckResult result = check(R"( type A = { x: number, y: (number) -> string } | { z: number, y: (number) -> string } From 7e9e697489c886773d289020afa210a1166ea7d6 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 19 May 2022 16:46:52 -0700 Subject: [PATCH 10/19] Sync to upstream/release/527 --- Analysis/include/Luau/Clone.h | 2 +- Analysis/include/Luau/Error.h | 4 +- Analysis/include/Luau/LValue.h | 4 - Analysis/include/Luau/Module.h | 36 +- Analysis/include/Luau/Substitution.h | 3 +- Analysis/include/Luau/ToString.h | 1 + Analysis/include/Luau/TxnLog.h | 11 - Analysis/include/Luau/TypeArena.h | 42 ++ Analysis/include/Luau/TypeInfer.h | 19 +- Analysis/include/Luau/Unifier.h | 4 +- Analysis/include/Luau/VisitTypeVar.h | 20 +- Analysis/src/AstQuery.cpp | 20 +- Analysis/src/BuiltinDefinitions.cpp | 53 +- Analysis/src/Clone.cpp | 44 +- Analysis/src/Error.cpp | 13 +- Analysis/src/IostreamHelpers.cpp | 2 +- Analysis/src/LValue.cpp | 21 - Analysis/src/Module.cpp | 123 ++-- Analysis/src/Normalize.cpp | 32 +- Analysis/src/Quantify.cpp | 39 +- Analysis/src/Substitution.cpp | 113 ++-- Analysis/src/ToString.cpp | 25 +- Analysis/src/TxnLog.cpp | 34 +- Analysis/src/TypeArena.cpp | 88 +++ Analysis/src/TypeInfer.cpp | 650 +++++---------------- Analysis/src/TypeUtils.cpp | 11 +- Analysis/src/TypeVar.cpp | 9 +- Analysis/src/Unifier.cpp | 84 +-- Compiler/include/Luau/BytecodeBuilder.h | 1 + Compiler/src/BytecodeBuilder.cpp | 10 + Compiler/src/Compiler.cpp | 289 ++++++--- Compiler/src/ConstantFolding.cpp | 4 +- Sources.cmake | 2 + VM/src/ltablib.cpp | 27 - VM/src/lvmexecute.cpp | 14 - tests/AstQuery.test.cpp | 13 + tests/Autocomplete.test.cpp | 4 + tests/Compiler.test.cpp | 235 ++++++-- tests/Module.test.cpp | 8 - tests/NonstrictMode.test.cpp | 4 - tests/Normalize.test.cpp | 45 ++ tests/RuntimeLimits.test.cpp | 2 - tests/ToString.test.cpp | 81 +++ tests/TypeInfer.builtins.test.cpp | 14 - tests/TypeInfer.functions.test.cpp | 2 - tests/TypeInfer.generics.test.cpp | 74 +++ tests/TypeInfer.intersectionTypes.test.cpp | 2 - tests/TypeInfer.loops.test.cpp | 2 - tests/TypeInfer.operators.test.cpp | 2 - tests/TypeInfer.provisional.test.cpp | 36 +- tests/TypeInfer.refinements.test.cpp | 143 +---- tests/TypeInfer.singletons.test.cpp | 25 +- tests/TypeInfer.tables.test.cpp | 34 +- tests/TypeInfer.test.cpp | 4 - tests/TypeInfer.unionTypes.test.cpp | 1 - 55 files changed, 1213 insertions(+), 1372 deletions(-) create mode 100644 Analysis/include/Luau/TypeArena.h create mode 100644 Analysis/src/TypeArena.cpp diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 9b6ffa6..9fcbce0 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/TypeArena.h" #include "Luau/TypeVar.h" #include @@ -18,7 +19,6 @@ struct CloneState SeenTypePacks seenTypePacks; int recursionCount = 0; - bool encounteredFreeType = false; // TODO: Remove with LuauLosslessClone. }; TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 7068314..b453067 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -5,6 +5,7 @@ #include "Luau/Location.h" #include "Luau/TypeVar.h" #include "Luau/Variant.h" +#include "Luau/TypeArena.h" namespace Luau { @@ -108,9 +109,6 @@ struct FunctionDoesNotTakeSelf struct FunctionRequiresSelf { - // TODO: Delete with LuauAnyInIsOptionalIsOptional - int requiredExtraNils = 0; - bool operator==(const FunctionRequiresSelf& rhs) const; }; diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h index afb7141..1a92d52 100644 --- a/Analysis/include/Luau/LValue.h +++ b/Analysis/include/Luau/LValue.h @@ -34,10 +34,6 @@ const LValue* baseof(const LValue& lvalue); std::optional tryGetLValue(const class AstExpr& expr); -// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. -// TODO: remove with FFlagLuauTypecheckOptPass -std::pair> getFullName(const LValue& lvalue); - // Utility function: breaks down an LValue to get at the Symbol Symbol getBaseSymbol(const LValue& lvalue); diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 0dd4418..00e1e63 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -2,11 +2,10 @@ #pragma once #include "Luau/FileResolver.h" -#include "Luau/TypePack.h" -#include "Luau/TypedAllocator.h" #include "Luau/ParseOptions.h" #include "Luau/Error.h" #include "Luau/ParseResult.h" +#include "Luau/TypeArena.h" #include #include @@ -54,35 +53,6 @@ struct RequireCycle std::vector path; // one of the paths for a require() to go all the way back to the originating module }; -struct TypeArena -{ - TypedAllocator typeVars; - TypedAllocator typePacks; - - void clear(); - - template - TypeId addType(T tv) - { - if constexpr (std::is_same_v) - LUAU_ASSERT(tv.options.size() >= 2); - - return addTV(TypeVar(std::move(tv))); - } - - TypeId addTV(TypeVar&& tv); - - TypeId freshType(TypeLevel level); - - TypePackId addTypePack(std::initializer_list types); - TypePackId addTypePack(std::vector types); - TypePackId addTypePack(TypePack pack); - TypePackId addTypePack(TypePackVar pack); -}; - -void freeze(TypeArena& arena); -void unfreeze(TypeArena& arena); - struct Module { ~Module(); @@ -111,9 +81,7 @@ struct Module // Once a module has been typechecked, we clone its public interface into a separate arena. // This helps us to force TypeVar ownership into a DAG rather than a DCG. - // Returns true if there were any free types encountered in the public interface. This - // indicates a bug in the type checker that we want to surface. - bool clonePublicInterface(InternalErrorReporter& ice); + void clonePublicInterface(InternalErrorReporter& ice); }; } // namespace Luau diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 6f5931e..f3c3ae9 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -1,8 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Module.h" -#include "Luau/ModuleResolver.h" +#include "Luau/TypeArena.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/DenseHash.h" diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index f4db5e3..3b380a6 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -28,6 +28,7 @@ struct ToStringOptions bool functionTypeArguments = false; // If true, output function type argument names when they are available bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. + bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self bool indent = false; size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 995ed6c..cd115e3 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -7,8 +7,6 @@ #include "Luau/TypeVar.h" #include "Luau/TypePack.h" -LUAU_FASTFLAG(LuauTypecheckOptPass) - namespace Luau { @@ -93,15 +91,6 @@ struct TxnLog { } - TxnLog(TxnLog* parent, std::vector>* sharedSeen) - : typeVarChanges(nullptr) - , typePackChanges(nullptr) - , parent(parent) - , sharedSeen(sharedSeen) - { - LUAU_ASSERT(!FFlag::LuauTypecheckOptPass); - } - TxnLog(const TxnLog&) = delete; TxnLog& operator=(const TxnLog&) = delete; diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h new file mode 100644 index 0000000..7c74158 --- /dev/null +++ b/Analysis/include/Luau/TypeArena.h @@ -0,0 +1,42 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypedAllocator.h" +#include "Luau/TypeVar.h" +#include "Luau/TypePack.h" + +#include + +namespace Luau +{ + +struct TypeArena +{ + TypedAllocator typeVars; + TypedAllocator typePacks; + + void clear(); + + template + TypeId addType(T tv) + { + if constexpr (std::is_same_v) + LUAU_ASSERT(tv.options.size() >= 2); + + return addTV(TypeVar(std::move(tv))); + } + + TypeId addTV(TypeVar&& tv); + + TypeId freshType(TypeLevel level); + + TypePackId addTypePack(std::initializer_list types); + TypePackId addTypePack(std::vector types); + TypePackId addTypePack(TypePack pack); + TypePackId addTypePack(TypePackVar pack); +}; + +void freeze(TypeArena& arena); +void unfreeze(TypeArena& arena); + +} diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index ac88013..fcaf5ba 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -187,7 +187,6 @@ struct TypeChecker ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); ExprResult checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); - ExprResult checkExpr_(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); ExprResult checkExpr(const ScopePtr& scope, const AstExprUnary& expr); TypeId checkRelationalOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); @@ -395,7 +394,7 @@ private: const AstArray& genericNames, const AstArray& genericPackNames, bool useCache = false); public: - ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); + void resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); private: void refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate); @@ -403,14 +402,14 @@ private: std::optional resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue); - void resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); - void resolve(const Predicate& predicate, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); - void resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); - void resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const PredicateVec& predicates, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); + void resolve(const Predicate& predicate, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); + void resolve(const TruthyPredicate& truthyP, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); + void resolve(const AndPredicate& andP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const OrPredicate& orP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const TypeGuardPredicate& typeguardP, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense); bool isNonstrictMode() const; bool useConstrainedIntersections() const; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 418d4ca..0e24c8b 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -5,7 +5,7 @@ #include "Luau/Location.h" #include "Luau/TxnLog.h" #include "Luau/TypeInfer.h" -#include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header. +#include "Luau/TypeArena.h" #include "Luau/UnifierSharedState.h" #include @@ -55,8 +55,6 @@ struct Unifier UnifierSharedState& sharedState; Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); - Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, Variance variance, - UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 67fce5e..2e98f52 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -10,6 +10,7 @@ LUAU_FASTFLAG(LuauUseVisitRecursionLimit) LUAU_FASTINT(LuauVisitRecursionLimit) +LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) namespace Luau { @@ -471,18 +472,21 @@ struct GenericTypeVarVisitor else if (auto pack = get(tp)) { - visit(tp, *pack); + bool res = visit(tp, *pack); + if (!FFlag::LuauNormalizeFlagIsConservative || res) + { + for (TypeId ty : pack->head) + traverse(ty); - for (TypeId ty : pack->head) - traverse(ty); - - if (pack->tail) - traverse(*pack->tail); + if (pack->tail) + traverse(*pack->tail); + } } else if (auto pack = get(tp)) { - visit(tp, *pack); - traverse(pack->ty); + bool res = visit(tp, *pack); + if (!FFlag::LuauNormalizeFlagIsConservative || res) + traverse(pack->ty); } else LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypePackId) is not exhaustive!"); diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 0aed34c..0522b1f 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -71,9 +71,11 @@ struct FindFullAncestry final : public AstVisitor { std::vector nodes; Position pos; + Position documentEnd; - explicit FindFullAncestry(Position pos) + explicit FindFullAncestry(Position pos, Position documentEnd) : pos(pos) + , documentEnd(documentEnd) { } @@ -84,6 +86,16 @@ struct FindFullAncestry final : public AstVisitor nodes.push_back(node); return true; } + + // Edge case: If we ask for the node at the position that is the very end of the document + // return the innermost AST element that ends at that position. + + if (node->location.end == documentEnd && pos >= documentEnd) + { + nodes.push_back(node); + return true; + } + return false; } }; @@ -92,7 +104,11 @@ struct FindFullAncestry final : public AstVisitor std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos) { - FindFullAncestry finder(pos); + const Position end = source.root->location.end; + if (pos > end) + pos = end; + + FindFullAncestry finder(pos, end); source.root->visit(&finder); return std::move(finder.nodes); } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 3895b01..5ed6de6 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -8,7 +8,6 @@ #include -LUAU_FASTFLAG(LuauAssertStripsFalsyTypes) LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) /** FIXME: Many of these type definitions are not quite completely accurate. @@ -408,41 +407,29 @@ static std::optional> magicFunctionAssert( { auto [paramPack, predicates] = exprResult; - if (FFlag::LuauAssertStripsFalsyTypes) + TypeArena& arena = typechecker.currentModule->internalTypes; + + auto [head, tail] = flatten(paramPack); + if (head.empty() && tail) { - TypeArena& arena = typechecker.currentModule->internalTypes; - - auto [head, tail] = flatten(paramPack); - if (head.empty() && tail) - { - std::optional fst = first(*tail); - if (!fst) - return ExprResult{paramPack}; - head.push_back(*fst); - } - - typechecker.reportErrors(typechecker.resolve(predicates, scope, true)); - - if (head.size() > 0) - { - std::optional newhead = typechecker.pickTypesFromSense(head[0], true); - if (!newhead) - head = {typechecker.nilType}; - else - head[0] = *newhead; - } - - return ExprResult{arena.addTypePack(TypePack{std::move(head), tail})}; - } - else - { - if (expr.args.size < 1) + std::optional fst = first(*tail); + if (!fst) return ExprResult{paramPack}; - - typechecker.reportErrors(typechecker.resolve(predicates, scope, true)); - - return ExprResult{paramPack}; + head.push_back(*fst); } + + typechecker.resolve(predicates, scope, true); + + if (head.size() > 0) + { + std::optional newhead = typechecker.pickTypesFromSense(head[0], true); + if (!newhead) + head = {typechecker.nilType}; + else + head[0] = *newhead; + } + + return ExprResult{arena.addTypePack(TypePack{std::move(head), tail})}; } static std::optional> magicFunctionPack( diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 1aa556e..a3611f5 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -1,7 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Clone.h" -#include "Luau/Module.h" #include "Luau/RecursionCounter.h" #include "Luau/TypePack.h" #include "Luau/Unifiable.h" @@ -9,8 +8,6 @@ LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) -LUAU_FASTFLAG(LuauTypecheckOptPass) -LUAU_FASTFLAGVARIABLE(LuauLosslessClone, false) LUAU_FASTFLAG(LuauNoMethodLocations) namespace Luau @@ -89,20 +86,8 @@ struct TypePackCloner void operator()(const Unifiable::Free& t) { - if (FFlag::LuauLosslessClone) - { - defaultClone(t); - } - else - { - cloneState.encounteredFreeType = true; - - TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); - TypePackId cloned = dest.addTypePack(*err); - seenTypePacks[typePackId] = cloned; - } + defaultClone(t); } - void operator()(const Unifiable::Generic& t) { defaultClone(t); @@ -152,18 +137,7 @@ void TypeCloner::defaultClone(const T& t) void TypeCloner::operator()(const Unifiable::Free& t) { - if (FFlag::LuauLosslessClone) - { - defaultClone(t); - } - else - { - cloneState.encounteredFreeType = true; - - TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); - TypeId cloned = dest.addType(*err); - seenTypes[typeId] = cloned; - } + defaultClone(t); } void TypeCloner::operator()(const Unifiable::Generic& t) @@ -191,9 +165,6 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t) void TypeCloner::operator()(const ConstrainedTypeVar& t) { - if (!FFlag::LuauLosslessClone) - cloneState.encounteredFreeType = true; - TypeId res = dest.addType(ConstrainedTypeVar{t.level}); ConstrainedTypeVar* ctv = getMutable(res); LUAU_ASSERT(ctv); @@ -230,9 +201,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t) ftv->argTypes = clone(t.argTypes, dest, cloneState); ftv->argNames = t.argNames; ftv->retType = clone(t.retType, dest, cloneState); - - if (FFlag::LuauTypecheckOptPass) - ftv->hasNoGenerics = t.hasNoGenerics; + ftv->hasNoGenerics = t.hasNoGenerics; } void TypeCloner::operator()(const TableTypeVar& t) @@ -270,13 +239,6 @@ void TypeCloner::operator()(const TableTypeVar& t) for (TypePackId& arg : ttv->instantiatedTypePackParams) arg = clone(arg, dest, cloneState); - if (!FFlag::LuauLosslessClone && ttv->state == TableState::Free) - { - cloneState.encounteredFreeType = true; - - ttv->state = TableState::Sealed; - } - ttv->definitionModuleName = t.definitionModuleName; if (!FFlag::LuauNoMethodLocations) ttv->methodDefinitionLocations = t.methodDefinitionLocations; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 24ed4ac..f443a3c 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -2,7 +2,6 @@ #include "Luau/Error.h" #include "Luau/Clone.h" -#include "Luau/Module.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" @@ -178,15 +177,7 @@ struct ErrorConverter std::string operator()(const Luau::FunctionRequiresSelf& e) const { - if (e.requiredExtraNils) - { - const char* plural = e.requiredExtraNils == 1 ? "" : "s"; - return format("This function was declared to accept self, but you did not pass enough arguments. Use a colon instead of a dot or " - "pass %i extra nil%s to suppress this warning", - e.requiredExtraNils, plural); - } - else - return "This function must be called with self. Did you mean to use a colon instead of a dot?"; + return "This function must be called with self. Did you mean to use a colon instead of a dot?"; } std::string operator()(const Luau::OccursCheckFailed&) const @@ -539,7 +530,7 @@ bool FunctionDoesNotTakeSelf::operator==(const FunctionDoesNotTakeSelf&) const bool FunctionRequiresSelf::operator==(const FunctionRequiresSelf& e) const { - return requiredExtraNils == e.requiredExtraNils; + return true; } bool OccursCheckFailed::operator==(const OccursCheckFailed&) const diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 0eaa485..048167a 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -48,7 +48,7 @@ static void errorToString(std::ostream& stream, const T& err) else if constexpr (std::is_same_v) stream << "FunctionDoesNotTakeSelf { }"; else if constexpr (std::is_same_v) - stream << "FunctionRequiresSelf { extraNils " << err.requiredExtraNils << " }"; + stream << "FunctionRequiresSelf { }"; else if constexpr (std::is_same_v) stream << "OccursCheckFailed { }"; else if constexpr (std::is_same_v) diff --git a/Analysis/src/LValue.cpp b/Analysis/src/LValue.cpp index 72555ab..38dfe1a 100644 --- a/Analysis/src/LValue.cpp +++ b/Analysis/src/LValue.cpp @@ -5,8 +5,6 @@ #include -LUAU_FASTFLAG(LuauTypecheckOptPass) - namespace Luau { @@ -79,27 +77,8 @@ std::optional tryGetLValue(const AstExpr& node) return std::nullopt; } -std::pair> getFullName(const LValue& lvalue) -{ - LUAU_ASSERT(!FFlag::LuauTypecheckOptPass); - - const LValue* current = &lvalue; - std::vector keys; - while (auto field = get(*current)) - { - keys.push_back(field->key); - current = baseof(*current); - } - - const Symbol* symbol = get(*current); - LUAU_ASSERT(symbol); - return {*symbol, std::vector(keys.rbegin(), keys.rend())}; -} - Symbol getBaseSymbol(const LValue& lvalue) { - LUAU_ASSERT(FFlag::LuauTypecheckOptPass); - const LValue* current = &lvalue; while (auto field = get(*current)) current = baseof(*current); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index bafd437..074a41e 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -13,9 +13,8 @@ #include -LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) -LUAU_FASTFLAG(LuauLowerBoundsCalculation) -LUAU_FASTFLAG(LuauLosslessClone) +LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(LuauNormalizeFlagIsConservative); namespace Luau { @@ -55,89 +54,25 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos) return contains(pos, *iter); } -void TypeArena::clear() +struct ForceNormal : TypeVarOnceVisitor { - typeVars.clear(); - typePacks.clear(); -} + bool visit(TypeId ty) override + { + asMutable(ty)->normal = true; + return true; + } -TypeId TypeArena::addTV(TypeVar&& tv) -{ - TypeId allocated = typeVars.allocate(std::move(tv)); + bool visit(TypeId ty, const FreeTypeVar& ftv) override + { + visit(ty); + return true; + } - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypeId TypeArena::freshType(TypeLevel level) -{ - TypeId allocated = typeVars.allocate(FreeTypeVar{level}); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(std::initializer_list types) -{ - TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(std::vector types) -{ - TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(TypePack tp) -{ - TypePackId allocated = typePacks.allocate(std::move(tp)); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -TypePackId TypeArena::addTypePack(TypePackVar tp) -{ - TypePackId allocated = typePacks.allocate(std::move(tp)); - - asMutable(allocated)->owningArena = this; - - return allocated; -} - -ScopePtr Module::getModuleScope() const -{ - LUAU_ASSERT(!scopes.empty()); - return scopes.front().second; -} - -void freeze(TypeArena& arena) -{ - if (!FFlag::DebugLuauFreezeArena) - return; - - arena.typeVars.freeze(); - arena.typePacks.freeze(); -} - -void unfreeze(TypeArena& arena) -{ - if (!FFlag::DebugLuauFreezeArena) - return; - - arena.typeVars.unfreeze(); - arena.typePacks.unfreeze(); -} + bool visit(TypePackId tp, const FreeTypePack& ftp) override + { + return true; + } +}; Module::~Module() { @@ -145,7 +80,7 @@ Module::~Module() unfreeze(internalTypes); } -bool Module::clonePublicInterface(InternalErrorReporter& ice) +void Module::clonePublicInterface(InternalErrorReporter& ice) { LUAU_ASSERT(interfaceTypes.typeVars.empty()); LUAU_ASSERT(interfaceTypes.typePacks.empty()); @@ -165,11 +100,22 @@ bool Module::clonePublicInterface(InternalErrorReporter& ice) normalize(*moduleScope->varargPack, interfaceTypes, ice); } + ForceNormal forceNormal; + for (auto& [name, tf] : moduleScope->exportedTypeBindings) { tf = clone(tf, interfaceTypes, cloneState); if (FFlag::LuauLowerBoundsCalculation) + { normalize(tf.type, interfaceTypes, ice); + + if (FFlag::LuauNormalizeFlagIsConservative) + { + // We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables + // won't be marked normal. If the types aren't normal by now, they never will be. + forceNormal.traverse(tf.type); + } + } } for (TypeId ty : moduleScope->returnType) @@ -191,11 +137,12 @@ bool Module::clonePublicInterface(InternalErrorReporter& ice) freeze(internalTypes); freeze(interfaceTypes); +} - if (FFlag::LuauLosslessClone) - return false; // TODO: make function return void. - else - return cloneState.encounteredFreeType; +ScopePtr Module::getModuleScope() const +{ + LUAU_ASSERT(!scopes.empty()); + return scopes.front().second; } } // namespace Luau diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index ef5377a..30fd4af 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); namespace Luau { @@ -260,8 +261,13 @@ static bool areNormal_(const T& t, const std::unordered_set& seen, Intern if (count >= FInt::LuauNormalizeIterationLimit) ice.ice("Luau::areNormal hit iteration limit"); - // The follow is here because a bound type may not be normal, but the bound type is normal. - return ty->normal || follow(ty)->normal || seen.find(asMutable(ty)) != seen.end(); + if (FFlag::LuauNormalizeFlagIsConservative) + return ty->normal; + else + { + // The follow is here because a bound type may not be normal, but the bound type is normal. + return ty->normal || follow(ty)->normal || seen.find(asMutable(ty)) != seen.end(); + } }; return std::all_of(begin(t), end(t), isNormal); @@ -1003,8 +1009,15 @@ std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorRepo (void)clone(ty, arena, state); Normalize n{arena, ice}; - std::unordered_set seen; - DEPRECATED_visitTypeVar(ty, n, seen); + if (FFlag::LuauNormalizeFlagIsConservative) + { + DEPRECATED_visitTypeVar(ty, n); + } + else + { + std::unordered_set seen; + DEPRECATED_visitTypeVar(ty, n, seen); + } return {ty, !n.limitExceeded}; } @@ -1028,8 +1041,15 @@ std::pair normalize(TypePackId tp, TypeArena& arena, InternalE (void)clone(tp, arena, state); Normalize n{arena, ice}; - std::unordered_set seen; - DEPRECATED_visitTypeVar(tp, n, seen); + if (FFlag::LuauNormalizeFlagIsConservative) + { + DEPRECATED_visitTypeVar(tp, n); + } + else + { + std::unordered_set seen; + DEPRECATED_visitTypeVar(tp, n, seen); + } return {tp, !n.limitExceeded}; } diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 4f3e446..018d563 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -4,7 +4,7 @@ #include "Luau/VisitTypeVar.h" -LUAU_FASTFLAG(LuauTypecheckOptPass) +LUAU_FASTFLAG(LuauAlwaysQuantify) namespace Luau { @@ -59,8 +59,7 @@ struct Quantifier final : TypeVarOnceVisitor bool visit(TypeId ty, const FreeTypeVar& ftv) override { - if (FFlag::LuauTypecheckOptPass) - seenMutableType = true; + seenMutableType = true; if (!level.subsumes(ftv.level)) return false; @@ -76,20 +75,17 @@ struct Quantifier final : TypeVarOnceVisitor LUAU_ASSERT(getMutable(ty)); TableTypeVar& ttv = *getMutable(ty); - if (FFlag::LuauTypecheckOptPass) - { - if (ttv.state == TableState::Generic) - seenGenericType = true; + if (ttv.state == TableState::Generic) + seenGenericType = true; - if (ttv.state == TableState::Free) - seenMutableType = true; - } + if (ttv.state == TableState::Free) + seenMutableType = true; if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) return false; if (!level.subsumes(ttv.level)) { - if (FFlag::LuauTypecheckOptPass && ttv.state == TableState::Unsealed) + if (ttv.state == TableState::Unsealed) seenMutableType = true; return false; } @@ -97,9 +93,7 @@ struct Quantifier final : TypeVarOnceVisitor if (ttv.state == TableState::Free) { ttv.state = TableState::Generic; - - if (FFlag::LuauTypecheckOptPass) - seenGenericType = true; + seenGenericType = true; } else if (ttv.state == TableState::Unsealed) ttv.state = TableState::Sealed; @@ -111,8 +105,7 @@ struct Quantifier final : TypeVarOnceVisitor bool visit(TypePackId tp, const FreeTypePack& ftp) override { - if (FFlag::LuauTypecheckOptPass) - seenMutableType = true; + seenMutableType = true; if (!level.subsumes(ftp.level)) return false; @@ -131,10 +124,18 @@ void quantify(TypeId ty, TypeLevel level) FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); - ftv->generics = q.generics; - ftv->genericPacks = q.genericPacks; + if (FFlag::LuauAlwaysQuantify) + { + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + } + else + { + ftv->generics = q.generics; + ftv->genericPacks = q.genericPacks; + } - if (FFlag::LuauTypecheckOptPass && ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) + if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) ftv->hasNoGenerics = true; } diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index c5c7977..e40bedb 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -9,9 +9,6 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) -LUAU_FASTFLAG(LuauTypecheckOptPass) -LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowNewTypes, false) -LUAU_FASTFLAGVARIABLE(LuauSubstituteFollowPossibleMutations, false) LUAU_FASTFLAG(LuauNoMethodLocations) namespace Luau @@ -19,26 +16,20 @@ namespace Luau void Tarjan::visitChildren(TypeId ty, int index) { - if (FFlag::LuauTypecheckOptPass) - LUAU_ASSERT(ty == log->follow(ty)); - else - ty = log->follow(ty); + LUAU_ASSERT(ty == log->follow(ty)); if (ignoreChildren(ty)) return; - if (FFlag::LuauTypecheckOptPass) - { - if (auto pty = log->pending(ty)) - ty = &pty->pending; - } + if (auto pty = log->pending(ty)) + ty = &pty->pending; - if (const FunctionTypeVar* ftv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + if (const FunctionTypeVar* ftv = get(ty)) { visitChild(ftv->argTypes); visitChild(ftv->retType); } - else if (const TableTypeVar* ttv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const TableTypeVar* ttv = get(ty)) { LUAU_ASSERT(!ttv->boundTo); for (const auto& [name, prop] : ttv->props) @@ -55,17 +46,17 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypePackId itp : ttv->instantiatedTypePackParams) visitChild(itp); } - else if (const MetatableTypeVar* mtv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const MetatableTypeVar* mtv = get(ty)) { visitChild(mtv->table); visitChild(mtv->metatable); } - else if (const UnionTypeVar* utv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const UnionTypeVar* utv = get(ty)) { for (TypeId opt : utv->options) visitChild(opt); } - else if (const IntersectionTypeVar* itv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const IntersectionTypeVar* itv = get(ty)) { for (TypeId part : itv->parts) visitChild(part); @@ -79,28 +70,22 @@ void Tarjan::visitChildren(TypeId ty, int index) void Tarjan::visitChildren(TypePackId tp, int index) { - if (FFlag::LuauTypecheckOptPass) - LUAU_ASSERT(tp == log->follow(tp)); - else - tp = log->follow(tp); + LUAU_ASSERT(tp == log->follow(tp)); if (ignoreChildren(tp)) return; - if (FFlag::LuauTypecheckOptPass) - { - if (auto ptp = log->pending(tp)) - tp = &ptp->pending; - } + if (auto ptp = log->pending(tp)) + tp = &ptp->pending; - if (const TypePack* tpp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) + if (const TypePack* tpp = get(tp)) { for (TypeId tv : tpp->head) visitChild(tv); if (tpp->tail) visitChild(*tpp->tail); } - else if (const VariadicTypePack* vtp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) + else if (const VariadicTypePack* vtp = get(tp)) { visitChild(vtp->ty); } @@ -108,10 +93,7 @@ void Tarjan::visitChildren(TypePackId tp, int index) std::pair Tarjan::indexify(TypeId ty) { - if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) - LUAU_ASSERT(ty == log->follow(ty)); - else - ty = log->follow(ty); + ty = log->follow(ty); bool fresh = !typeToIndex.contains(ty); int& index = typeToIndex[ty]; @@ -129,10 +111,7 @@ std::pair Tarjan::indexify(TypeId ty) std::pair Tarjan::indexify(TypePackId tp) { - if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) - LUAU_ASSERT(tp == log->follow(tp)); - else - tp = log->follow(tp); + tp = log->follow(tp); bool fresh = !packToIndex.contains(tp); int& index = packToIndex[tp]; @@ -150,8 +129,7 @@ std::pair Tarjan::indexify(TypePackId tp) void Tarjan::visitChild(TypeId ty) { - if (!FFlag::LuauSubstituteFollowPossibleMutations) - ty = log->follow(ty); + ty = log->follow(ty); edgesTy.push_back(ty); edgesTp.push_back(nullptr); @@ -159,8 +137,7 @@ void Tarjan::visitChild(TypeId ty) void Tarjan::visitChild(TypePackId tp) { - if (!FFlag::LuauSubstituteFollowPossibleMutations) - tp = log->follow(tp); + tp = log->follow(tp); edgesTy.push_back(nullptr); edgesTp.push_back(tp); @@ -389,13 +366,10 @@ TypeId Substitution::clone(TypeId ty) TypeId result = ty; - if (FFlag::LuauTypecheckOptPass) - { - if (auto pty = log->pending(ty)) - ty = &pty->pending; - } + if (auto pty = log->pending(ty)) + ty = &pty->pending; - if (const FunctionTypeVar* ftv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + if (const FunctionTypeVar* ftv = get(ty)) { FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; clone.generics = ftv->generics; @@ -405,7 +379,7 @@ TypeId Substitution::clone(TypeId ty) clone.argNames = ftv->argNames; result = addType(std::move(clone)); } - else if (const TableTypeVar* ttv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const TableTypeVar* ttv = get(ty)) { LUAU_ASSERT(!ttv->boundTo); TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -419,19 +393,19 @@ TypeId Substitution::clone(TypeId ty) clone.tags = ttv->tags; result = addType(std::move(clone)); } - else if (const MetatableTypeVar* mtv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const MetatableTypeVar* mtv = get(ty)) { MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; clone.syntheticName = mtv->syntheticName; result = addType(std::move(clone)); } - else if (const UnionTypeVar* utv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const UnionTypeVar* utv = get(ty)) { UnionTypeVar clone; clone.options = utv->options; result = addType(std::move(clone)); } - else if (const IntersectionTypeVar* itv = FFlag::LuauTypecheckOptPass ? get(ty) : log->getMutable(ty)) + else if (const IntersectionTypeVar* itv = get(ty)) { IntersectionTypeVar clone; clone.parts = itv->parts; @@ -451,20 +425,17 @@ TypePackId Substitution::clone(TypePackId tp) { tp = log->follow(tp); - if (FFlag::LuauTypecheckOptPass) - { - if (auto ptp = log->pending(tp)) - tp = &ptp->pending; - } + if (auto ptp = log->pending(tp)) + tp = &ptp->pending; - if (const TypePack* tpp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) + if (const TypePack* tpp = get(tp)) { TypePack clone; clone.head = tpp->head; clone.tail = tpp->tail; return addTypePack(std::move(clone)); } - else if (const VariadicTypePack* vtp = FFlag::LuauTypecheckOptPass ? get(tp) : log->getMutable(tp)) + else if (const VariadicTypePack* vtp = get(tp)) { VariadicTypePack clone; clone.ty = vtp->ty; @@ -476,28 +447,22 @@ TypePackId Substitution::clone(TypePackId tp) void Substitution::foundDirty(TypeId ty) { - if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) - LUAU_ASSERT(ty == log->follow(ty)); - else - ty = log->follow(ty); + ty = log->follow(ty); if (isDirty(ty)) - newTypes[ty] = FFlag::LuauSubstituteFollowNewTypes ? follow(clean(ty)) : clean(ty); + newTypes[ty] = follow(clean(ty)); else - newTypes[ty] = FFlag::LuauSubstituteFollowNewTypes ? follow(clone(ty)) : clone(ty); + newTypes[ty] = follow(clone(ty)); } void Substitution::foundDirty(TypePackId tp) { - if (FFlag::LuauTypecheckOptPass && !FFlag::LuauSubstituteFollowPossibleMutations) - LUAU_ASSERT(tp == log->follow(tp)); - else - tp = log->follow(tp); + tp = log->follow(tp); if (isDirty(tp)) - newPacks[tp] = FFlag::LuauSubstituteFollowNewTypes ? follow(clean(tp)) : clean(tp); + newPacks[tp] = follow(clean(tp)); else - newPacks[tp] = FFlag::LuauSubstituteFollowNewTypes ? follow(clone(tp)) : clone(tp); + newPacks[tp] = follow(clone(tp)); } TypeId Substitution::replace(TypeId ty) @@ -525,10 +490,7 @@ void Substitution::replaceChildren(TypeId ty) if (BoundTypeVar* btv = log->getMutable(ty); FFlag::LuauLowerBoundsCalculation && btv) btv->boundTo = replace(btv->boundTo); - if (FFlag::LuauTypecheckOptPass) - LUAU_ASSERT(ty == log->follow(ty)); - else - ty = log->follow(ty); + LUAU_ASSERT(ty == log->follow(ty)); if (ignoreChildren(ty)) return; @@ -579,10 +541,7 @@ void Substitution::replaceChildren(TypeId ty) void Substitution::replaceChildren(TypePackId tp) { - if (FFlag::LuauTypecheckOptPass) - LUAU_ASSERT(tp == log->follow(tp)); - else - tp = log->follow(tp); + LUAU_ASSERT(tp == log->follow(tp)); if (ignoreChildren(tp)) return; diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 51665f7..f90f701 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -219,6 +219,8 @@ struct StringifierState return generateName(s); } + int previousNameIndex = 0; + std::string getName(TypePackId ty) { const size_t s = result.nameMap.typePacks.size(); @@ -228,9 +230,10 @@ struct StringifierState for (int count = 0; count < 256; ++count) { - std::string candidate = generateName(usedNames.size() + count); + std::string candidate = generateName(previousNameIndex + count); if (!usedNames.count(candidate)) { + previousNameIndex += count; usedNames.insert(candidate); n = candidate; return candidate; @@ -399,6 +402,7 @@ struct TypeVarStringifier { if (gtv.explicitName) { + state.usedNames.insert(gtv.name); state.result.nameMap.typeVars[ty] = gtv.name; state.emit(gtv.name); } @@ -745,7 +749,10 @@ struct TypeVarStringifier for (std::string& ss : results) { if (!first) - state.emit(" | "); + { + state.newline(); + state.emit("| "); + } state.emit(ss); first = false; } @@ -798,7 +805,10 @@ struct TypeVarStringifier for (std::string& ss : results) { if (!first) - state.emit(" & "); + { + state.newline(); + state.emit("& "); + } state.emit(ss); first = false; } @@ -937,6 +947,7 @@ struct TypePackStringifier state.emit("gen-"); if (pack.explicitName) { + state.usedNames.insert(pack.name); state.result.nameMap.typePacks[tp] = pack.name; state.emit(pack.name); } @@ -1230,6 +1241,14 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp size_t idx = 0; while (argPackIter != end(ftv.argTypes)) { + // ftv takes a self parameter as the first argument, skip it if specified in option + if (idx == 0 && ftv.hasSelf && opts.hideFunctionSelfArgument) + { + ++argPackIter; + ++idx; + continue; + } + if (!first) state.emit(", "); first = false; diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 1fb5a61..e45c0cb 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,8 +7,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauJustOneCallFrameForHaveSeen, false) - namespace Luau { @@ -150,37 +148,13 @@ void TxnLog::popSeen(TypePackId lhs, TypePackId rhs) bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const { - if (FFlag::LuauJustOneCallFrameForHaveSeen && !FFlag::LuauTypecheckOptPass) + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) { - // This function will technically work if `this` is nullptr, but this - // indicates a bug, so we explicitly assert. - LUAU_ASSERT(static_cast(this) != nullptr); - - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - - for (const TxnLog* current = this; current; current = current->parent) - { - if (current->sharedSeen->end() != std::find(current->sharedSeen->begin(), current->sharedSeen->end(), sortedPair)) - return true; - } - - return false; + return true; } - else - { - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) - { - return true; - } - if (!FFlag::LuauTypecheckOptPass && parent) - { - return parent->haveSeen(lhs, rhs); - } - - return false; - } + return false; } void TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs) diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp new file mode 100644 index 0000000..673b002 --- /dev/null +++ b/Analysis/src/TypeArena.cpp @@ -0,0 +1,88 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeArena.h" + +LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false); + +namespace Luau +{ + +void TypeArena::clear() +{ + typeVars.clear(); + typePacks.clear(); +} + +TypeId TypeArena::addTV(TypeVar&& tv) +{ + TypeId allocated = typeVars.allocate(std::move(tv)); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType(TypeLevel level) +{ + TypeId allocated = typeVars.allocate(FreeTypeVar{level}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(std::initializer_list types) +{ + TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(std::vector types) +{ + TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(TypePack tp) +{ + TypePackId allocated = typePacks.allocate(std::move(tp)); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(TypePackVar tp) +{ + TypePackId allocated = typePacks.allocate(std::move(tp)); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +void freeze(TypeArena& arena) +{ + if (!FFlag::DebugLuauFreezeArena) + return; + + arena.typeVars.freeze(); + arena.typePacks.freeze(); +} + +void unfreeze(TypeArena& arena) +{ + if (!FFlag::DebugLuauFreezeArena) + return; + + arena.typeVars.unfreeze(); + arena.typePacks.unfreeze(); +} + +} diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index a13abd5..208b3f2 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,32 +32,25 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauSeparateTypechecks) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAGVARIABLE(LuauDoNotRelyOnNextBinding, false) -LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) +LUAU_FASTFLAGVARIABLE(LuauExpectedPropTypeFromIndexer, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauInstantiateFollows, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) -LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) -LUAU_FASTFLAGVARIABLE(LuauTypecheckOptPass, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) -LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. +LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree2) -LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) -LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) -LUAU_FASTFLAGVARIABLE(LuauCheckImplicitNumbericKeys, false) -LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) -LUAU_FASTFLAGVARIABLE(LuauTableUseCounterInstead, false) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); -LUAU_FASTFLAG(LuauLosslessClone) +LUAU_FASTFLAGVARIABLE(LuauApplyTypeFunctionFix, false); LUAU_FASTFLAGVARIABLE(LuauTypecheckIter, false); LUAU_FASTFLAGVARIABLE(LuauSuccessTypingForEqualityOperations, false) LUAU_FASTFLAGVARIABLE(LuauNoMethodLocations, false); +LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); namespace Luau { @@ -371,12 +364,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo prepareErrorsForDisplay(currentModule->errors); - bool encounteredFreeType = currentModule->clonePublicInterface(*iceHandler); - if (!FFlag::LuauLosslessClone && encounteredFreeType) - { - reportError(TypeError{module.root->location, - GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}}); - } + currentModule->clonePublicInterface(*iceHandler); // Clear unifier cache since it's keyed off internal types that get deallocated // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. @@ -701,7 +689,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) ExprResult result = checkExpr(scope, *statement.condition); ScopePtr ifScope = childScope(scope, statement.thenbody->location); - reportErrors(resolve(result.predicates, ifScope, true)); + resolve(result.predicates, ifScope, true); check(ifScope, *statement.thenbody); if (statement.elsebody) @@ -734,7 +722,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) ExprResult result = checkExpr(scope, *statement.condition); ScopePtr whileScope = childScope(scope, statement.body->location); - reportErrors(resolve(result.predicates, whileScope, true)); + resolve(result.predicates, whileScope, true); check(whileScope, *statement.body); } @@ -1154,10 +1142,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) } else { - if (FFlag::LuauInstantiateFollows) - iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); - else - iterTy = follow(instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location)); + iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); } if (FFlag::LuauTypecheckIter) @@ -1849,23 +1834,11 @@ std::optional TypeChecker::getIndexTypeFromType( tablify(type); - if (FFlag::LuauDiscriminableUnions2) + if (isString(type)) { - if (isString(type)) - { - std::optional mtIndex = findMetatableEntry(stringType, "__index", location); - LUAU_ASSERT(mtIndex); - type = *mtIndex; - } - } - else - { - const PrimitiveTypeVar* primitiveType = get(type); - if (primitiveType && primitiveType->type == PrimitiveTypeVar::String) - { - if (std::optional mtIndex = findMetatableEntry(type, "__index", location)) - type = *mtIndex; - } + std::optional mtIndex = findMetatableEntry(stringType, "__index", location); + LUAU_ASSERT(mtIndex); + type = *mtIndex; } if (TableTypeVar* tableType = getMutableTableType(type)) @@ -1966,23 +1939,10 @@ std::optional TypeChecker::getIndexTypeFromType( return std::nullopt; } - if (FFlag::LuauDoNotTryToReduce) - { - if (parts.size() == 1) - return parts[0]; + if (parts.size() == 1) + return parts[0]; - return addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct. - } - else - { - // TODO(amccord): Write some logic to correctly handle intersections. CLI-34659 - std::vector result = reduceUnion(parts); - - if (result.size() == 1) - return result[0]; - - return addType(IntersectionTypeVar{result}); - } + return addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct. } if (addErrors) @@ -1993,103 +1953,55 @@ std::optional TypeChecker::getIndexTypeFromType( std::vector TypeChecker::reduceUnion(const std::vector& types) { - if (FFlag::LuauDoNotAccidentallyDependOnPointerOrdering) + std::vector result; + for (TypeId t : types) { - std::vector result; - for (TypeId t : types) + t = follow(t); + if (get(t) || get(t)) + return {t}; + + if (const UnionTypeVar* utv = get(t)) { - t = follow(t); - if (get(t) || get(t)) - return {t}; - - if (const UnionTypeVar* utv = get(t)) + if (FFlag::LuauReduceUnionRecursion) { - if (FFlag::LuauReduceUnionRecursion) + for (TypeId ty : utv) { - for (TypeId ty : utv) - { - if (get(ty) || get(ty)) - return {ty}; - - if (result.end() == std::find(result.begin(), result.end(), ty)) - result.push_back(ty); - } - } - else - { - std::vector r = reduceUnion(utv->options); - for (TypeId ty : r) - { + if (FFlag::LuauNormalizeFlagIsConservative) ty = follow(ty); - if (get(ty) || get(ty)) - return {ty}; + if (get(ty) || get(ty)) + return {ty}; - if (std::find(result.begin(), result.end(), ty) == result.end()) - result.push_back(ty); - } + if (result.end() == std::find(result.begin(), result.end(), ty)) + result.push_back(ty); } } - else if (std::find(result.begin(), result.end(), t) == result.end()) - result.push_back(t); - } - - return result; - } - else - { - std::set s; - - for (TypeId t : types) - { - if (const UnionTypeVar* utv = get(follow(t))) + else { std::vector r = reduceUnion(utv->options); for (TypeId ty : r) - s.insert(ty); + { + ty = follow(ty); + if (get(ty) || get(ty)) + return {ty}; + + if (std::find(result.begin(), result.end(), ty) == result.end()) + result.push_back(ty); + } } - else - s.insert(t); } - - // If any of them are ErrorTypeVars/AnyTypeVars, decay into them. - for (TypeId t : s) - { - t = follow(t); - if (get(t) || get(t)) - return {t}; - } - - std::vector r(s.begin(), s.end()); - std::sort(r.begin(), r.end()); - return r; + else if (std::find(result.begin(), result.end(), t) == result.end()) + result.push_back(t); } + + return result; } std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) { if (const UnionTypeVar* utv = get(ty)) { - if (FFlag::LuauAnyInIsOptionalIsOptional) - { - if (!std::any_of(begin(utv), end(utv), isNil)) - return ty; - } - else - { - bool hasNil = false; - - for (TypeId option : utv) - { - if (isNil(option)) - { - hasNil = true; - break; - } - } - - if (!hasNil) - return ty; - } + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; std::vector result; @@ -2110,32 +2022,18 @@ std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) { - if (FFlag::LuauAnyInIsOptionalIsOptional) + ty = follow(ty); + + if (auto utv = get(ty)) { - ty = follow(ty); - - if (auto utv = get(ty)) - { - if (!std::any_of(begin(utv), end(utv), isNil)) - return ty; - } - - if (std::optional strippedUnion = tryStripUnionFromNil(ty)) - { - reportError(location, OptionalValueAccess{ty}); - return follow(*strippedUnion); - } + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; } - else + + if (std::optional strippedUnion = tryStripUnionFromNil(ty)) { - if (isOptional(ty)) - { - if (std::optional strippedUnion = tryStripUnionFromNil(follow(ty))) - { - reportError(location, OptionalValueAccess{ty}); - return follow(*strippedUnion); - } - } + reportError(location, OptionalValueAccess{ty}); + return follow(*strippedUnion); } return ty; @@ -2194,8 +2092,7 @@ TypeId TypeChecker::checkExprTable( if (indexer) { - if (FFlag::LuauCheckImplicitNumbericKeys) - unify(numberType, indexer->indexType, value->location); + unify(numberType, indexer->indexType, value->location); unify(valueType, indexer->indexResultType, value->location); } else @@ -2219,7 +2116,8 @@ TypeId TypeChecker::checkExprTable( if (errors.empty()) exprType = expectedProp.type; } - else if (expectedTable->indexer && isString(expectedTable->indexer->indexType)) + else if (expectedTable->indexer && (FFlag::LuauExpectedPropTypeFromIndexer ? maybeString(expectedTable->indexer->indexType) + : isString(expectedTable->indexer->indexType))) { ErrorVec errors = tryUnify(exprType, expectedTable->indexer->indexResultType, k->location); if (errors.empty()) @@ -2259,26 +2157,13 @@ TypeId TypeChecker::checkExprTable( ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) { - if (FFlag::LuauTableUseCounterInstead) + RecursionCounter _rc(&checkRecursionCount); + if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { - RecursionCounter _rc(&checkRecursionCount); - if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) - { - reportErrorCodeTooComplex(expr.location); - return {errorRecoveryType(scope)}; - } - - return checkExpr_(scope, expr, expectedType); + reportErrorCodeTooComplex(expr.location); + return {errorRecoveryType(scope)}; } - else - { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "checkExpr for tables"); - return checkExpr_(scope, expr, expectedType); - } -} -ExprResult TypeChecker::checkExpr_(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) -{ std::vector> fieldTypes(expr.items.size); const TableTypeVar* expectedTable = nullptr; @@ -2324,6 +2209,8 @@ ExprResult TypeChecker::checkExpr_(const ScopePtr& scope, const AstExprT { if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) expectedResultType = prop->second.type; + else if (FFlag::LuauExpectedPropTypeFromIndexer && expectedIndexType && maybeString(*expectedIndexType)) + expectedResultType = expectedIndexResultType; } else if (expectedUnion) { @@ -2529,7 +2416,7 @@ TypeId TypeChecker::checkRelationalOperation( if (expr.op == AstExprBinary::Or && subexp->op == AstExprBinary::And) { ScopePtr subScope = childScope(scope, subexp->location); - reportErrors(resolve(predicates, subScope, true)); + resolve(predicates, subScope, true); return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), expr.location); } } @@ -2851,8 +2738,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); - return {checkBinaryOperation(FFlag::LuauDiscriminableUnions2 ? scope : innerScope, expr, lhsTy, rhsTy), - {AndPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; + return {checkBinaryOperation(scope, expr, lhsTy, rhsTy), {AndPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::Or) { @@ -2864,7 +2750,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); // Because of C++, I'm not sure if lhsPredicates was not moved out by the time we call checkBinaryOperation. - TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions2 ? scope : innerScope, expr, lhsTy, rhsTy, lhsPredicates); + TypeId result = checkBinaryOperation(scope, expr, lhsTy, rhsTy, lhsPredicates); return {result, {OrPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) @@ -2872,8 +2758,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; - ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions2); - ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions2); + ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); + ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); PredicateVec predicates; @@ -2931,12 +2817,12 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) { ExprResult result = checkExpr(scope, *expr.condition); + ScopePtr trueScope = childScope(scope, expr.trueExpr->location); - reportErrors(resolve(result.predicates, trueScope, true)); + resolve(result.predicates, trueScope, true); ExprResult trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); ScopePtr falseScope = childScope(scope, expr.falseExpr->location); - // Don't report errors for this scope to avoid potentially duplicating errors reported for the first scope. resolve(result.predicates, falseScope, false); ExprResult falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); @@ -3668,9 +3554,6 @@ void TypeChecker::checkArgumentList( else if (state.log.getMutable(t)) { } // ok - else if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && state.log.get(t)) - { - } // ok else { size_t minParams = getMinParameterCount(&state.log, paramPack); @@ -3823,9 +3706,6 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A actualFunctionType = instantiate(scope, functionType, expr.func->location); } - if (!FFlag::LuauInstantiateFollows) - actualFunctionType = follow(actualFunctionType); - TypePackId retPack; if (FFlag::LuauLowerBoundsCalculation || !FFlag::LuauWidenIfSupertypeIsFree2) { @@ -4096,32 +3976,6 @@ std::optional> TypeChecker::checkCallOverload(const Scope { state.log.commit(); - if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && !expr.self && expr.func->is() && ftv->hasSelf) - { - // If we are running in nonstrict mode, passing fewer arguments than the function is declared to take AND - // the function is declared with colon notation AND we use dot notation, warn. - auto [providedArgs, providedTail] = flatten(argPack); - - // If we have a variadic tail, we can't say how many arguments were actually provided - if (!providedTail) - { - std::vector actualArgs = flatten(ftv->argTypes).first; - - size_t providedCount = providedArgs.size(); - size_t requiredCount = actualArgs.size(); - - // Ignore optional arguments - while (providedCount < requiredCount && requiredCount != 0 && isOptional(actualArgs[requiredCount - 1])) - requiredCount--; - - if (providedCount < requiredCount) - { - int requiredExtraNils = int(requiredCount - providedCount); - reportError(TypeError{expr.func->location, FunctionRequiresSelf{requiredExtraNils}}); - } - } - } - currentModule->astOverloadResolvedTypes[&expr] = fn; // We select this overload @@ -4525,7 +4379,7 @@ bool Instantiation::isDirty(TypeId ty) { if (const FunctionTypeVar* ftv = log->getMutable(ty)) { - if (FFlag::LuauTypecheckOptPass && ftv->hasNoGenerics) + if (ftv->hasNoGenerics) return false; return true; @@ -4582,7 +4436,7 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty) { if (const FunctionTypeVar* ftv = log->getMutable(ty)) { - if (FFlag::LuauTypecheckOptPass && ftv->hasNoGenerics) + if (ftv->hasNoGenerics) return true; // We aren't recursing in the case of a generic function which @@ -4701,8 +4555,17 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location ty = follow(ty); const FunctionTypeVar* ftv = get(ty); - if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) - Luau::quantify(ty, scope->level); + + if (FFlag::LuauAlwaysQuantify) + { + if (ftv) + Luau::quantify(ty, scope->level); + } + else + { + if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) + Luau::quantify(ty, scope->level); + } if (FFlag::LuauLowerBoundsCalculation && ftv) { @@ -4717,15 +4580,11 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { - if (FFlag::LuauInstantiateFollows) - ty = follow(ty); + ty = follow(ty); - if (FFlag::LuauTypecheckOptPass) - { - const FunctionTypeVar* ftv = get(FFlag::LuauInstantiateFollows ? ty : follow(ty)); - if (ftv && ftv->hasNoGenerics) - return ty; - } + const FunctionTypeVar* ftv = get(ty); + if (ftv && ftv->hasNoGenerics) + return ty; Instantiation instantiation{log, ¤tModule->internalTypes, scope->level}; @@ -5392,10 +5251,9 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack bool ApplyTypeFunction::isDirty(TypeId ty) { - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics. - if (get(ty)) + if (FFlag::LuauApplyTypeFunctionFix && typeArguments.count(ty)) + return true; + else if (!FFlag::LuauApplyTypeFunctionFix && get(ty)) return true; else if (const FreeTypeVar* ftv = get(ty)) { @@ -5409,10 +5267,9 @@ bool ApplyTypeFunction::isDirty(TypeId ty) bool ApplyTypeFunction::isDirty(TypePackId tp) { - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics. - if (get(tp)) + if (FFlag::LuauApplyTypeFunctionFix && typePackArguments.count(tp)) + return true; + else if (!FFlag::LuauApplyTypeFunctionFix && get(tp)) return true; else return false; @@ -5436,11 +5293,13 @@ bool ApplyTypeFunction::ignoreChildren(TypePackId tp) TypeId ApplyTypeFunction::clean(TypeId ty) { - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics by free type variables. TypeId& arg = typeArguments[ty]; - if (arg) + if (FFlag::LuauApplyTypeFunctionFix) + { + LUAU_ASSERT(arg); + return arg; + } + else if (arg) return arg; else return addType(FreeTypeVar{level}); @@ -5448,11 +5307,13 @@ TypeId ApplyTypeFunction::clean(TypeId ty) TypePackId ApplyTypeFunction::clean(TypePackId tp) { - // Really this should just replace the arguments, - // but for bug-compatibility with existing code, we replace - // all generics by free type variables. TypePackId& arg = typePackArguments[tp]; - if (arg) + if (FFlag::LuauApplyTypeFunctionFix) + { + LUAU_ASSERT(arg); + return arg; + } + else if (arg) return arg; else return addTypePack(FreeTypePack{level}); @@ -5596,8 +5457,6 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate) { - LUAU_ASSERT(FFlag::LuauDiscriminableUnions2 || FFlag::LuauAssertStripsFalsyTypes); - const LValue* target = &lvalue; std::optional key; // If set, we know we took the base of the lvalue path and should be walking down each option of the base's type. @@ -5683,66 +5542,6 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV // We need to search in the provided Scope. Find t.x.y first. // We fail to find t.x.y. Try t.x. We found it. Now we must return the type of the property y from the mapped-to type of t.x. // If we completely fail to find the Symbol t but the Scope has that entry, then we should walk that all the way through and terminate. - if (!FFlag::LuauTypecheckOptPass) - { - const auto& [symbol, keys] = getFullName(lvalue); - - ScopePtr currentScope = scope; - while (currentScope) - { - std::optional found; - - std::vector childKeys; - const LValue* currentLValue = &lvalue; - while (currentLValue) - { - if (auto it = currentScope->refinements.find(*currentLValue); it != currentScope->refinements.end()) - { - found = it->second; - break; - } - - childKeys.push_back(*currentLValue); - currentLValue = baseof(*currentLValue); - } - - if (!found) - { - // Should not be using scope->lookup. This is already recursive. - if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end()) - found = it->second.typeId; - else - { - // Nothing exists in this Scope. Just skip and try the parent one. - currentScope = currentScope->parent; - continue; - } - } - - for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it) - { - const LValue& key = *it; - - // Symbol can happen. Skip. - if (get(key)) - continue; - else if (auto field = get(key)) - { - found = getIndexTypeFromType(scope, *found, field->key, Location(), false); - if (!found) - return std::nullopt; // Turns out this type doesn't have the property at all. We're done. - } - else - LUAU_ASSERT(!"New LValue alternative not handled here."); - } - - return found; - } - - // No entry for it at all. Can happen when LValue root is a global. - return std::nullopt; - } - const Symbol symbol = getBaseSymbol(lvalue); ScopePtr currentScope = scope; @@ -5820,85 +5619,47 @@ static bool isUndecidable(TypeId ty) return get(ty) || get(ty) || get(ty); } -ErrorVec TypeChecker::resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense) { - ErrorVec errVec; - resolve(predicates, errVec, scope->refinements, scope, sense); - return errVec; + resolve(predicates, scope->refinements, scope, sense); } -void TypeChecker::resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +void TypeChecker::resolve(const PredicateVec& predicates, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) { for (const Predicate& c : predicates) - resolve(c, errVec, refis, scope, sense, fromOr); + resolve(c, refis, scope, sense, fromOr); } -void TypeChecker::resolve(const Predicate& predicate, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +void TypeChecker::resolve(const Predicate& predicate, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) { if (auto truthyP = get(predicate)) - resolve(*truthyP, errVec, refis, scope, sense, fromOr); + resolve(*truthyP, refis, scope, sense, fromOr); else if (auto andP = get(predicate)) - resolve(*andP, errVec, refis, scope, sense); + resolve(*andP, refis, scope, sense); else if (auto orP = get(predicate)) - resolve(*orP, errVec, refis, scope, sense); + resolve(*orP, refis, scope, sense); else if (auto notP = get(predicate)) - resolve(notP->predicates, errVec, refis, scope, !sense, fromOr); + resolve(notP->predicates, refis, scope, !sense, fromOr); else if (auto isaP = get(predicate)) - resolve(*isaP, errVec, refis, scope, sense); + resolve(*isaP, refis, scope, sense); else if (auto typeguardP = get(predicate)) - resolve(*typeguardP, errVec, refis, scope, sense); + resolve(*typeguardP, refis, scope, sense); else if (auto eqP = get(predicate)) - resolve(*eqP, errVec, refis, scope, sense); + resolve(*eqP, refis, scope, sense); else ice("Unhandled predicate kind"); } -void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +void TypeChecker::resolve(const TruthyPredicate& truthyP, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) { - if (FFlag::LuauAssertStripsFalsyTypes) - { - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (ty && fromOr) - return addRefinement(refis, truthyP.lvalue, *ty); + std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); + if (ty && fromOr) + return addRefinement(refis, truthyP.lvalue, *ty); - refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense)); - } - else - { - auto predicate = [sense](TypeId option) -> std::optional { - if (isUndecidable(option) || isBoolean(option) || isNil(option) != sense) - return option; - - return std::nullopt; - }; - - if (FFlag::LuauDiscriminableUnions2) - { - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (ty && fromOr) - return addRefinement(refis, truthyP.lvalue, *ty); - - refineLValue(truthyP.lvalue, refis, scope, predicate); - } - else - { - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (!ty) - return; - - // This is a hack. :( - // Without this, the expression 'a or b' might refine 'b' to be falsy. - // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. - if (fromOr) - return addRefinement(refis, truthyP.lvalue, *ty); - - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, truthyP.lvalue, *result); - } - } + refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense)); } -void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const AndPredicate& andP, RefinementMap& refis, const ScopePtr& scope, bool sense) { if (!sense) { @@ -5907,14 +5668,14 @@ void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, Refinement {NotPredicate{std::move(andP.rhs)}}, }; - return resolve(orP, errVec, refis, scope, !sense); + return resolve(orP, refis, scope, !sense); } - resolve(andP.lhs, errVec, refis, scope, sense); - resolve(andP.rhs, errVec, refis, scope, sense); + resolve(andP.lhs, refis, scope, sense); + resolve(andP.rhs, refis, scope, sense); } -void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const OrPredicate& orP, RefinementMap& refis, const ScopePtr& scope, bool sense) { if (!sense) { @@ -5923,28 +5684,24 @@ void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMa {NotPredicate{std::move(orP.rhs)}}, }; - return resolve(andP, errVec, refis, scope, !sense); + return resolve(andP, refis, scope, !sense); } - ErrorVec discarded; - RefinementMap leftRefis; - resolve(orP.lhs, errVec, leftRefis, scope, sense); + resolve(orP.lhs, leftRefis, scope, sense); RefinementMap rightRefis; - resolve(orP.lhs, discarded, rightRefis, scope, !sense); - resolve(orP.rhs, errVec, rightRefis, scope, sense, true); // :( + resolve(orP.lhs, rightRefis, scope, !sense); + resolve(orP.rhs, rightRefis, scope, sense, true); // :( merge(refis, leftRefis); merge(refis, rightRefis); } -void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const IsAPredicate& isaP, RefinementMap& refis, const ScopePtr& scope, bool sense) { auto predicate = [&](TypeId option) -> std::optional { // This by itself is not truly enough to determine that A is stronger than B or vice versa. - // The best unambiguous way about this would be to have a function that returns the relationship ordering of a pair. - // i.e. TypeRelationship relationshipOf(TypeId superTy, TypeId subTy) bool optionIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); bool targetIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); @@ -5985,32 +5742,15 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return res; }; - if (FFlag::LuauDiscriminableUnions2) - { - refineLValue(isaP.lvalue, refis, scope, predicate); - } - else - { - std::optional ty = resolveLValue(refis, scope, isaP.lvalue); - if (!ty) - return; - - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, isaP.lvalue, *result); - else - { - addRefinement(refis, isaP.lvalue, errorRecoveryType(scope)); - errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); - } - } + refineLValue(isaP.lvalue, refis, scope, predicate); } -void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& refis, const ScopePtr& scope, bool sense) { // Rewrite the predicate 'type(foo) == "vector"' to be 'typeof(foo) == "Vector3"'. They're exactly identical. // This allows us to avoid writing in edge cases. if (!typeguardP.isTypeof && typeguardP.kind == "vector") - return resolve(TypeGuardPredicate{std::move(typeguardP.lvalue), typeguardP.location, "Vector3", true}, errVec, refis, scope, sense); + return resolve(TypeGuardPredicate{std::move(typeguardP.lvalue), typeguardP.location, "Vector3", true}, refis, scope, sense); std::optional ty = resolveLValue(refis, scope, typeguardP.lvalue); if (!ty) @@ -6060,52 +5800,29 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec if (auto it = primitives.find(typeguardP.kind); it != primitives.end()) { - if (FFlag::LuauDiscriminableUnions2) - { - refineLValue(typeguardP.lvalue, refis, scope, it->second(sense)); - return; - } - else - { - if (std::optional result = filterMap(*ty, it->second(sense))) - addRefinement(refis, typeguardP.lvalue, *result); - else - { - addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - if (sense) - errVec.push_back( - TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); - } - - return; - } + refineLValue(typeguardP.lvalue, refis, scope, it->second(sense)); + return; } - auto fail = [&](const TypeErrorData& err) { - if (!FFlag::LuauDiscriminableUnions2) - errVec.push_back(TypeError{typeguardP.location, err}); - addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - }; - if (!typeguardP.isTypeof) - return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); auto typeFun = globalScope->lookupType(typeguardP.kind); if (!typeFun || !typeFun->typeParams.empty() || !typeFun->typePackParams.empty()) - return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); TypeId type = follow(typeFun->type); // We're only interested in the root class of any classes. if (auto ctv = get(type); !ctv || ctv->parent) - return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); // This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA. // Until then, we rewrite this to be the same as using IsA. - return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, errVec, refis, scope, sense); + return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, refis, scope, sense); } -void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const ScopePtr& scope, bool sense) { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. auto options = [](TypeId ty) -> std::vector { @@ -6114,82 +5831,33 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa return {ty}; }; - if (FFlag::LuauDiscriminableUnions2) - { - std::vector rhs = options(eqP.type); + std::vector rhs = options(eqP.type); - if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) - return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. + if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - auto predicate = [&](TypeId option) -> std::optional { - if (sense && isUndecidable(option)) - return FFlag::LuauWeakEqConstraint ? option : eqP.type; + auto predicate = [&](TypeId option) -> std::optional { + if (sense && isUndecidable(option)) + return FFlag::LuauWeakEqConstraint ? option : eqP.type; - if (!sense && isNil(eqP.type)) - return (isUndecidable(option) || !isNil(option)) ? std::optional(option) : std::nullopt; + if (!sense && isNil(eqP.type)) + return (isUndecidable(option) || !isNil(option)) ? std::optional(option) : std::nullopt; - if (maybeSingleton(eqP.type)) - { - // Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this. - if (!sense || canUnify(eqP.type, option, eqP.location).empty()) - return sense ? eqP.type : option; - - // local variable works around an odd gcc 9.3 warning: may be used uninitialized - std::optional res = std::nullopt; - return res; - } - - return option; - }; - - refineLValue(eqP.lvalue, refis, scope, predicate); - } - else - { - if (FFlag::LuauWeakEqConstraint) + if (maybeSingleton(eqP.type)) { - if (!sense && isNil(eqP.type)) - resolve(TruthyPredicate{std::move(eqP.lvalue), eqP.location}, errVec, refis, scope, true, /* fromOr= */ false); + // Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this. + if (!sense || canUnify(eqP.type, option, eqP.location).empty()) + return sense ? eqP.type : option; - return; + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; } - if (FFlag::LuauEqConstraint) - { - std::optional ty = resolveLValue(refis, scope, eqP.lvalue); - if (!ty) - return; + return option; + }; - std::vector lhs = options(*ty); - std::vector rhs = options(eqP.type); - - if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) - { - addRefinement(refis, eqP.lvalue, eqP.type); - return; - } - else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) - return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - - std::unordered_set set; - for (TypeId left : lhs) - { - for (TypeId right : rhs) - { - // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. - if (canUnify(right, left, eqP.location).empty() == sense || (!sense && !isNil(left))) - set.insert(left); - } - } - - if (set.empty()) - return; - - std::vector viable(set.begin(), set.end()); - TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); - addRefinement(refis, eqP.lvalue, result); - } - } + refineLValue(eqP.lvalue, refis, scope, predicate); } bool TypeChecker::isNonstrictMode() const diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index c243589..ba09df5 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -5,8 +5,6 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" -LUAU_FASTFLAGVARIABLE(LuauTerminateCyclicMetatableIndexLookup, false) - namespace Luau { @@ -55,13 +53,10 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId t { TypeId index = follow(*mtIndex); - if (FFlag::LuauTerminateCyclicMetatableIndexLookup) - { - if (count >= 100) - return std::nullopt; + if (count >= 100) + return std::nullopt; - ++count; - } + ++count; if (const auto& itt = getTableType(index)) { diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 463b465..2355dab 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -24,8 +24,6 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) -LUAU_FASTFLAG(LuauDiscriminableUnions2) -LUAU_FASTFLAGVARIABLE(LuauAnyInIsOptionalIsOptional, false) LUAU_FASTFLAGVARIABLE(LuauClassDefinitionModuleInError, false) namespace Luau @@ -204,14 +202,14 @@ bool isOptional(TypeId ty) ty = follow(ty); - if (FFlag::LuauAnyInIsOptionalIsOptional && get(ty)) + if (get(ty)) return true; auto utv = get(ty); if (!utv) return false; - return std::any_of(begin(utv), end(utv), FFlag::LuauAnyInIsOptionalIsOptional ? isOptional : isNil); + return std::any_of(begin(utv), end(utv), isOptional); } bool isTableIntersection(TypeId ty) @@ -378,8 +376,7 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) if (seen.contains(ty)) return true; - bool isStr = FFlag::LuauDiscriminableUnions2 ? isString(ty) : isPrim(ty, PrimitiveTypeVar::String); - if (isStr || get(ty) || get(ty) || get(ty)) + if (isString(ty) || get(ty) || get(ty) || get(ty)) return true; if (auto uty = get(ty)) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index f5c1dde..9308e9f 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -24,8 +24,6 @@ LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter2, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) -LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) -LUAU_FASTFLAG(LuauTypecheckOptPass) namespace Luau { @@ -382,19 +380,6 @@ Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance LUAU_ASSERT(sharedState.iceHandler); } -Unifier::Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) - : types(types) - , mode(mode) - , log(parentLog, sharedSeen) - , location(location) - , variance(variance) - , sharedState(sharedState) -{ - LUAU_ASSERT(!FFlag::LuauTypecheckOptPass); - LUAU_ASSERT(sharedState.iceHandler); -} - void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { sharedState.counters.iterationCount = 0; @@ -1219,14 +1204,6 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal continue; } - // In nonstrict mode, any also marks an optional argument. - else if (!FFlag::LuauAnyInIsOptionalIsOptional && superIter.good() && isNonstrictMode() && - log.getMutable(log.follow(*superIter))) - { - superIter.advance(); - continue; - } - if (log.getMutable(superIter.packId)) { tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); @@ -1454,21 +1431,9 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto subIter = subTable->props.find(propName); - if (FFlag::LuauAnyInIsOptionalIsOptional) - { - if (subIter == subTable->props.end() && - (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type)) - missingProperties.push_back(propName); - } - else - { - bool isAny = log.getMutable(log.follow(superProp.type)); - - if (subIter == subTable->props.end() && - (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && - !isAny) - missingProperties.push_back(propName); - } + if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && + !isOptional(superProp.type)) + missingProperties.push_back(propName); } if (!missingProperties.empty()) @@ -1485,18 +1450,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto superIter = superTable->props.find(propName); - if (FFlag::LuauAnyInIsOptionalIsOptional) - { - if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || !isOptional(subProp.type))) - extraProperties.push_back(propName); - } - else - { - bool isAny = log.is(log.follow(subProp.type)); - if (superIter == superTable->props.end() && - (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) - extraProperties.push_back(propName); - } + if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || !isOptional(subProp.type))) + extraProperties.push_back(propName); } if (!extraProperties.empty()) @@ -1540,21 +1495,12 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); } - else if (FFlag::LuauAnyInIsOptionalIsOptional && - (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) + else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) { } - else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && - (isOptional(prop.type) || get(follow(prop.type)))) - // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` - // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. - // TODO: should isOptional(anyType) be true? - // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) - { - } else if (subTable->state == TableState::Free) { PendingType* pendingSub = log.queue(subTy); @@ -1618,10 +1564,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else if (variance == Covariant) { } - else if (FFlag::LuauAnyInIsOptionalIsOptional && !FFlag::LuauSubtypingAddOptPropsToUnsealedTables && isOptional(prop.type)) - { - } - else if (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables && (isOptional(prop.type) || get(follow(prop.type)))) + else if (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables && isOptional(prop.type)) { } else if (superTable->state == TableState::Free) @@ -1753,9 +1696,7 @@ TypePackId Unifier::widen(TypePackId tp) TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) { ty = follow(ty); - if (!FFlag::LuauAnyInIsOptionalIsOptional && get(ty)) - return ty; - else if (isOptional(ty)) + if (isOptional(ty)) return ty; else if (const TableTypeVar* ttv = get(ty)) { @@ -2666,14 +2607,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { - if (FFlag::LuauTypecheckOptPass) - { - Unifier u = Unifier{types, mode, location, variance, sharedState, &log}; - u.anyIsTop = anyIsTop; - return u; - } - - Unifier u = Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; + Unifier u = Unifier{types, mode, location, variance, sharedState, &log}; u.anyIsTop = anyIsTop; return u; } diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index b00440a..1246537 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -224,6 +224,7 @@ private: DenseHashMap constantMap; DenseHashMap tableShapeMap; + DenseHashMap protoMap; int debugLine = 0; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index fb70392..beeda29 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -6,6 +6,8 @@ #include #include +LUAU_FASTFLAG(LuauCompileNestedClosureO2) + namespace Luau { @@ -181,6 +183,7 @@ size_t BytecodeBuilder::TableShapeHash::operator()(const TableShape& v) const BytecodeBuilder::BytecodeBuilder(BytecodeEncoder* encoder) : constantMap({Constant::Type_Nil, ~0ull}) , tableShapeMap(TableShape()) + , protoMap(~0u) , stringTable({nullptr, 0}) , encoder(encoder) { @@ -250,6 +253,7 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues) constantMap.clear(); tableShapeMap.clear(); + protoMap.clear(); debugRemarks.clear(); debugRemarkBuffer.clear(); @@ -372,11 +376,17 @@ int32_t BytecodeBuilder::addConstantClosure(uint32_t fid) int16_t BytecodeBuilder::addChildFunction(uint32_t fid) { + if (FFlag::LuauCompileNestedClosureO2) + if (int16_t* cache = protoMap.find(fid)) + return *cache; + uint32_t id = uint32_t(protos.size()); if (id >= kMaxClosureCount) return -1; + if (FFlag::LuauCompileNestedClosureO2) + protoMap[fid] = int16_t(id); protos.push_back(fid); return int16_t(id); diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index e177e92..4f26ceb 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -17,8 +17,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompileSupportInlining, false) - LUAU_FASTFLAGVARIABLE(LuauCompileIter, false) LUAU_FASTFLAGVARIABLE(LuauCompileIterNoReserve, false) LUAU_FASTFLAGVARIABLE(LuauCompileIterNoPairs, false) @@ -30,6 +28,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) +LUAU_FASTFLAGVARIABLE(LuauCompileNestedClosureO2, false) + namespace Luau { @@ -100,13 +100,11 @@ struct Compiler upvals.reserve(16); } - uint8_t getLocal(AstLocal* local) + int getLocalReg(AstLocal* local) { Local* l = locals.find(local); - LUAU_ASSERT(l); - LUAU_ASSERT(l->allocated); - return l->reg; + return l && l->allocated ? l->reg : -1; } uint8_t getUpval(AstLocal* local) @@ -159,17 +157,19 @@ struct Compiler AstExprFunction* getFunctionExpr(AstExpr* node) { - if (AstExprLocal* le = node->as()) + if (AstExprLocal* expr = node->as()) { - Variable* lv = variables.find(le->local); + Variable* lv = variables.find(expr->local); if (!lv || lv->written || !lv->init) return nullptr; return getFunctionExpr(lv->init); } - else if (AstExprGroup* ge = node->as()) - return getFunctionExpr(ge->expr); + else if (AstExprGroup* expr = node->as()) + return getFunctionExpr(expr->expr); + else if (AstExprTypeAssertion* expr = node->as()) + return getFunctionExpr(expr->expr); else return node->as(); } @@ -180,13 +180,13 @@ struct Compiler { bool result = true; - bool visit(AstExpr* node) override + bool visit(AstExprFunction* node) override { - // nested functions may capture function arguments, and our upval handling doesn't handle elided variables (constant) - // TODO: we could remove this case if we changed function compilation to create temporary locals for constant upvalues - // TODO: additionally we would need to change upvalue handling in compileExprFunction to handle upvalue->local migration - result = result && !node->is(); - return result; + if (!FFlag::LuauCompileNestedClosureO2) + result = false; + + // short-circuit to avoid analyzing nested closure bodies + return false; } bool visit(AstStat* node) override @@ -275,8 +275,7 @@ struct Compiler f.upvals = upvals; // record information for inlining - if (FFlag::LuauCompileSupportInlining && options.optimizationLevel >= 2 && !func->vararg && canInlineFunctionBody(func->body) && - !getfenvUsed && !setfenvUsed) + if (options.optimizationLevel >= 2 && !func->vararg && canInlineFunctionBody(func->body) && !getfenvUsed && !setfenvUsed) { f.canInline = true; f.stackSize = stackSize; @@ -346,8 +345,8 @@ struct Compiler uint8_t argreg; - if (isExprLocalReg(arg)) - argreg = getLocal(arg->as()->local); + if (int reg = getExprLocalReg(arg); reg >= 0) + argreg = uint8_t(reg); else { argreg = uint8_t(regs + 1); @@ -403,8 +402,8 @@ struct Compiler } } - if (isExprLocalReg(expr->args.data[i])) - args[i] = getLocal(expr->args.data[i]->as()->local); + if (int reg = getExprLocalReg(expr->args.data[i]); reg >= 0) + args[i] = uint8_t(reg); else { args[i] = uint8_t(regs + 1 + i); @@ -489,19 +488,18 @@ struct Compiler return false; } - // TODO: we can compile functions with mismatching arity at call site but it's more annoying - if (func->args.size != expr->args.size) - { - bytecode.addDebugRemark("inlining failed: argument count mismatch (expected %d, got %d)", int(func->args.size), int(expr->args.size)); - return false; - } - - // we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to inlining + // compute constant bitvector for all arguments to feed the cost model bool varc[8] = {}; - for (size_t i = 0; i < expr->args.size && i < 8; ++i) + for (size_t i = 0; i < func->args.size && i < expr->args.size && i < 8; ++i) varc[i] = isConstant(expr->args.data[i]); - int inlinedCost = computeCost(fi->costModel, varc, std::min(int(expr->args.size), 8)); + // if the last argument only returns a single value, all following arguments are nil + if (expr->args.size != 0 && !(expr->args.data[expr->args.size - 1]->is() || expr->args.data[expr->args.size - 1]->is())) + for (size_t i = expr->args.size; i < func->args.size && i < 8; ++i) + varc[i] = true; + + // we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to inlining + int inlinedCost = computeCost(fi->costModel, varc, std::min(int(func->args.size), 8)); int baselineCost = computeCost(fi->costModel, nullptr, 0) + 3; int inlineProfit = (inlinedCost == 0) ? thresholdMaxBoost : std::min(thresholdMaxBoost, 100 * baselineCost / inlinedCost); @@ -533,15 +531,44 @@ struct Compiler for (size_t i = 0; i < func->args.size; ++i) { AstLocal* var = func->args.data[i]; - AstExpr* arg = expr->args.data[i]; + AstExpr* arg = i < expr->args.size ? expr->args.data[i] : nullptr; - if (Variable* vv = variables.find(var); vv && vv->written) + if (i + 1 == expr->args.size && func->args.size > expr->args.size && (arg->is() || arg->is())) + { + // if the last argument can return multiple values, we need to compute all of them into the remaining arguments + unsigned int tail = unsigned(func->args.size - expr->args.size) + 1; + uint8_t reg = allocReg(arg, tail); + + if (AstExprCall* expr = arg->as()) + compileExprCall(expr, reg, tail, /* targetTop= */ true); + else if (AstExprVarargs* expr = arg->as()) + compileExprVarargs(expr, reg, tail); + else + LUAU_ASSERT(!"Unexpected expression type"); + + for (size_t j = i; j < func->args.size; ++j) + pushLocal(func->args.data[j], uint8_t(reg + (j - i))); + + // all remaining function arguments have been allocated and assigned to + break; + } + else if (Variable* vv = variables.find(var); vv && vv->written) { // if the argument is mutated, we need to allocate a fresh register even if it's a constant uint8_t reg = allocReg(arg, 1); - compileExprTemp(arg, reg); + + if (arg) + compileExprTemp(arg, reg); + else + bytecode.emitABC(LOP_LOADNIL, reg, 0, 0); + pushLocal(var, reg); } + else if (arg == nullptr) + { + // since the argument is not mutated, we can simply fold the value into the expressions that need it + locstants[var] = {Constant::Type_Nil}; + } else if (const Constant* cv = constants.find(arg); cv && cv->type != Constant::Type_Unknown) { // since the argument is not mutated, we can simply fold the value into the expressions that need it @@ -553,20 +580,26 @@ struct Compiler Variable* lv = le ? variables.find(le->local) : nullptr; // if the argument is a local that isn't mutated, we will simply reuse the existing register - if (isExprLocalReg(arg) && (!lv || !lv->written)) + if (int reg = le ? getExprLocalReg(le) : -1; reg >= 0 && (!lv || !lv->written)) { - uint8_t reg = getLocal(le->local); - pushLocal(var, reg); + pushLocal(var, uint8_t(reg)); } else { - uint8_t reg = allocReg(arg, 1); - compileExprTemp(arg, reg); - pushLocal(var, reg); + uint8_t temp = allocReg(arg, 1); + compileExprTemp(arg, temp); + pushLocal(var, temp); } } } + // evaluate extra expressions for side effects + for (size_t i = func->args.size; i < expr->args.size; ++i) + { + RegScope rsi(this); + compileExprAuto(expr->args.data[i], rsi); + } + // fold constant values updated above into expressions in the function body foldConstants(constants, variables, locstants, func->body); @@ -627,12 +660,15 @@ struct Compiler FInt::LuauCompileInlineThresholdMaxBoost, FInt::LuauCompileInlineDepth)) return; - if (fi && !fi->canInline) + // add a debug remark for cases when we didn't even call tryCompileInlinedCall + if (func && !(fi && fi->canInline)) { if (func->vararg) bytecode.addDebugRemark("inlining failed: function is variadic"); - else + else if (fi) bytecode.addDebugRemark("inlining failed: complex constructs in function body"); + else + bytecode.addDebugRemark("inlining failed: can't inline recursive calls"); } } @@ -677,9 +713,9 @@ struct Compiler LUAU_ASSERT(fi); // Optimization: use local register directly in NAMECALL if possible - if (isExprLocalReg(fi->expr)) + if (int reg = getExprLocalReg(fi->expr); reg >= 0) { - selfreg = getLocal(fi->expr->as()->local); + selfreg = uint8_t(reg); } else { @@ -785,6 +821,8 @@ struct Compiler void compileExprFunction(AstExprFunction* expr, uint8_t target) { + RegScope rs(this); + const Function* f = functions.find(expr); LUAU_ASSERT(f); @@ -795,6 +833,67 @@ struct Compiler if (pid < 0) CompileError::raise(expr->location, "Exceeded closure limit; simplify the code to compile"); + if (FFlag::LuauCompileNestedClosureO2) + { + captures.clear(); + captures.reserve(f->upvals.size()); + + for (AstLocal* uv : f->upvals) + { + LUAU_ASSERT(uv->functionDepth < expr->functionDepth); + + if (int reg = getLocalReg(uv); reg >= 0) + { + // note: we can't check if uv is an upvalue in the current frame because inlining can migrate from upvalues to locals + Variable* ul = variables.find(uv); + bool immutable = !ul || !ul->written; + + captures.push_back({immutable ? LCT_VAL : LCT_REF, uint8_t(reg)}); + } + else if (const Constant* uc = locstants.find(uv); uc && uc->type != Constant::Type_Unknown) + { + // inlining can result in an upvalue capture of a constant, in which case we can't capture without a temporary register + uint8_t reg = allocReg(expr, 1); + compileExprConstant(expr, uc, reg); + + captures.push_back({LCT_VAL, reg}); + } + else + { + LUAU_ASSERT(uv->functionDepth < expr->functionDepth - 1); + + // get upvalue from parent frame + // note: this will add uv to the current upvalue list if necessary + uint8_t uid = getUpval(uv); + + captures.push_back({LCT_UPVAL, uid}); + } + } + + // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure + // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it + // is used) + int16_t shared = -1; + + if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) + { + int32_t cid = bytecode.addConstantClosure(f->id); + + if (cid >= 0 && cid < 32768) + shared = int16_t(cid); + } + + if (shared >= 0) + bytecode.emitAD(LOP_DUPCLOSURE, target, shared); + else + bytecode.emitAD(LOP_NEWCLOSURE, target, pid); + + for (const Capture& c : captures) + bytecode.emitABC(LOP_CAPTURE, uint8_t(c.type), c.data, 0); + + return; + } + bool shared = false; // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure @@ -824,9 +923,10 @@ struct Compiler if (uv->functionDepth == expr->functionDepth - 1) { // get local variable - uint8_t reg = getLocal(uv); + int reg = getLocalReg(uv); + LUAU_ASSERT(reg >= 0); - bytecode.emitABC(LOP_CAPTURE, uint8_t(immutable ? LCT_VAL : LCT_REF), reg, 0); + bytecode.emitABC(LOP_CAPTURE, uint8_t(immutable ? LCT_VAL : LCT_REF), uint8_t(reg), 0); } else { @@ -1213,10 +1313,10 @@ struct Compiler if (!isConditionFast(expr->left)) { // Optimization: when right hand side is a local variable, we can use AND/OR - if (isExprLocalReg(expr->right)) + if (int reg = getExprLocalReg(expr->right); reg >= 0) { uint8_t lr = compileExprAuto(expr->left, rs); - uint8_t rr = getLocal(expr->right->as()->local); + uint8_t rr = uint8_t(reg); bytecode.emitABC(and_ ? LOP_AND : LOP_OR, target, lr, rr); return; @@ -1803,19 +1903,18 @@ struct Compiler } else if (AstExprLocal* expr = node->as()) { - if (FFlag::LuauCompileSupportInlining ? !isExprLocalReg(expr) : expr->upvalue) + // note: this can't check expr->upvalue because upvalues may be upgraded to locals during inlining + if (int reg = getExprLocalReg(expr); reg >= 0) + { + bytecode.emitABC(LOP_MOVE, target, uint8_t(reg), 0); + } + else { LUAU_ASSERT(expr->upvalue); uint8_t uid = getUpval(expr->local); bytecode.emitABC(LOP_GETUPVAL, target, uid, 0); } - else - { - uint8_t reg = getLocal(expr->local); - - bytecode.emitABC(LOP_MOVE, target, reg, 0); - } } else if (AstExprGlobal* expr = node->as()) { @@ -1879,8 +1978,8 @@ struct Compiler uint8_t compileExprAuto(AstExpr* node, RegScope&) { // Optimization: directly return locals instead of copying them to a temporary - if (isExprLocalReg(node)) - return getLocal(node->as()->local); + if (int reg = getExprLocalReg(node); reg >= 0) + return uint8_t(reg); // note: the register is owned by the parent scope uint8_t reg = allocReg(node, 1); @@ -1910,7 +2009,7 @@ struct Compiler for (size_t i = 0; i < targetCount; ++i) compileExprTemp(list.data[i], uint8_t(target + i)); - // compute expressions with values that go nowhere; this is required to run side-effecting code if any + // evaluate extra expressions for side effects for (size_t i = targetCount; i < list.size; ++i) { RegScope rsi(this); @@ -2008,20 +2107,21 @@ struct Compiler if (AstExprLocal* expr = node->as()) { - if (FFlag::LuauCompileSupportInlining ? !isExprLocalReg(expr) : expr->upvalue) + // note: this can't check expr->upvalue because upvalues may be upgraded to locals during inlining + if (int reg = getExprLocalReg(expr); reg >= 0) { - LUAU_ASSERT(expr->upvalue); - - LValue result = {LValue::Kind_Upvalue}; - result.upval = getUpval(expr->local); + LValue result = {LValue::Kind_Local}; + result.reg = uint8_t(reg); result.location = node->location; return result; } else { - LValue result = {LValue::Kind_Local}; - result.reg = getLocal(expr->local); + LUAU_ASSERT(expr->upvalue); + + LValue result = {LValue::Kind_Upvalue}; + result.upval = getUpval(expr->local); result.location = node->location; return result; @@ -2115,15 +2215,21 @@ struct Compiler compileLValueUse(lv, source, /* set= */ true); } - bool isExprLocalReg(AstExpr* expr) + int getExprLocalReg(AstExpr* node) { - AstExprLocal* le = expr->as(); - if (!le || (!FFlag::LuauCompileSupportInlining && le->upvalue)) - return false; + if (AstExprLocal* expr = node->as()) + { + // note: this can't check expr->upvalue because upvalues may be upgraded to locals during inlining + Local* l = locals.find(expr->local); - Local* l = locals.find(le->local); - - return l && l->allocated; + return l && l->allocated ? l->reg : -1; + } + else if (AstExprGroup* expr = node->as()) + return getExprLocalReg(expr->expr); + else if (AstExprTypeAssertion* expr = node->as()) + return getExprLocalReg(expr->expr); + else + return -1; } bool isStatBreak(AstStat* node) @@ -2352,20 +2458,17 @@ struct Compiler // Optimization: return locals directly instead of copying them into a temporary // this is very important for a single return value and occasionally effective for multiple values - if (stat->list.size > 0 && isExprLocalReg(stat->list.data[0])) + if (int reg = stat->list.size > 0 ? getExprLocalReg(stat->list.data[0]) : -1; reg >= 0) { - temp = getLocal(stat->list.data[0]->as()->local); + temp = uint8_t(reg); consecutive = true; for (size_t i = 1; i < stat->list.size; ++i) - { - AstExpr* v = stat->list.data[i]; - if (!isExprLocalReg(v) || getLocal(v->as()->local) != temp + i) + if (getExprLocalReg(stat->list.data[i]) != int(temp + i)) { consecutive = false; break; } - } } if (!consecutive && stat->list.size > 0) @@ -2438,12 +2541,13 @@ struct Compiler { bool result = true; - bool visit(AstExpr* node) override + bool visit(AstExprFunction* node) override { - // functions may capture loop variable, and our upval handling doesn't handle elided variables (constant) - // TODO: we could remove this case if we changed function compilation to create temporary locals for constant upvalues - result = result && !node->is(); - return result; + if (!FFlag::LuauCompileNestedClosureO2) + result = false; + + // short-circuit to avoid analyzing nested closure bodies + return false; } bool visit(AstStat* node) override @@ -2874,12 +2978,9 @@ struct Compiler void compileStatFunction(AstStatFunction* stat) { // Optimization: compile value expresion directly into target local register - if (isExprLocalReg(stat->name)) + if (int reg = getExprLocalReg(stat->name); reg >= 0) { - AstExprLocal* le = stat->name->as(); - LUAU_ASSERT(le); - - compileExpr(stat->func, getLocal(le->local)); + compileExpr(stat->func, uint8_t(reg)); return; } @@ -3399,6 +3500,12 @@ struct Compiler std::vector returnJumps; }; + struct Capture + { + LuauCaptureType type; + uint8_t data; + }; + BytecodeBuilder& bytecode; CompileOptions options; @@ -3422,6 +3529,7 @@ struct Compiler std::vector loopJumps; std::vector loops; std::vector inlineFrames; + std::vector captures; }; void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options) @@ -3465,6 +3573,9 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName /* self= */ nullptr, AstArray(), /* vararg= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); uint32_t mainid = compiler.compileFunction(&main); + const Compiler::Function* mainf = compiler.functions.find(&main); + LUAU_ASSERT(mainf && mainf->upvals.empty()); + bytecode.setMainFunction(mainid); bytecode.finalize(); } diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index e4d59ea..a62beeb 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -3,8 +3,6 @@ #include -LUAU_FASTFLAG(LuauCompileSupportInlining) - namespace Luau { namespace Compile @@ -330,7 +328,7 @@ struct ConstantVisitor : AstVisitor { if (value.type != Constant::Type_Unknown) map[key] = value; - else if (!FFlag::LuauCompileSupportInlining || wasEmpty) + else if (wasEmpty) ; else if (Constant* old = map.find(key)) old->type = Constant::Type_Unknown; diff --git a/Sources.cmake b/Sources.cmake index d2430cc..297f561 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -73,6 +73,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/ToString.h Analysis/include/Luau/Transpiler.h Analysis/include/Luau/TxnLog.h + Analysis/include/Luau/TypeArena.h Analysis/include/Luau/TypeAttach.h Analysis/include/Luau/TypedAllocator.h Analysis/include/Luau/TypeInfer.h @@ -108,6 +109,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/ToString.cpp Analysis/src/Transpiler.cpp Analysis/src/TxnLog.cpp + Analysis/src/TypeArena.cpp Analysis/src/TypeAttach.cpp Analysis/src/TypedAllocator.cpp Analysis/src/TypeInfer.cpp diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 9c1f387..27187c6 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -10,10 +10,6 @@ #include "ldebug.h" #include "lvm.h" -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauTableMoveTelemetry2, false) - -void (*lua_table_move_telemetry)(lua_State* L, int f, int e, int t, int nf, int nt); - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -199,29 +195,6 @@ static int tmove(lua_State* L) int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ luaL_checktype(L, tt, LUA_TTABLE); - void (*telemetrycb)(lua_State * L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry; - - if (DFFlag::LuauTableMoveTelemetry2 && telemetrycb && e >= f) - { - int nf = lua_objlen(L, 1); - int nt = lua_objlen(L, tt); - - bool report = false; - - // source index range must be in bounds in source table unless the table is empty (permits 1..#t moves) - if (!(f == 1 || (f >= 1 && f <= nf))) - report = true; - if (!(e == nf || (e >= 1 && e <= nf))) - report = true; - - // destination index must be in bounds in dest table or be exactly at the first empty element (permits concats) - if (!(t == nt + 1 || (t >= 1 && t <= nt))) - report = true; - - if (report) - telemetrycb(L, f, e, t, nf, nt); - } - if (e >= f) { /* otherwise, nothing to move */ luaL_argcheck(L, f > 0 || e < INT_MAX + f, 3, "too many elements to move"); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 3c7c276..9e2eb26 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -17,9 +17,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauIter, false) -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauIterCallTelemetry, false) - -void (*lua_iter_call_telemetry)(lua_State* L); // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ @@ -157,17 +154,6 @@ LUAU_NOINLINE static bool luau_loopFORG(lua_State* L, int a, int c) StkId ra = &L->base[a]; LUAU_ASSERT(ra + 3 <= L->top); - if (DFFlag::LuauIterCallTelemetry) - { - /* TODO: we might be able to stop supporting this depending on whether it's used in practice */ - void (*telemetrycb)(lua_State* L) = lua_iter_call_telemetry; - - if (telemetrycb && ttistable(ra) && fasttm(L, hvalue(ra)->metatable, TM_CALL)) - telemetrycb(L); - if (telemetrycb && ttisuserdata(ra) && fasttm(L, uvalue(ra)->metatable, TM_CALL)) - telemetrycb(L); - } - setobjs2s(L, ra + 3 + 2, ra + 2); setobjs2s(L, ra + 3 + 1, ra + 1); setobjs2s(L, ra + 3, ra); diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 12c6845..f001750 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -92,4 +92,17 @@ bar(foo()) CHECK_EQ("number", toString(*expectedOty)); } +TEST_CASE_FIXTURE(Fixture, "ast_ancestry_at_eof") +{ + check(R"( +if true then + )"); + + std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), Position(2, 4)); + REQUIRE_GE(ancestry.size(), 2); + AstStat* parentStat = ancestry[ancestry.size() - 2]->asStat(); + REQUIRE(bool(parentStat)); + REQUIRE(parentStat->is()); +} + TEST_SUITE_END(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index b4e9340..caaccf4 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2772,6 +2772,8 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") { + ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; + check(R"( type tag = "cat" | "dog" local function f(a: tag) end @@ -2844,6 +2846,8 @@ f(@1) TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_escape") { + ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; + check(R"( type tag = "strange\t\"cat\"" | 'nice\t"dog"' local function f(x: tag) end diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index b032060..cf27d19 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -4269,22 +4269,26 @@ FORNLOOP R3 -6 FORNLOOP R0 -11 RETURN R0 0 )"); +} - // can't unroll loops if the body has functions that refer to loop variables +TEST_CASE("LoopUnrollNestedClosure") +{ + ScopedFastFlag sff("LuauCompileNestedClosureO2", true); + + // if the body has functions that refer to loop variables, we unroll the loop and use MOVE+CAPTURE for upvalues CHECK_EQ("\n" + compileFunction(R"( -for i=1,1 do +for i=1,2 do local x = function() return i end end )", 1, 2), R"( -LOADN R2 1 -LOADN R0 1 LOADN R1 1 -FORNPREP R0 +3 -NEWCLOSURE R3 P0 -CAPTURE VAL R2 -FORNLOOP R0 -3 +NEWCLOSURE R0 P0 +CAPTURE VAL R1 +LOADN R1 2 +NEWCLOSURE R0 P0 +CAPTURE VAL R1 RETURN R0 0 )"); } @@ -4469,8 +4473,6 @@ RETURN R0 0 TEST_CASE("InlineBasic") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // inline function that returns a constant CHECK_EQ("\n" + compileFunction(R"( local function foo() @@ -4550,10 +4552,72 @@ RETURN R1 1 )"); } +TEST_CASE("InlineBasicProhibited") +{ + ScopedFastFlag sff("LuauCompileNestedClosureO2", true); + + // we can't inline variadic functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(...) + return 42 +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); + + // we also can't inline functions that have internal loops + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + for i=1,4 do end +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +CALL R1 0 1 +RETURN R1 1 +)"); +} + +TEST_CASE("InlineNestedClosures") +{ + ScopedFastFlag sff("LuauCompileNestedClosureO2", true); + + // we can inline functions that contain/return functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(x) + return function(y) return x + y end +end + +local x = foo(1)(2) +return x +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R2 1 +NEWCLOSURE R1 P1 +CAPTURE VAL R2 +LOADN R2 2 +CALL R1 1 1 +RETURN R1 1 +)"); +} + TEST_CASE("InlineMutate") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // if the argument is mutated, it gets a register even if the value is constant CHECK_EQ("\n" + compileFunction(R"( local function foo(a) @@ -4636,8 +4700,6 @@ RETURN R1 1 TEST_CASE("InlineUpval") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // if the argument is an upvalue, we naturally need to copy it to a local CHECK_EQ("\n" + compileFunction(R"( local function foo(a) @@ -4705,8 +4767,6 @@ RETURN R1 1 TEST_CASE("InlineFallthrough") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // if the function doesn't return, we still fill the results with nil CHECK_EQ("\n" + compileFunction(R"( local function foo() @@ -4759,8 +4819,6 @@ RETURN R1 -1 TEST_CASE("InlineCapture") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // can't inline function with nested functions that capture locals because they might be constants CHECK_EQ("\n" + compileFunction(R"( local function foo(a) @@ -4782,12 +4840,9 @@ RETURN R2 -1 TEST_CASE("InlineArgMismatch") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // when inlining a function, we must respect all the usual rules // caller might not have enough arguments - // TODO: we don't inline this atm CHECK_EQ("\n" + compileFunction(R"( local function foo(a) return a @@ -4799,13 +4854,11 @@ return x 1, 2), R"( DUPCLOSURE R0 K0 -MOVE R1 R0 -CALL R1 0 1 +LOADNIL R1 RETURN R1 1 )"); // caller might be using multret for arguments - // TODO: we don't inline this atm CHECK_EQ("\n" + compileFunction(R"( local function foo(a, b) return a + b @@ -4817,17 +4870,32 @@ return x 1, 2), R"( DUPCLOSURE R0 K0 -MOVE R1 R0 LOADK R3 K1 FASTCALL1 20 R3 +2 GETIMPORT R2 4 -CALL R2 1 -1 -CALL R1 -1 1 +CALL R2 1 2 +ADD R1 R2 R3 +RETURN R1 1 +)"); + + // caller might be using varargs for arguments + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a, b) + return a + b +end + +local x = foo(...) +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R2 2 +ADD R1 R2 R3 RETURN R1 1 )"); // caller might have too many arguments, but we still need to compute them for side effects - // TODO: we don't inline this atm CHECK_EQ("\n" + compileFunction(R"( local function foo(a) return a @@ -4839,19 +4907,34 @@ return x 1, 2), R"( DUPCLOSURE R0 K0 -MOVE R1 R0 +GETIMPORT R2 2 +CALL R2 0 1 +LOADN R1 42 +RETURN R1 1 +)"); + + // caller might not have enough arguments, and the arg might be mutated so it needs a register + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = 42 + return a +end + +local x = foo() +return x +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R2 LOADN R2 42 -GETIMPORT R3 2 -CALL R3 0 -1 -CALL R1 -1 1 +MOVE R1 R2 RETURN R1 1 )"); } TEST_CASE("InlineMultiple") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // we call this with a different set of variable/constant args CHECK_EQ("\n" + compileFunction(R"( local function foo(a, b) @@ -4880,8 +4963,6 @@ RETURN R3 4 TEST_CASE("InlineChain") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // inline a chain of functions CHECK_EQ("\n" + compileFunction(R"( local function foo(a, b) @@ -4912,8 +4993,6 @@ RETURN R3 1 TEST_CASE("InlineThresholds") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - ScopedFastInt sfis[] = { {"LuauCompileInlineThreshold", 25}, {"LuauCompileInlineThresholdMaxBoost", 300}, @@ -4988,8 +5067,6 @@ RETURN R3 1 TEST_CASE("InlineIIFE") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // IIFE with arguments CHECK_EQ("\n" + compileFunction(R"( function choose(a, b, c) @@ -5025,8 +5102,6 @@ RETURN R3 1 TEST_CASE("InlineRecurseArguments") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - // we can't inline a function if it's used to compute its own arguments CHECK_EQ("\n" + compileFunction(R"( local function foo(a, b) @@ -5036,22 +5111,20 @@ foo(foo(foo,foo(foo,foo))[foo]) 1, 2), R"( DUPCLOSURE R0 K0 -MOVE R1 R0 +MOVE R2 R0 +MOVE R3 R0 MOVE R4 R0 MOVE R5 R0 MOVE R6 R0 -CALL R4 2 1 -LOADNIL R3 -GETTABLE R2 R3 R0 -CALL R1 1 0 +CALL R4 2 -1 +CALL R2 -1 1 +GETTABLE R1 R2 R0 RETURN R0 0 )"); } TEST_CASE("InlineFastCallK") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - CHECK_EQ("\n" + compileFunction(R"( local function set(l0) rawset({}, l0) @@ -5080,8 +5153,6 @@ RETURN R0 0 TEST_CASE("InlineExprIndexK") { - ScopedFastFlag sff("LuauCompileSupportInlining", true); - CHECK_EQ("\n" + compileFunction(R"( local _ = function(l0) local _ = nil @@ -5141,6 +5212,58 @@ RETURN R0 0 )"); } +TEST_CASE("InlineHiddenMutation") +{ + // when the argument is assigned inside the function, we can't reuse the local + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = 42 + return a +end + +local x = ... +local y = foo(x :: number) +return y +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +MOVE R3 R1 +LOADN R3 42 +MOVE R2 R3 +RETURN R2 1 +)"); + + // and neither can we do that when it's assigned outside the function + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + mutator() + return a +end + +local x = ... +mutator = function() x = 42 end + +local y = foo(x :: number) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +NEWCLOSURE R2 P1 +CAPTURE REF R1 +SETGLOBAL R2 K1 +MOVE R3 R1 +GETGLOBAL R4 K1 +CALL R4 0 0 +MOVE R2 R3 +CLOSEUPVALS R1 +RETURN R2 1 +)"); +} + TEST_CASE("ReturnConsecutive") { // we can return a single local directly @@ -5193,6 +5316,16 @@ return )"), R"( RETURN R0 0 +)"); + + // this optimization also works in presence of group / type casts + CHECK_EQ("\n" + compileFunction0(R"( +local x, y = ... +return (x), y :: number +)"), + R"( +GETVARARGS R0 2 +RETURN R0 2 )"); } diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 4a99986..c7e18ef 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -198,10 +198,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") TEST_CASE_FIXTURE(Fixture, "clone_free_types") { - ScopedFastFlag sff[]{ - {"LuauLosslessClone", true}, - }; - TypeVar freeTy(FreeTypeVar{TypeLevel{}}); TypePackVar freeTp(FreeTypePack{TypeLevel{}}); @@ -218,8 +214,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_free_types") TEST_CASE_FIXTURE(Fixture, "clone_free_tables") { - ScopedFastFlag sff{"LuauLosslessClone", true}; - TypeVar tableTy{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tableTy); ttv->state = TableState::Free; @@ -252,8 +246,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_constrained_intersection") TEST_CASE_FIXTURE(BuiltinsFixture, "clone_self_property") { - ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; - fileResolver.source["Module/A"] = R"( --!nonstrict local a = {} diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index 69430b1..83c526e 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -150,8 +150,6 @@ TEST_CASE_FIXTURE(Fixture, "parameters_having_type_any_are_optional") TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") { - ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; - CheckResult result = check(R"( --!nonstrict local T = {} @@ -169,8 +167,6 @@ TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") TEST_CASE_FIXTURE(Fixture, "offer_a_hint_if_you_use_a_dot_instead_of_a_colon") { - ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; - CheckResult result = check(R"( --!nonstrict local T = {} diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 4183068..dd49eb0 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -683,6 +683,7 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_table_is_marked_normal") { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeFlagIsConservative", false} }; check(R"( @@ -697,6 +698,26 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_table_is_marked_normal") CHECK(t->normal); } +// Unfortunately, getting this right in the general case is difficult. +TEST_CASE_FIXTURE(Fixture, "cyclic_table_is_not_marked_normal") +{ + ScopedFastFlag flags[] = { + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeFlagIsConservative", true} + }; + + check(R"( + type Fiber = { + return_: Fiber? + } + + local f: Fiber + )"); + + TypeId t = requireType("f"); + CHECK(!t->normal); +} + TEST_CASE_FIXTURE(Fixture, "variadic_tail_is_marked_normal") { ScopedFastFlag flags[] = { @@ -997,4 +1018,28 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_failure_bound_type_is_normal_but_not_its_bounde LUAU_REQUIRE_ERRORS(result); } +// We had an issue where a normal BoundTypeVar might point at a non-normal BoundTypeVar if it in turn pointed to a +// normal TypeVar because we were calling follow() in an improper place. +TEST_CASE_FIXTURE(Fixture, "bound_typevars_should_only_be_marked_normal_if_their_pointee_is_normal") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeFlagIsConservative", true}, + }; + + CheckResult result = check(R"( + local T = {} + + function T:M() + local function f(a) + print(self.prop) + self:g(a) + self.prop = a + end + end + + return T + )"); +} + TEST_SUITE_END(); diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index c16f60d..14c1761 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -22,8 +22,6 @@ struct LimitFixture : BuiltinsFixture #if defined(_NOOPT) || defined(_DEBUG) ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100}; #endif - - ScopedFastFlag LuauJustOneCallFrameForHaveSeen{"LuauJustOneCallFrameForHaveSeen", true}; }; template diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 50d0838..b854bc5 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -126,6 +126,39 @@ TEST_CASE_FIXTURE(Fixture, "functions_are_always_parenthesized_in_unions_or_inte CHECK_EQ(toString(&itv), "((number, string) -> (string, number)) & ((string, number) -> (number, string))"); } +TEST_CASE_FIXTURE(Fixture, "intersections_respects_use_line_breaks") +{ + CheckResult result = check(R"( + local a: ((string) -> string) & ((number) -> number) + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + + //clang-format off + CHECK_EQ("((number) -> number)\n" + "& ((string) -> string)", + toString(requireType("a"), opts)); + //clang-format on +} + +TEST_CASE_FIXTURE(Fixture, "unions_respects_use_line_breaks") +{ + CheckResult result = check(R"( + local a: string | number | boolean + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + + //clang-format off + CHECK_EQ("boolean\n" + "| number\n" + "| string", + toString(requireType("a"), opts)); + //clang-format on +} + TEST_CASE_FIXTURE(Fixture, "quit_stringifying_table_type_when_length_is_exceeded") { TableTypeVar ttv{}; @@ -617,4 +650,52 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_overrides_param_names") CHECK_EQ("test(first: a, second: string, ...: number): a", toStringNamedFunction("test", *ftv, opts)); } +TEST_CASE_FIXTURE(Fixture, "pick_distinct_names_for_mixed_explicit_and_implicit_generics") +{ + ScopedFastFlag sff[] = { + {"LuauAlwaysQuantify", true}, + }; + + CheckResult result = check(R"( + function foo(x: a, y) end + )"); + + CHECK("(a, b) -> ()" == toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") +{ + ScopedFastFlag flag{"LuauDocFuncParameters", true}; + CheckResult result = check(R"( + local foo = {} + function foo:method(arg: string): () + end + )"); + + TypeId parentTy = requireType("foo"); + auto ttv = get(follow(parentTy)); + auto ftv = get(ttv->props.at("method").type); + + CHECK_EQ("foo:method(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv)); +} + + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") +{ + ScopedFastFlag flag{"LuauDocFuncParameters", true}; + CheckResult result = check(R"( + local foo = {} + function foo:method(arg: string): () + end + )"); + + TypeId parentTy = requireType("foo"); + auto ttv = get(follow(parentTy)); + auto ftv = get(ttv->props.at("method").type); + + ToStringOptions opts; + opts.hideFunctionSelfArgument = true; + CHECK_EQ("foo:method(arg: string): ()", toStringNamedFunction("foo:method", *ftv, opts)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index b710ea0..aa4ca41 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -878,8 +878,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_add_definitions_to_persistent_types") TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types") { ScopedFastFlag sff[]{ - {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions2", true}, {"LuauWidenIfSupertypeIsFree2", true}, }; @@ -899,8 +897,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types") TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types2") { ScopedFastFlag sff[]{ - {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions2", true}, {"LuauWidenIfSupertypeIsFree2", true}, }; @@ -916,11 +912,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types2") TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type") { - ScopedFastFlag sff[]{ - {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local function f(...: number?) return assert(...) @@ -933,11 +924,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_removes_falsy_types_even_from_type_pa TEST_CASE_FIXTURE(BuiltinsFixture, "assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy") { - ScopedFastFlag sff[]{ - {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local function f(x: nil) return assert(x, "hmm") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 14f1f70..a28ba49 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1496,8 +1496,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") { - ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; - CheckResult result = check(R"( local function f(x: any) end f() diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index de0c939..78a5fee 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -1121,4 +1121,78 @@ TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics1") +{ + ScopedFastFlag sff{"LuauApplyTypeFunctionFix", true}; + + // https://github.com/Roblox/luau/issues/484 + CheckResult result = check(R"( +--!strict +type MyObject = { + getReturnValue: (cb: () -> V) -> V +} +local object: MyObject = { + getReturnValue = function(cb: () -> U): U + return cb() + end, +} + +type ComplexObject = { + id: T, + nested: MyObject +} + +local complex: ComplexObject = { + id = "Foo", + nested = object, +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics2") +{ + ScopedFastFlag sff{"LuauApplyTypeFunctionFix", true}; + + // https://github.com/Roblox/luau/issues/484 + CheckResult result = check(R"( +--!strict +type MyObject = { + getReturnValue: (cb: () -> V) -> V +} +type ComplexObject = { + id: T, + nested: MyObject +} + +local complex2: ComplexObject = nil + +local x = complex2.nested.getReturnValue(function(): string + return "" +end) + +local y = complex2.nested.getReturnValue(function() + return 3 +end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_functions_even_if_they_have_an_explicit_generic") +{ + ScopedFastFlag sff[] = { + {"LuauAlwaysQuantify", true}, + }; + + CheckResult result = check(R"( + function foo(f, x: X) + return f(x) + end + )"); + + CHECK("((X) -> (a...), X) -> (a...)" == toString(requireType("foo"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 41bc0c2..f75b2d1 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -177,8 +177,6 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_property_guarante TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_works_at_arbitrary_depth") { - ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; - CheckResult result = check(R"( type A = {x: {y: {z: {thing: string}}}} type B = {x: {y: {z: {thing: string}}}} diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 765419c..a3cae3d 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -475,8 +475,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "loop_typecheck_crash_on_empty_optional") TEST_CASE_FIXTURE(Fixture, "fuzz_fail_missing_instantitation_follow") { - ScopedFastFlag luauInstantiateFollows{"LuauInstantiateFollows", true}; - // Just check that this doesn't assert check(R"( --!nonstrict diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 51f6fdf..0361493 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -728,8 +728,6 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_but_has_overlap") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: string | number, b: boolean | number) return a == b diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index ee3ae97..9d22789 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -7,7 +7,6 @@ #include -LUAU_FASTFLAG(LuauEqConstraint) LUAU_FASTFLAG(LuauLowerBoundsCalculation) using namespace Luau; @@ -183,8 +182,6 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") // We'll need to not only report an error on `a == b`, but also to refine both operands as `never` in the `==` branch. TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: string, b: boolean?) if a == b then @@ -208,8 +205,6 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") // Just needs to fully support equality refinement. Which is annoying without type states. TEST_CASE_FIXTURE(Fixture, "discriminate_from_x_not_equal_to_nil") { - ScopedFastFlag sff{"LuauDiscriminableUnions2", true}; - CheckResult result = check(R"( type T = {x: string, y: number} | {x: nil, y: nil} @@ -471,4 +466,35 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "function_returns_many_things_but_first_of_it CHECK_EQ("boolean", toString(requireType("b"))); } +TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauNormalizeFlagIsConservative", true}, + }; + + CheckResult result = check(R"( + local function f(o) + local t = {} + t[o] = true + + local function foo(o) + o:m1() + t[o] = nil + end + + local function bar(o) + o:m2() + t[o] = true + end + + return t + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + // TODO: We're missing generics a... and b... + CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 8c13049..85a3334 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -7,7 +7,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauDiscriminableUnions2) LUAU_FASTFLAG(LuauWeakEqConstraint) LUAU_FASTFLAG(LuauLowerBoundsCalculation) @@ -268,18 +267,10 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") end )"); - if (FFlag::LuauDiscriminableUnions2) - { - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({8, 44}))); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({9, 38}))); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); - } + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({8, 44}))); + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({9, 38}))); } TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") @@ -378,8 +369,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: (string | number)?, b: boolean?) if a == b then @@ -392,28 +381,15 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "nil"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "nil"); // a == b - - CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b - } + CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b } TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: (string | number)?) if a == 1 then @@ -426,24 +402,12 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1 - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number"); // a == 1 - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 - } + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1 + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 } TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") { - ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local function f(a: (string | number)?) if "hello" == a then @@ -462,8 +426,6 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: (string | number)?) if a ~= nil then @@ -476,21 +438,12 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "nil"); // a == nil - } + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil } TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") { - ScopedFastFlag sff{"LuauDiscriminableUnions2", true}; ScopedFastFlag sff2{"LuauWeakEqConstraint", true}; CheckResult result = check(R"( @@ -509,8 +462,6 @@ TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_equal") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local function f(a: any, b: {x: number}?) if a ~= b then @@ -521,22 +472,12 @@ TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_e LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}"); // a ~= b - } + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b } TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( local t: {string} = {"hello"} @@ -554,18 +495,8 @@ TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b - } - else - { - // This is technically not wrong, but it's also wrong at the same time. - // The refinement code is none the wiser about the fact we pulled a string out of an array, so it has no choice but to narrow as just string. - CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string"); // a == b - } + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b } TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") @@ -594,16 +525,7 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") end )"); - if (FFlag::LuauDiscriminableUnions2) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - // This is kinda weird to see, but this actually only happens in Luau without Roblox type bindings because we don't have a Vector3 type. - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Unknown type 'Vector3'", toString(result.errors[0])); - } + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); } @@ -1009,10 +931,6 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") { - ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( type T = {tag: "missing", x: nil} | {tag: "exists", x: string} @@ -1033,10 +951,6 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") TEST_CASE_FIXTURE(Fixture, "discriminate_tag") { - ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( type Cat = {tag: "Cat", name: string, catfood: string} type Dog = {tag: "Dog", name: string, dogfood: string} @@ -1070,11 +984,6 @@ TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauAssertStripsFalsyTypes", true}, - }; - CheckResult result = check(R"( local function is_true(b: true) end local function is_false(b: false) end @@ -1093,11 +1002,6 @@ TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauAssertStripsFalsyTypes", true}, - }; - CheckResult result = check(R"( type Ok = { ok: true, value: T } type Err = { ok: false, error: E } @@ -1117,8 +1021,6 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_ TEST_CASE_FIXTURE(Fixture, "refine_a_property_not_to_be_nil_through_an_intersection_table") { - ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; - CheckResult result = check(R"( type T = {} & {f: ((string) -> string)?} local function f(t: T, x) @@ -1133,10 +1035,6 @@ TEST_CASE_FIXTURE(Fixture, "refine_a_property_not_to_be_nil_through_an_intersect TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") { - ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( type T = {tag: "Part", x: Part} | {tag: "Folder", x: Folder} @@ -1171,14 +1069,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") end )"); - if (FFlag::LuauDiscriminableUnions2) - LUAU_REQUIRE_NO_ERRORS(result); - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); - } + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 79eeb82..d90dfbb 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -139,6 +139,8 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") { + ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; + CheckResult result = check(R"( type MyEnum = "foo" | "bar" | "baz" local a : MyEnum = "bang" @@ -325,8 +327,6 @@ local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_singleton") { ScopedFastFlag sff[]{ - {"LuauEqConstraint", true}, - {"LuauDiscriminableUnions2", true}, {"LuauWidenIfSupertypeIsFree2", true}, {"LuauWeakEqConstraint", false}, }; @@ -350,11 +350,8 @@ TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_si TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") { ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauEqConstraint", true}, {"LuauWidenIfSupertypeIsFree2", true}, {"LuauWeakEqConstraint", false}, - {"LuauDoNotAccidentallyDependOnPointerOrdering", true}, }; CheckResult result = check(R"( @@ -390,7 +387,6 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables") { ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, {"LuauWidenIfSupertypeIsFree2", true}, }; @@ -419,6 +415,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") { ScopedFastFlag sff[]{ {"LuauWidenIfSupertypeIsFree2", true}, + {"LuauWeakEqConstraint", true}, }; CheckResult result = check(R"( @@ -456,10 +453,6 @@ TEST_CASE_FIXTURE(Fixture, "functions_are_not_to_be_widened") TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local a: string = "hi" if a == "hi" then @@ -474,10 +467,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local a: string = "hi" if a == "hi" or a == "bye" then @@ -492,10 +481,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local a: string = "hi" if a == "hi" then @@ -510,10 +495,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") { - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - }; - CheckResult result = check(R"( local a: string = "hi" if a == "hi" or a == "bye" then diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 5078b0b..c924484 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2279,8 +2279,6 @@ local y = #x TEST_CASE_FIXTURE(BuiltinsFixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") { - ScopedFastFlag sff{"LuauTerminateCyclicMetatableIndexLookup", true}; - // t :: t1 where t1 = {metatable {__index: t1, __tostring: (t1) -> string}} CheckResult result = check(R"( local mt = {} @@ -2313,8 +2311,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "give_up_after_one_metatable_index_look_up") TEST_CASE_FIXTURE(Fixture, "confusing_indexing") { - ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; - CheckResult result = check(R"( type T = {} & {p: number | string} local function f(t: T) @@ -2971,8 +2967,6 @@ TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") { - ScopedFastFlag sff{"LuauCheckImplicitNumbericKeys", true}; - CheckResult result = check(R"( local t: { [string]: number } = { 5, 6, 7 } )"); @@ -2984,4 +2978,32 @@ TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[2])); } +TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") +{ + ScopedFastFlag luauExpectedPropTypeFromIndexer{"LuauExpectedPropTypeFromIndexer", true}; + ScopedFastFlag luauSubtypingAddOptPropsToUnsealedTables{"LuauSubtypingAddOptPropsToUnsealedTables", true}; + + CheckResult result = check(R"( + type X = { { x: boolean?, y: boolean? } } + + local l1: {[string]: X} = { key = { { x = true }, { y = true } } } + local l2: {[any]: X} = { key = { { x = true }, { y = true } } } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra_2") +{ + ScopedFastFlag luauExpectedPropTypeFromIndexer{"LuauExpectedPropTypeFromIndexer", true}; + + CheckResult result = check(R"( + type X = {[any]: string | boolean} + + local x: X = { key = "str" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 48cd1c3..1d144db 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -15,7 +15,6 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) -LUAU_FASTFLAG(LuauEqConstraint) using namespace Luau; @@ -308,7 +307,6 @@ TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_count") int limit = 600; #endif - ScopedFastFlag sff{"LuauTableUseCounterInstead", true}; ScopedFastInt sfi{"LuauCheckRecursionLimit", limit}; CheckResult result = check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"); @@ -1011,8 +1009,6 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") { - ScopedFastFlag substituteFollowNewTypes{"LuauSubstituteFollowNewTypes", true}; - CheckResult result = check(R"( local obj = {} diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 277f388..d19d80c 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -7,7 +7,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauLowerBoundsCalculation) -LUAU_FASTFLAG(LuauEqConstraint) using namespace Luau; From c4e05eb7c1d1a7753a54b01f2738c48c19a026d7 Mon Sep 17 00:00:00 2001 From: Rob Blanckaert Date: Thu, 26 May 2022 13:33:48 -0700 Subject: [PATCH 11/19] Sync to upstream/release/529 --- Analysis/include/Luau/Frontend.h | 16 +- Analysis/include/Luau/Instantiation.h | 53 + Analysis/include/Luau/TypeArena.h | 2 +- Analysis/include/Luau/TypeInfer.h | 39 - Analysis/include/Luau/Unifier.h | 3 + Analysis/include/Luau/UnifierSharedState.h | 1 - Analysis/include/Luau/VisitTypeVar.h | 196 --- Analysis/src/Autocomplete.cpp | 31 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 2 +- Analysis/src/Frontend.cpp | 108 +- Analysis/src/Instantiation.cpp | 128 ++ Analysis/src/Module.cpp | 12 +- Analysis/src/Normalize.cpp | 265 +--- Analysis/src/Quantify.cpp | 3 +- Analysis/src/ToString.cpp | 42 +- Analysis/src/TypeArena.cpp | 2 +- Analysis/src/TypeInfer.cpp | 233 +--- Analysis/src/Unifier.cpp | 172 +-- Ast/include/Luau/TimeTrace.h | 2 +- Ast/src/Parser.cpp | 13 + CMakeLists.txt | 15 +- CodeGen/include/Luau/AssemblyBuilderX64.h | 169 +++ CodeGen/include/Luau/Condition.h | 46 + CodeGen/include/Luau/Label.h | 18 + CodeGen/include/Luau/OperandX64.h | 136 +++ CodeGen/include/Luau/RegisterX64.h | 116 ++ CodeGen/src/AssemblyBuilderX64.cpp | 1003 ++++++++++++++++ {Compiler => Common}/include/Luau/Bytecode.h | 0 {Ast => Common}/include/Luau/Common.h | 0 Compiler/include/Luau/BytecodeBuilder.h | 2 +- Compiler/src/BytecodeBuilder.cpp | 124 +- Compiler/src/Compiler.cpp | 179 ++- Compiler/src/CostModel.cpp | 124 +- Compiler/src/CostModel.h | 3 + Makefile | 30 +- Sources.cmake | 25 +- VM/include/lua.h | 1 + VM/src/lapi.cpp | 24 +- VM/src/lbuiltins.cpp | 4 +- VM/src/lbytecode.h | 5 +- VM/src/lcommon.h | 6 +- VM/src/ldo.cpp | 17 +- VM/src/ldo.h | 1 + VM/src/lvmexecute.cpp | 22 +- tests/AssemblyBuilderX64.test.cpp | 410 +++++++ tests/Autocomplete.test.cpp | 22 +- tests/Compiler.test.cpp | 1129 +++++++++++------- tests/Conformance.test.cpp | 176 ++- tests/CostModel.test.cpp | 6 +- tests/JsonEncoder.test.cpp | 77 +- tests/NonstrictMode.test.cpp | 5 +- tests/Normalize.test.cpp | 36 +- tests/Parser.test.cpp | 11 + tests/RuntimeLimits.test.cpp | 2 +- tests/ToDot.test.cpp | 10 +- tests/TypeInfer.builtins.test.cpp | 14 +- tests/TypeInfer.generics.test.cpp | 2 +- tests/TypeInfer.intersectionTypes.test.cpp | 2 +- tests/TypeInfer.modules.test.cpp | 4 +- tests/TypeInfer.provisional.test.cpp | 5 +- tests/TypeInfer.refinements.test.cpp | 7 +- tests/TypeInfer.singletons.test.cpp | 27 +- tests/TypeInfer.tables.test.cpp | 37 +- tests/TypeInfer.test.cpp | 11 +- tests/TypeInfer.typePacks.cpp | 8 +- tests/TypeVar.test.cpp | 3 +- tests/VisitTypeVar.test.cpp | 2 - tests/conformance/errors.lua | 80 ++ tests/conformance/nextvar.lua | 25 + tests/conformance/userdata.lua | 45 + tools/natvis/CodeGen.natvis | 50 + tools/patchtests.py | 76 ++ 72 files changed, 3937 insertions(+), 1738 deletions(-) create mode 100644 Analysis/include/Luau/Instantiation.h create mode 100644 Analysis/src/Instantiation.cpp create mode 100644 CodeGen/include/Luau/AssemblyBuilderX64.h create mode 100644 CodeGen/include/Luau/Condition.h create mode 100644 CodeGen/include/Luau/Label.h create mode 100644 CodeGen/include/Luau/OperandX64.h create mode 100644 CodeGen/include/Luau/RegisterX64.h create mode 100644 CodeGen/src/AssemblyBuilderX64.cpp rename {Compiler => Common}/include/Luau/Bytecode.h (100%) rename {Ast => Common}/include/Luau/Common.h (100%) create mode 100644 tests/AssemblyBuilderX64.test.cpp create mode 100644 tests/conformance/userdata.lua create mode 100644 tools/natvis/CodeGen.natvis create mode 100644 tools/patchtests.py diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 37e3cfd..d7c9ca4 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -12,9 +12,6 @@ #include #include -LUAU_FASTFLAG(LuauSeparateTypechecks) -LUAU_FASTFLAG(LuauDirtySourceModule) - namespace Luau { @@ -60,17 +57,12 @@ struct SourceNode { bool hasDirtySourceModule() const { - LUAU_ASSERT(FFlag::LuauDirtySourceModule); - return dirtySourceModule; } bool hasDirtyModule(bool forAutocomplete) const { - if (FFlag::LuauSeparateTypechecks) - return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule; - else - return dirtyModule; + return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule; } ModuleName name; @@ -90,10 +82,6 @@ struct FrontendOptions // is complete. bool retainFullTypeGraphs = false; - // When true, we run typechecking twice, once in the regular mode, and once in strict mode - // in order to get more precise type information (e.g. for autocomplete). - bool typecheckTwice_DEPRECATED = false; - // Run typechecking only in mode required for autocomplete (strict mode in order to get more precise type information) bool forAutocomplete = false; }; @@ -171,7 +159,7 @@ struct Frontend void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); private: - std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete_DEPRECATED); + std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name); SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); bool parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root, bool forAutocomplete); diff --git a/Analysis/include/Luau/Instantiation.h b/Analysis/include/Luau/Instantiation.h new file mode 100644 index 0000000..e05ceeb --- /dev/null +++ b/Analysis/include/Luau/Instantiation.h @@ -0,0 +1,53 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Substitution.h" +#include "Luau/TypeVar.h" +#include "Luau/Unifiable.h" + +namespace Luau +{ + +struct TypeArena; +struct TxnLog; + +// A substitution which replaces generic types in a given set by free types. +struct ReplaceGenerics : Substitution +{ + ReplaceGenerics( + const TxnLog* log, TypeArena* arena, TypeLevel level, const std::vector& generics, const std::vector& genericPacks) + : Substitution(log, arena) + , level(level) + , generics(generics) + , genericPacks(genericPacks) + { + } + + TypeLevel level; + std::vector generics; + std::vector genericPacks; + bool ignoreChildren(TypeId ty) override; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +// A substitution which replaces generic functions by monomorphic functions +struct Instantiation : Substitution +{ + Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level) + : Substitution(log, arena) + , level(level) + { + } + + TypeLevel level; + bool ignoreChildren(TypeId ty) override; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index 7c74158..559c55c 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -39,4 +39,4 @@ struct TypeArena void freeze(TypeArena& arena); void unfreeze(TypeArena& arena); -} +} // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index fcaf5ba..183cc05 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -34,45 +34,6 @@ const AstStat* getFallthrough(const AstStat* node); struct UnifierOptions; struct Unifier; -// A substitution which replaces generic types in a given set by free types. -struct ReplaceGenerics : Substitution -{ - ReplaceGenerics( - const TxnLog* log, TypeArena* arena, TypeLevel level, const std::vector& generics, const std::vector& genericPacks) - : Substitution(log, arena) - , level(level) - , generics(generics) - , genericPacks(genericPacks) - { - } - - TypeLevel level; - std::vector generics; - std::vector genericPacks; - bool ignoreChildren(TypeId ty) override; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; -}; - -// A substitution which replaces generic functions by monomorphic functions -struct Instantiation : Substitution -{ - Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level) - : Substitution(log, arena) - , level(level) - { - } - - TypeLevel level; - bool ignoreChildren(TypeId ty) override; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; -}; - // A substitution which replaces free types by any struct Anyification : Substitution { diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 0e24c8b..627b52c 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -32,6 +32,9 @@ struct Widen : Substitution TypeId clean(TypeId ty) override; TypePackId clean(TypePackId ty) override; bool ignoreChildren(TypeId ty) override; + + TypeId operator()(TypeId ty); + TypePackId operator()(TypePackId ty); }; // TODO: Use this more widely. diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index 1a0b8b7..d4315d4 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -42,7 +42,6 @@ struct UnifierSharedState InternalErrorReporter* iceHandler; - DenseHashSet seenAny{nullptr}; DenseHashMap skipCacheForType{nullptr}; DenseHashSet, TypeIdPairHash> cachedUnify{{nullptr, nullptr}}; DenseHashMap, TypeErrorData, TypeIdPairHash> cachedUnifyError{{nullptr, nullptr}}; diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 2e98f52..f383991 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -8,7 +8,6 @@ #include "Luau/TypePack.h" #include "Luau/TypeVar.h" -LUAU_FASTFLAG(LuauUseVisitRecursionLimit) LUAU_FASTINT(LuauVisitRecursionLimit) LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) @@ -62,168 +61,6 @@ inline void unsee(DenseHashSet& seen, const void* tv) // When DenseHashSet is used for 'visitTypeVarOnce', where don't forget visited elements } -template -void visit(TypePackId tp, F& f, Set& seen); - -template -void visit(TypeId ty, F& f, Set& seen) -{ - if (visit_detail::hasSeen(seen, ty)) - { - f.cycle(ty); - return; - } - - if (auto btv = get(ty)) - { - if (apply(ty, *btv, seen, f)) - visit(btv->boundTo, f, seen); - } - - else if (auto ftv = get(ty)) - apply(ty, *ftv, seen, f); - - else if (auto gtv = get(ty)) - apply(ty, *gtv, seen, f); - - else if (auto etv = get(ty)) - apply(ty, *etv, seen, f); - - else if (auto ctv = get(ty)) - { - if (apply(ty, *ctv, seen, f)) - { - for (TypeId part : ctv->parts) - visit(part, f, seen); - } - } - - else if (auto ptv = get(ty)) - apply(ty, *ptv, seen, f); - - else if (auto ftv = get(ty)) - { - if (apply(ty, *ftv, seen, f)) - { - visit(ftv->argTypes, f, seen); - visit(ftv->retType, f, seen); - } - } - - else if (auto ttv = get(ty)) - { - // Some visitors want to see bound tables, that's why we visit the original type - if (apply(ty, *ttv, seen, f)) - { - if (ttv->boundTo) - { - visit(*ttv->boundTo, f, seen); - } - else - { - for (auto& [_name, prop] : ttv->props) - visit(prop.type, f, seen); - - if (ttv->indexer) - { - visit(ttv->indexer->indexType, f, seen); - visit(ttv->indexer->indexResultType, f, seen); - } - } - } - } - - else if (auto mtv = get(ty)) - { - if (apply(ty, *mtv, seen, f)) - { - visit(mtv->table, f, seen); - visit(mtv->metatable, f, seen); - } - } - - else if (auto ctv = get(ty)) - { - if (apply(ty, *ctv, seen, f)) - { - for (const auto& [name, prop] : ctv->props) - visit(prop.type, f, seen); - - if (ctv->parent) - visit(*ctv->parent, f, seen); - - if (ctv->metatable) - visit(*ctv->metatable, f, seen); - } - } - - else if (auto atv = get(ty)) - apply(ty, *atv, seen, f); - - else if (auto utv = get(ty)) - { - if (apply(ty, *utv, seen, f)) - { - for (TypeId optTy : utv->options) - visit(optTy, f, seen); - } - } - - else if (auto itv = get(ty)) - { - if (apply(ty, *itv, seen, f)) - { - for (TypeId partTy : itv->parts) - visit(partTy, f, seen); - } - } - - visit_detail::unsee(seen, ty); -} - -template -void visit(TypePackId tp, F& f, Set& seen) -{ - if (visit_detail::hasSeen(seen, tp)) - { - f.cycle(tp); - return; - } - - if (auto btv = get(tp)) - { - if (apply(tp, *btv, seen, f)) - visit(btv->boundTo, f, seen); - } - - else if (auto ftv = get(tp)) - apply(tp, *ftv, seen, f); - - else if (auto gtv = get(tp)) - apply(tp, *gtv, seen, f); - - else if (auto etv = get(tp)) - apply(tp, *etv, seen, f); - - else if (auto pack = get(tp)) - { - apply(tp, *pack, seen, f); - - for (TypeId ty : pack->head) - visit(ty, f, seen); - - if (pack->tail) - visit(*pack->tail, f, seen); - } - else if (auto pack = get(tp)) - { - apply(tp, *pack, seen, f); - visit(pack->ty, f, seen); - } - - visit_detail::unsee(seen, tp); -} - } // namespace visit_detail template @@ -513,37 +350,4 @@ struct TypeVarOnceVisitor : GenericTypeVarVisitor> } }; -// Clip with FFlagLuauUseVisitRecursionLimit -template -void DEPRECATED_visitTypeVar(TID ty, F& f, std::unordered_set& seen) -{ - visit_detail::visit(ty, f, seen); -} - -// Delete and inline when clipping FFlagLuauUseVisitRecursionLimit -template -void DEPRECATED_visitTypeVar(TID ty, F& f) -{ - if (FFlag::LuauUseVisitRecursionLimit) - f.traverse(ty); - else - { - std::unordered_set seen; - visit_detail::visit(ty, f, seen); - } -} - -// Delete and inline when clipping FFlagLuauUseVisitRecursionLimit -template -void DEPRECATED_visitTypeVarOnce(TID ty, F& f, DenseHashSet& seen) -{ - if (FFlag::LuauUseVisitRecursionLimit) - f.traverse(ty); - else - { - seen.clear(); - visit_detail::visit(ty, f, seen); - } -} - } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 19d06cf..b988ed3 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -1700,31 +1700,18 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) { - if (FFlag::LuauSeparateTypechecks) - { - // FIXME: We can improve performance here by parsing without checking. - // The old type graph is probably fine. (famous last words!) - FrontendOptions opts; - opts.forAutocomplete = true; - frontend.check(moduleName, opts); - } - else - { - // FIXME: We can improve performance here by parsing without checking. - // The old type graph is probably fine. (famous last words!) - // FIXME: We don't need to typecheck for script analysis here, just for autocomplete. - frontend.check(moduleName); - } + // FIXME: We can improve performance here by parsing without checking. + // The old type graph is probably fine. (famous last words!) + FrontendOptions opts; + opts.forAutocomplete = true; + frontend.check(moduleName, opts); const SourceModule* sourceModule = frontend.getSourceModule(moduleName); if (!sourceModule) return {}; - TypeChecker& typeChecker = - (frontend.options.typecheckTwice_DEPRECATED || FFlag::LuauSeparateTypechecks ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - ModulePtr module = - (frontend.options.typecheckTwice_DEPRECATED || FFlag::LuauSeparateTypechecks ? frontend.moduleResolverForAutocomplete.getModule(moduleName) - : frontend.moduleResolver.getModule(moduleName)); + TypeChecker& typeChecker = frontend.typeCheckerForAutocomplete; + ModulePtr module = frontend.moduleResolverForAutocomplete.getModule(moduleName); if (!module) return {}; @@ -1752,9 +1739,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view sourceModule->mode = Mode::Strict; sourceModule->commentLocations = std::move(result.commentLocations); - TypeChecker& typeChecker = - (frontend.options.typecheckTwice_DEPRECATED || FFlag::LuauSeparateTypechecks ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - + TypeChecker& typeChecker = frontend.typeCheckerForAutocomplete; ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); OwningAutocompleteResult autocompleteResult = { diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index be3fcd7..9a2259f 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -143,7 +143,7 @@ declare coroutine: { create: ((A...) -> R...) -> thread, resume: (thread, A...) -> (boolean, R...), running: () -> thread, - status: (thread) -> string, + status: (thread) -> "dead" | "running" | "normal" | "suspended", -- FIXME: This technically returns a function, but we can't represent this yet. wrap: ((A...) -> R...) -> any, yield: (A...) -> R..., diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 56c0ac2..1d33f13 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -20,9 +20,7 @@ LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauSeparateTypechecks, false) LUAU_FASTFLAGVARIABLE(LuauAutocompleteDynamicLimits, false) -LUAU_FASTFLAGVARIABLE(LuauDirtySourceModule, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) namespace Luau @@ -361,32 +359,21 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalsecond.hasDirtyModule(frontendOptions.forAutocomplete)) { // No recheck required. - if (FFlag::LuauSeparateTypechecks) + if (frontendOptions.forAutocomplete) { - if (frontendOptions.forAutocomplete) - { - auto it2 = moduleResolverForAutocomplete.modules.find(name); - if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr) - throw std::runtime_error("Frontend::modules does not have data for " + name); - } - else - { - auto it2 = moduleResolver.modules.find(name); - if (it2 == moduleResolver.modules.end() || it2->second == nullptr) - throw std::runtime_error("Frontend::modules does not have data for " + name); - } - - return CheckResult{accumulateErrors( - sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)}; + auto it2 = moduleResolverForAutocomplete.modules.find(name); + if (it2 == moduleResolverForAutocomplete.modules.end() || it2->second == nullptr) + throw std::runtime_error("Frontend::modules does not have data for " + name); } else { auto it2 = moduleResolver.modules.find(name); if (it2 == moduleResolver.modules.end() || it2->second == nullptr) throw std::runtime_error("Frontend::modules does not have data for " + name); - - return CheckResult{accumulateErrors(sourceNodes, moduleResolver.modules, name)}; } + + return CheckResult{ + accumulateErrors(sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)}; } std::vector buildQueue; @@ -428,7 +415,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& buildQueue, CheckResult& chec bool cyclic = false; { - auto [sourceNode, _] = getSourceNode(checkResult, root, forAutocomplete); + auto [sourceNode, _] = getSourceNode(checkResult, root); if (sourceNode) stack.push_back(sourceNode); } @@ -627,7 +603,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec } } - auto [sourceNode, _] = getSourceNode(checkResult, dep, forAutocomplete); + auto [sourceNode, _] = getSourceNode(checkResult, dep); if (sourceNode) { stack.push_back(sourceNode); @@ -671,7 +647,7 @@ LintResult Frontend::lint(const ModuleName& name, std::optional* markedDirty) { - if (FFlag::LuauSeparateTypechecks) - { - if (!moduleResolver.modules.count(name) && !moduleResolverForAutocomplete.modules.count(name)) - return; - } - else - { - if (!moduleResolver.modules.count(name)) - return; - } + if (!moduleResolver.modules.count(name) && !moduleResolverForAutocomplete.modules.count(name)) + return; std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) @@ -783,32 +751,12 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked if (markedDirty) markedDirty->push_back(next); - if (FFlag::LuauDirtySourceModule) - { - LUAU_ASSERT(FFlag::LuauSeparateTypechecks); + if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) + continue; - if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) - continue; - - sourceNode.dirtySourceModule = true; - sourceNode.dirtyModule = true; - sourceNode.dirtyModuleForAutocomplete = true; - } - else if (FFlag::LuauSeparateTypechecks) - { - if (sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) - continue; - - sourceNode.dirtyModule = true; - sourceNode.dirtyModuleForAutocomplete = true; - } - else - { - if (sourceNode.dirtyModule) - continue; - - sourceNode.dirtyModule = true; - } + sourceNode.dirtySourceModule = true; + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; if (0 == reverseDeps.count(name)) continue; @@ -835,14 +783,13 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons } // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. -std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name, bool forAutocomplete_DEPRECATED) +std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name) { LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); auto it = sourceNodes.find(name); - if (it != sourceNodes.end() && - (FFlag::LuauDirtySourceModule ? !it->second.hasDirtySourceModule() : !it->second.hasDirtyModule(forAutocomplete_DEPRECATED))) + if (it != sourceNodes.end() && !it->second.hasDirtySourceModule()) { auto moduleIt = sourceModules.find(name); if (moduleIt != sourceModules.end()) @@ -885,21 +832,12 @@ std::pair Frontend::getSourceNode(CheckResult& check sourceNode.name = name; sourceNode.requires.clear(); sourceNode.requireLocations.clear(); + sourceNode.dirtySourceModule = false; - if (FFlag::LuauDirtySourceModule) - sourceNode.dirtySourceModule = false; - - if (FFlag::LuauSeparateTypechecks) - { - if (it == sourceNodes.end()) - { - sourceNode.dirtyModule = true; - sourceNode.dirtyModuleForAutocomplete = true; - } - } - else + if (it == sourceNodes.end()) { sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; } for (const auto& [moduleName, location] : requireTrace.requires) diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp new file mode 100644 index 0000000..4a12027 --- /dev/null +++ b/Analysis/src/Instantiation.cpp @@ -0,0 +1,128 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Common.h" +#include "Luau/Instantiation.h" +#include "Luau/TxnLog.h" +#include "Luau/TypeArena.h" + +LUAU_FASTFLAG(LuauNoMethodLocations) + +namespace Luau +{ + +bool Instantiation::isDirty(TypeId ty) +{ + if (const FunctionTypeVar* ftv = log->getMutable(ty)) + { + if (ftv->hasNoGenerics) + return false; + + return true; + } + else + { + return false; + } +} + +bool Instantiation::isDirty(TypePackId tp) +{ + return false; +} + +bool Instantiation::ignoreChildren(TypeId ty) +{ + if (log->getMutable(ty)) + return true; + else + return false; +} + +TypeId Instantiation::clean(TypeId ty) +{ + const FunctionTypeVar* ftv = log->getMutable(ty); + LUAU_ASSERT(ftv); + + FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; + clone.magicFunction = ftv->magicFunction; + clone.tags = ftv->tags; + clone.argNames = ftv->argNames; + TypeId result = addType(std::move(clone)); + + // Annoyingly, we have to do this even if there are no generics, + // to replace any generic tables. + ReplaceGenerics replaceGenerics{log, arena, level, ftv->generics, ftv->genericPacks}; + + // TODO: What to do if this returns nullopt? + // We don't have access to the error-reporting machinery + result = replaceGenerics.substitute(result).value_or(result); + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; +} + +TypePackId Instantiation::clean(TypePackId tp) +{ + LUAU_ASSERT(false); + return tp; +} + +bool ReplaceGenerics::ignoreChildren(TypeId ty) +{ + if (const FunctionTypeVar* ftv = log->getMutable(ty)) + { + if (ftv->hasNoGenerics) + return true; + + // We aren't recursing in the case of a generic function which + // binds the same generics. This can happen if, for example, there's recursive types. + // If T = (a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'. + // It's OK to use vector equality here, since we always generate fresh generics + // whenever we quantify, so the vectors overlap if and only if they are equal. + return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); + } + else + { + return false; + } +} + +bool ReplaceGenerics::isDirty(TypeId ty) +{ + if (const TableTypeVar* ttv = log->getMutable(ty)) + return ttv->state == TableState::Generic; + else if (log->getMutable(ty)) + return std::find(generics.begin(), generics.end(), ty) != generics.end(); + else + return false; +} + +bool ReplaceGenerics::isDirty(TypePackId tp) +{ + if (log->getMutable(tp)) + return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end(); + else + return false; +} + +TypeId ReplaceGenerics::clean(TypeId ty) +{ + LUAU_ASSERT(isDirty(ty)); + if (const TableTypeVar* ttv = log->getMutable(ty)) + { + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; + if (!FFlag::LuauNoMethodLocations) + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + clone.definitionModuleName = ttv->definitionModuleName; + return addType(std::move(clone)); + } + else + return addType(FreeTypeVar{level}); +} + +TypePackId ReplaceGenerics::clean(TypePackId tp) +{ + LUAU_ASSERT(isDirty(tp)); + return addTypePack(TypePackVar(FreeTypePack{level})); +} + +} // namespace Luau diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 074a41e..6591d60 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -56,8 +56,18 @@ bool isWithinComment(const SourceModule& sourceModule, Position pos) struct ForceNormal : TypeVarOnceVisitor { + const TypeArena* typeArena = nullptr; + + ForceNormal(const TypeArena* typeArena) + : typeArena(typeArena) + { + } + bool visit(TypeId ty) override { + if (ty->owningArena != typeArena) + return false; + asMutable(ty)->normal = true; return true; } @@ -100,7 +110,7 @@ void Module::clonePublicInterface(InternalErrorReporter& ice) normalize(*moduleScope->varargPack, interfaceTypes, ice); } - ForceNormal forceNormal; + ForceNormal forceNormal{&interfaceTypes}; for (auto& [name, tf] : moduleScope->exportedTypeBindings) { diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 30fd4af..fb31df1 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false); namespace Luau { @@ -325,245 +326,6 @@ struct Normalize final : TypeVarVisitor int iterationLimit = 0; bool limitExceeded = false; - // TODO: Clip with FFlag::LuauUseVisitRecursionLimit - bool operator()(TypeId ty, const BoundTypeVar& btv, std::unordered_set& seen) - { - // A type could be considered normal when it is in the stack, but we will eventually find out it is not normal as normalization progresses. - // So we need to avoid eagerly saying that this bound type is normal if the thing it is bound to is in the stack. - if (seen.find(asMutable(btv.boundTo)) != seen.end()) - return false; - - // It should never be the case that this TypeVar is normal, but is bound to a non-normal type, except in nontrivial cases. - LUAU_ASSERT(!ty->normal || ty->normal == btv.boundTo->normal); - - asMutable(ty)->normal = btv.boundTo->normal; - return !ty->normal; - } - - bool operator()(TypeId ty, const FreeTypeVar& ftv) - { - return visit(ty, ftv); - } - bool operator()(TypeId ty, const PrimitiveTypeVar& ptv) - { - return visit(ty, ptv); - } - bool operator()(TypeId ty, const GenericTypeVar& gtv) - { - return visit(ty, gtv); - } - bool operator()(TypeId ty, const ErrorTypeVar& etv) - { - return visit(ty, etv); - } - bool operator()(TypeId ty, const ConstrainedTypeVar& ctvRef, std::unordered_set& seen) - { - CHECK_ITERATION_LIMIT(false); - - ConstrainedTypeVar* ctv = const_cast(&ctvRef); - - std::vector parts = std::move(ctv->parts); - - // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar - for (TypeId part : parts) - visit_detail::visit(part, *this, seen); - - std::vector newParts = normalizeUnion(parts); - - const bool normal = areNormal(newParts, seen, ice); - - if (newParts.size() == 1) - *asMutable(ty) = BoundTypeVar{newParts[0]}; - else - *asMutable(ty) = UnionTypeVar{std::move(newParts)}; - - asMutable(ty)->normal = normal; - - return false; - } - - bool operator()(TypeId ty, const FunctionTypeVar& ftv, std::unordered_set& seen) - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - visit_detail::visit(ftv.argTypes, *this, seen); - visit_detail::visit(ftv.retType, *this, seen); - - asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retType, seen, ice); - - return false; - } - - bool operator()(TypeId ty, const TableTypeVar& ttv, std::unordered_set& seen) - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - bool normal = true; - - auto checkNormal = [&](TypeId t) { - // if t is on the stack, it is possible that this type is normal. - // If t is not normal and it is not on the stack, this type is definitely not normal. - if (!t->normal && seen.find(asMutable(t)) == seen.end()) - normal = false; - }; - - if (ttv.boundTo) - { - visit_detail::visit(*ttv.boundTo, *this, seen); - asMutable(ty)->normal = (*ttv.boundTo)->normal; - return false; - } - - for (const auto& [_name, prop] : ttv.props) - { - visit_detail::visit(prop.type, *this, seen); - checkNormal(prop.type); - } - - if (ttv.indexer) - { - visit_detail::visit(ttv.indexer->indexType, *this, seen); - checkNormal(ttv.indexer->indexType); - visit_detail::visit(ttv.indexer->indexResultType, *this, seen); - checkNormal(ttv.indexer->indexResultType); - } - - asMutable(ty)->normal = normal; - - return false; - } - - bool operator()(TypeId ty, const MetatableTypeVar& mtv, std::unordered_set& seen) - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - visit_detail::visit(mtv.table, *this, seen); - visit_detail::visit(mtv.metatable, *this, seen); - - asMutable(ty)->normal = mtv.table->normal && mtv.metatable->normal; - - return false; - } - - bool operator()(TypeId ty, const ClassTypeVar& ctv) - { - return visit(ty, ctv); - } - bool operator()(TypeId ty, const AnyTypeVar& atv) - { - return visit(ty, atv); - } - bool operator()(TypeId ty, const UnionTypeVar& utvRef, std::unordered_set& seen) - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - UnionTypeVar* utv = &const_cast(utvRef); - std::vector options = std::move(utv->options); - - // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar - for (TypeId option : options) - visit_detail::visit(option, *this, seen); - - std::vector newOptions = normalizeUnion(options); - - const bool normal = areNormal(newOptions, seen, ice); - - LUAU_ASSERT(!newOptions.empty()); - - if (newOptions.size() == 1) - *asMutable(ty) = BoundTypeVar{newOptions[0]}; - else - utv->options = std::move(newOptions); - - asMutable(ty)->normal = normal; - - return false; - } - - bool operator()(TypeId ty, const IntersectionTypeVar& itvRef, std::unordered_set& seen) - { - CHECK_ITERATION_LIMIT(false); - - if (ty->normal) - return false; - - IntersectionTypeVar* itv = &const_cast(itvRef); - - std::vector oldParts = std::move(itv->parts); - - for (TypeId part : oldParts) - visit_detail::visit(part, *this, seen); - - std::vector tables; - for (TypeId part : oldParts) - { - part = follow(part); - if (get(part)) - tables.push_back(part); - else - { - Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD - combineIntoIntersection(replacer, itv, part); - } - } - - // Don't allocate a new table if there's just one in the intersection. - if (tables.size() == 1) - itv->parts.push_back(tables[0]); - else if (!tables.empty()) - { - const TableTypeVar* first = get(tables[0]); - LUAU_ASSERT(first); - - TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); - TableTypeVar* ttv = getMutable(newTable); - for (TypeId part : tables) - { - // Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need - // to be rewritten to point at 'newTable' in the clone. - Replacer replacer{&arena, part, newTable}; - combineIntoTable(replacer, ttv, part); - } - - itv->parts.push_back(newTable); - } - - asMutable(ty)->normal = areNormal(itv->parts, seen, ice); - - if (itv->parts.size() == 1) - { - TypeId part = itv->parts[0]; - *asMutable(ty) = BoundTypeVar{part}; - } - - return false; - } - - // TODO: Clip with FFlag::LuauUseVisitRecursionLimit - template - bool operator()(TypePackId, const T&) - { - return true; - } - - // TODO: Clip with FFlag::LuauUseVisitRecursionLimit - template - void cycle(TID) - { - } - bool visit(TypeId ty, const FreeTypeVar&) override { LUAU_ASSERT(!ty->normal); @@ -968,6 +730,9 @@ struct Normalize final : TypeVarVisitor */ TypeId combine(Replacer& replacer, TypeId a, TypeId b) { + if (FFlag::LuauNormalizeCombineEqFix) + b = follow(b); + if (FFlag::LuauNormalizeCombineTableFix && a == b) return a; @@ -986,7 +751,7 @@ struct Normalize final : TypeVarVisitor } else if (auto ttv = getMutable(a)) { - if (FFlag::LuauNormalizeCombineTableFix && !get(follow(b))) + if (FFlag::LuauNormalizeCombineTableFix && !get(FFlag::LuauNormalizeCombineEqFix ? b : follow(b))) return arena.addType(IntersectionTypeVar{{a, b}}); combineIntoTable(replacer, ttv, b); return a; @@ -1009,15 +774,7 @@ std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorRepo (void)clone(ty, arena, state); Normalize n{arena, ice}; - if (FFlag::LuauNormalizeFlagIsConservative) - { - DEPRECATED_visitTypeVar(ty, n); - } - else - { - std::unordered_set seen; - DEPRECATED_visitTypeVar(ty, n, seen); - } + n.traverse(ty); return {ty, !n.limitExceeded}; } @@ -1041,15 +798,7 @@ std::pair normalize(TypePackId tp, TypeArena& arena, InternalE (void)clone(tp, arena, state); Normalize n{arena, ice}; - if (FFlag::LuauNormalizeFlagIsConservative) - { - DEPRECATED_visitTypeVar(tp, n); - } - else - { - std::unordered_set seen; - DEPRECATED_visitTypeVar(tp, n, seen); - } + n.traverse(tp); return {tp, !n.limitExceeded}; } diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 018d563..c0f677d 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -119,8 +119,7 @@ struct Quantifier final : TypeVarOnceVisitor void quantify(TypeId ty, TypeLevel level) { Quantifier q{level}; - DenseHashSet seen{nullptr}; - DEPRECATED_visitTypeVarOnce(ty, q, seen); + q.traverse(ty); FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index f90f701..a4a3ec4 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -48,46 +48,6 @@ struct FindCyclicTypes final : TypeVarVisitor cycleTPs.insert(tp); } - // TODO: Clip all the operator()s when we clip FFlagLuauUseVisitRecursionLimit - - template - bool operator()(TypeId ty, const T&) - { - return visit(ty); - } - - bool operator()(TypeId ty, const TableTypeVar& ttv) = delete; - - bool operator()(TypeId ty, const TableTypeVar& ttv, std::unordered_set& seen) - { - if (!visited.insert(ty).second) - return false; - - if (ttv.name || ttv.syntheticName) - { - for (TypeId itp : ttv.instantiatedTypeParams) - DEPRECATED_visitTypeVar(itp, *this, seen); - - for (TypePackId itp : ttv.instantiatedTypePackParams) - DEPRECATED_visitTypeVar(itp, *this, seen); - - return exhaustive; - } - - return true; - } - - bool operator()(TypeId, const ClassTypeVar&) - { - return false; - } - - template - bool operator()(TypePackId tp, const T&) - { - return visit(tp); - } - bool visit(TypeId ty) override { return visited.insert(ty).second; @@ -128,7 +88,7 @@ void findCyclicTypes(std::set& cycles, std::set& cycleTPs, T { FindCyclicTypes fct; fct.exhaustive = exhaustive; - DEPRECATED_visitTypeVar(ty, fct); + fct.traverse(ty); cycles = std::move(fct.cycles); cycleTPs = std::move(fct.cycleTPs); diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp index 673b002..0c89d13 100644 --- a/Analysis/src/TypeArena.cpp +++ b/Analysis/src/TypeArena.cpp @@ -85,4 +85,4 @@ void unfreeze(TypeArena& arena) arena.typePacks.unfreeze(); } -} +} // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 208b3f2..11813c7 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -3,6 +3,7 @@ #include "Luau/Clone.h" #include "Luau/Common.h" +#include "Luau/Instantiation.h" #include "Luau/ModuleResolver.h" #include "Luau/Normalize.h" #include "Luau/Parser.h" @@ -10,13 +11,13 @@ #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/Substitution.h" -#include "Luau/TopoSortStatements.h" -#include "Luau/TypePack.h" -#include "Luau/ToString.h" -#include "Luau/TypeUtils.h" -#include "Luau/ToString.h" -#include "Luau/TypeVar.h" #include "Luau/TimeTrace.h" +#include "Luau/TopoSortStatements.h" +#include "Luau/ToString.h" +#include "Luau/ToString.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/TypeVar.h" #include #include @@ -26,14 +27,11 @@ LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) -LUAU_FASTFLAGVARIABLE(LuauUseVisitRecursionLimit, false) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAG(LuauSeparateTypechecks) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAGVARIABLE(LuauDoNotRelyOnNextBinding, false) LUAU_FASTFLAGVARIABLE(LuauExpectedPropTypeFromIndexer, false) -LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) @@ -43,7 +41,6 @@ LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) -LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree2) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); LUAU_FASTFLAGVARIABLE(LuauApplyTypeFunctionFix, false); @@ -51,6 +48,7 @@ LUAU_FASTFLAGVARIABLE(LuauTypecheckIter, false); LUAU_FASTFLAGVARIABLE(LuauSuccessTypingForEqualityOperations, false) LUAU_FASTFLAGVARIABLE(LuauNoMethodLocations, false); LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); +LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) namespace Luau { @@ -305,12 +303,8 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo currentModule.reset(new Module()); currentModule->type = module.type; - - if (FFlag::LuauSeparateTypechecks) - { - currentModule->allocator = module.allocator; - currentModule->names = module.names; - } + currentModule->allocator = module.allocator; + currentModule->names = module.names; iceHandler->moduleName = module.name; @@ -338,21 +332,14 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo if (prepareModuleScope) prepareModuleScope(module.name, currentModule->getModuleScope()); - if (FFlag::LuauSeparateTypechecks) - { - try - { - checkBlock(moduleScope, *module.root); - } - catch (const TimeLimitError&) - { - currentModule->timeout = true; - } - } - else + try { checkBlock(moduleScope, *module.root); } + catch (const TimeLimitError&) + { + currentModule->timeout = true; + } if (get(follow(moduleScope->returnType))) moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); @@ -443,7 +430,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) else ice("Unknown AstStat"); - if (FFlag::LuauSeparateTypechecks && finishTime && TimeTrace::getClock() > *finishTime) + if (finishTime && TimeTrace::getClock() > *finishTime) throw TimeLimitError(); } @@ -868,9 +855,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) TypeId right = nullptr; - Location loc = 0 == assign.values.size - ? assign.location - : i < assign.values.size ? assign.values.data[i]->location : assign.values.data[assign.values.size - 1]->location; + Location loc = 0 == assign.values.size ? assign.location + : i < assign.values.size ? assign.values.data[i]->location + : assign.values.data[assign.values.size - 1]->location; if (valueIter != valueEnd) { @@ -1825,7 +1812,7 @@ std::optional TypeChecker::findMetatableEntry(TypeId type, std::string e } std::optional TypeChecker::getIndexTypeFromType( - const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) + const ScopePtr& scope, TypeId type, const std::string& name, const Location& location, bool addErrors) { type = follow(type); @@ -1843,13 +1830,25 @@ std::optional TypeChecker::getIndexTypeFromType( if (TableTypeVar* tableType = getMutableTableType(type)) { - const auto& it = tableType->props.find(name); - if (it != tableType->props.end()) + if (auto it = tableType->props.find(name); it != tableType->props.end()) return it->second.type; else if (auto indexer = tableType->indexer) { - tryUnify(stringType, indexer->indexType, location); - return indexer->indexResultType; + // TODO: Property lookup should work with string singletons or unions thereof as the indexer key type. + ErrorVec errors = tryUnify(stringType, indexer->indexType, location); + + if (FFlag::LuauReportErrorsOnIndexerKeyMismatch) + { + if (errors.empty()) + return indexer->indexResultType; + + if (addErrors) + reportError(location, UnknownProperty{type, name}); + + return std::nullopt; + } + else + return indexer->indexResultType; } else if (tableType->state == TableState::Free) { @@ -1858,8 +1857,7 @@ std::optional TypeChecker::getIndexTypeFromType( return result; } - auto found = findTablePropertyRespectingMeta(type, name, location); - if (found) + if (auto found = findTablePropertyRespectingMeta(type, name, location)) return *found; } else if (const ClassTypeVar* cls = get(type)) @@ -2512,8 +2510,9 @@ TypeId TypeChecker::checkRelationalOperation( if (!matches) { - reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + reportError( + expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); return errorRecoveryType(booleanType); } } @@ -2522,8 +2521,9 @@ TypeId TypeChecker::checkRelationalOperation( { if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) { - reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + reportError( + expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); return errorRecoveryType(booleanType); } } @@ -3636,10 +3636,7 @@ void TypeChecker::checkArgumentList( } TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); - if (FFlag::LuauWidenIfSupertypeIsFree2) - state.tryUnify(varPack, tail); - else - state.tryUnify(tail, varPack); + state.tryUnify(varPack, tail); return; } @@ -3707,7 +3704,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } TypePackId retPack; - if (FFlag::LuauLowerBoundsCalculation || !FFlag::LuauWidenIfSupertypeIsFree2) + if (FFlag::LuauLowerBoundsCalculation) { retPack = freshTypePack(scope->level); } @@ -3868,9 +3865,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope Widen widen{¤tModule->internalTypes}; for (; it != endIt; ++it) { - TypeId t = *it; - TypeId widened = widen.substitute(t).value_or(t); // Surely widening is infallible - adjustedArgTypes.push_back(addType(ConstrainedTypeVar{level, {widened}})); + adjustedArgTypes.push_back(addType(ConstrainedTypeVar{level, {widen(*it)}})); } TypePackId adjustedArgPack = addTypePack(TypePack{std::move(adjustedArgTypes), it.tail()}); @@ -3885,14 +3880,11 @@ std::optional> TypeChecker::checkCallOverload(const Scope else { TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - if (FFlag::LuauWidenIfSupertypeIsFree2) - { - UnifierOptions options; - options.isFunctionCall = true; - unify(r, fn, expr.location, options); - } - else - unify(fn, r, expr.location); + + UnifierOptions options; + options.isFunctionCall = true; + unify(r, fn, expr.location, options); + return {{retPack}}; } } @@ -4375,122 +4367,6 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s } } -bool Instantiation::isDirty(TypeId ty) -{ - if (const FunctionTypeVar* ftv = log->getMutable(ty)) - { - if (ftv->hasNoGenerics) - return false; - - return true; - } - else - { - return false; - } -} - -bool Instantiation::isDirty(TypePackId tp) -{ - return false; -} - -bool Instantiation::ignoreChildren(TypeId ty) -{ - if (log->getMutable(ty)) - return true; - else - return false; -} - -TypeId Instantiation::clean(TypeId ty) -{ - const FunctionTypeVar* ftv = log->getMutable(ty); - LUAU_ASSERT(ftv); - - FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; - clone.magicFunction = ftv->magicFunction; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - TypeId result = addType(std::move(clone)); - - // Annoyingly, we have to do this even if there are no generics, - // to replace any generic tables. - ReplaceGenerics replaceGenerics{log, arena, level, ftv->generics, ftv->genericPacks}; - - // TODO: What to do if this returns nullopt? - // We don't have access to the error-reporting machinery - result = replaceGenerics.substitute(result).value_or(result); - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; -} - -TypePackId Instantiation::clean(TypePackId tp) -{ - LUAU_ASSERT(false); - return tp; -} - -bool ReplaceGenerics::ignoreChildren(TypeId ty) -{ - if (const FunctionTypeVar* ftv = log->getMutable(ty)) - { - if (ftv->hasNoGenerics) - return true; - - // We aren't recursing in the case of a generic function which - // binds the same generics. This can happen if, for example, there's recursive types. - // If T = (a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'. - // It's OK to use vector equality here, since we always generate fresh generics - // whenever we quantify, so the vectors overlap if and only if they are equal. - return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); - } - else - { - return false; - } -} - -bool ReplaceGenerics::isDirty(TypeId ty) -{ - if (const TableTypeVar* ttv = log->getMutable(ty)) - return ttv->state == TableState::Generic; - else if (log->getMutable(ty)) - return std::find(generics.begin(), generics.end(), ty) != generics.end(); - else - return false; -} - -bool ReplaceGenerics::isDirty(TypePackId tp) -{ - if (log->getMutable(tp)) - return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end(); - else - return false; -} - -TypeId ReplaceGenerics::clean(TypeId ty) -{ - LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = log->getMutable(ty)) - { - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - return addType(std::move(clone)); - } - else - return addType(FreeTypeVar{level}); -} - -TypePackId ReplaceGenerics::clean(TypePackId tp) -{ - LUAU_ASSERT(isDirty(tp)); - return addTypePack(TypePackVar(FreeTypePack{level})); -} - bool Anyification::isDirty(TypeId ty) { if (ty->persistent) @@ -5295,7 +5171,7 @@ TypeId ApplyTypeFunction::clean(TypeId ty) { TypeId& arg = typeArguments[ty]; if (FFlag::LuauApplyTypeFunctionFix) - { + { LUAU_ASSERT(arg); return arg; } @@ -5309,7 +5185,7 @@ TypePackId ApplyTypeFunction::clean(TypePackId tp) { TypePackId& arg = typePackArguments[tp]; if (FFlag::LuauApplyTypeFunctionFix) - { + { LUAU_ASSERT(arg); return arg; } @@ -5837,9 +5713,6 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. auto predicate = [&](TypeId option) -> std::optional { - if (sense && isUndecidable(option)) - return FFlag::LuauWeakEqConstraint ? option : eqP.type; - if (!sense && isNil(eqP.type)) return (isUndecidable(option) || !isNil(option)) ? std::optional(option) : std::nullopt; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 9308e9f..414b05f 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -21,8 +21,6 @@ LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) -LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) -LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter2, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) namespace Luau @@ -149,8 +147,7 @@ static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel return; PromoteTypeLevels ptl{log, typeArena, minLevel}; - DenseHashSet seen{nullptr}; - DEPRECATED_visitTypeVarOnce(ty, ptl, seen); + ptl.traverse(ty); } void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) @@ -160,8 +157,7 @@ void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLev return; PromoteTypeLevels ptl{log, typeArena, minLevel}; - DenseHashSet seen{nullptr}; - DEPRECATED_visitTypeVarOnce(tp, ptl, seen); + ptl.traverse(tp); } struct SkipCacheForType final : TypeVarOnceVisitor @@ -172,49 +168,6 @@ struct SkipCacheForType final : TypeVarOnceVisitor { } - // TODO cycle() and operator() can be clipped with FFlagLuauUseVisitRecursionLimit - void cycle(TypeId) override {} - void cycle(TypePackId) override {} - - bool operator()(TypeId ty, const FreeTypeVar& ftv) - { - return visit(ty, ftv); - } - bool operator()(TypeId ty, const BoundTypeVar& btv) - { - return visit(ty, btv); - } - bool operator()(TypeId ty, const GenericTypeVar& gtv) - { - return visit(ty, gtv); - } - bool operator()(TypeId ty, const TableTypeVar& ttv) - { - return visit(ty, ttv); - } - bool operator()(TypePackId tp, const FreeTypePack& ftp) - { - return visit(tp, ftp); - } - bool operator()(TypePackId tp, const BoundTypePack& ftp) - { - return visit(tp, ftp); - } - bool operator()(TypePackId tp, const GenericTypePack& ftp) - { - return visit(tp, ftp); - } - template - bool operator()(TypeId ty, const T& t) - { - return visit(ty); - } - template - bool operator()(TypePackId tp, const T&) - { - return visit(tp); - } - bool visit(TypeId, const FreeTypeVar&) override { result = true; @@ -341,6 +294,16 @@ bool Widen::ignoreChildren(TypeId ty) return !log->is(ty); } +TypeId Widen::operator()(TypeId ty) +{ + return substitute(ty).value_or(ty); +} + +TypePackId Widen::operator()(TypePackId tp) +{ + return substitute(tp).value_or(tp); +} + static std::optional hasUnificationTooComplex(const ErrorVec& errors) { auto isUnificationTooComplex = [](const TypeError& te) { @@ -475,6 +438,8 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (!occursFailed) { promoteTypeLevels(log, types, superLevel, subTy); + + Widen widen{types}; log.replace(superTy, BoundTypeVar(widen(subTy))); } @@ -612,9 +577,6 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId std::optional unificationTooComplex; std::optional firstFailedOption; - size_t count = uv->options.size(); - size_t i = 0; - for (TypeId type : uv->options) { Unifier innerState = makeChildUnifier(); @@ -630,60 +592,44 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId failed = true; } - - if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter2) - { - } - else - { - if (i == count - 1) - { - log.concat(std::move(innerState.log)); - } - - ++i; - } } // even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option. - if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter2) - { - auto tryBind = [this, subTy](TypeId superOption) { - superOption = log.follow(superOption); + auto tryBind = [this, subTy](TypeId superOption) { + superOption = log.follow(superOption); - // just skip if the superOption is not free-ish. - auto ttv = log.getMutable(superOption); - if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) - return; + // just skip if the superOption is not free-ish. + auto ttv = log.getMutable(superOption); + if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) + return; - // If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype - // test is successful. - if (auto subUnion = get(subTy)) - { - if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption)) - return; - } - - // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. - // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. - if (log.haveSeen(subTy, superOption)) - { - // TODO: would it be nice for TxnLog::replace to do this? - if (log.is(superOption)) - log.bindTable(superOption, subTy); - else - log.replace(superOption, *subTy); - } - }; - - if (auto utv = log.getMutable(superTy)) + // If superOption is already present in subTy, do nothing. Nothing new has been learned, but the subtype + // test is successful. + if (auto subUnion = get(subTy)) { - for (TypeId ty : utv) - tryBind(ty); + if (end(subUnion) != std::find(begin(subUnion), end(subUnion), superOption)) + return; } - else - tryBind(superTy); + + // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. + // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. + if (log.haveSeen(subTy, superOption)) + { + // TODO: would it be nice for TxnLog::replace to do this? + if (log.is(superOption)) + log.bindTable(superOption, subTy); + else + log.replace(superOption, *subTy); + } + }; + + if (auto utv = log.getMutable(superTy)) + { + for (TypeId ty : utv) + tryBind(ty); } + else + tryBind(superTy); if (unificationTooComplex) reportError(*unificationTooComplex); @@ -883,7 +829,7 @@ bool Unifier::canCacheResult(TypeId subTy, TypeId superTy) auto skipCacheFor = [this](TypeId ty) { SkipCacheForType visitor{sharedState.skipCacheForType, types}; - DEPRECATED_visitTypeVarOnce(ty, visitor, sharedState.seenAny); + visitor.traverse(ty); sharedState.skipCacheForType[ty] = visitor.result; @@ -1088,6 +1034,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (!log.getMutable(superTp)) { + Widen widen{types}; log.replace(superTp, Unifiable::Bound(widen(subTp))); } } @@ -1671,28 +1618,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } } -TypeId Unifier::widen(TypeId ty) -{ - if (!FFlag::LuauWidenIfSupertypeIsFree2) - return ty; - - Widen widen{types}; - std::optional result = widen.substitute(ty); - // TODO: what does it mean for substitution to fail to widen? - return result.value_or(ty); -} - -TypePackId Unifier::widen(TypePackId tp) -{ - if (!FFlag::LuauWidenIfSupertypeIsFree2) - return tp; - - Widen widen{types}; - std::optional result = widen.substitute(tp); - // TODO: what does it mean for substitution to fail to widen? - return result.value_or(tp); -} - TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) { ty = follow(ty); @@ -1809,10 +1734,7 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) { if (auto subProp = findTablePropertyRespectingMeta(subTy, freeName)) { - if (FFlag::LuauWidenIfSupertypeIsFree2) - tryUnify_(*subProp, freeProp.type); - else - tryUnify_(freeProp.type, *subProp); + tryUnify_(*subProp, freeProp.type); /* * TypeVars are commonly cyclic, so it is entirely possible diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index 9f7b2bd..be28282 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Common.h" +#include "Luau/Common.h" #include diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index c053e6b..eaf1991 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -11,6 +11,8 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) +LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false) + namespace Luau { @@ -1589,6 +1591,17 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { return parseFunctionTypeAnnotation(allowPack); } + else if (FFlag::LuauParserFunctionKeywordAsTypeHelp && lexer.current().type == Lexeme::ReservedFunction) + { + Location location = lexer.current().location; + + nextLexeme(); + + return {reportTypeAnnotationError(location, {}, /*isMissing*/ false, + "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " + "...any'"), + {}}; + } else { Location location = lexer.current().location; diff --git a/CMakeLists.txt b/CMakeLists.txt index ea35230..c624a13 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,9 +19,11 @@ if(LUAU_STATIC_CRT) endif() project(Luau LANGUAGES CXX C) +add_library(Luau.Common INTERFACE) add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) add_library(Luau.Analysis STATIC) +add_library(Luau.CodeGen STATIC) add_library(Luau.VM STATIC) add_library(isocline STATIC) @@ -48,8 +50,11 @@ endif() include(Sources.cmake) +target_include_directories(Luau.Common INTERFACE Common/include) + target_compile_features(Luau.Ast PUBLIC cxx_std_17) target_include_directories(Luau.Ast PUBLIC Ast/include) +target_link_libraries(Luau.Ast PUBLIC Luau.Common) target_compile_features(Luau.Compiler PUBLIC cxx_std_17) target_include_directories(Luau.Compiler PUBLIC Compiler/include) @@ -59,8 +64,13 @@ target_compile_features(Luau.Analysis PUBLIC cxx_std_17) target_include_directories(Luau.Analysis PUBLIC Analysis/include) target_link_libraries(Luau.Analysis PUBLIC Luau.Ast) +target_compile_features(Luau.CodeGen PRIVATE cxx_std_17) +target_include_directories(Luau.CodeGen PUBLIC CodeGen/include) +target_link_libraries(Luau.CodeGen PUBLIC Luau.Common) + target_compile_features(Luau.VM PRIVATE cxx_std_11) target_include_directories(Luau.VM PUBLIC VM/include) +target_link_libraries(Luau.VM PUBLIC Luau.Common) target_include_directories(isocline PUBLIC extern/isocline/include) @@ -101,6 +111,7 @@ endif() target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) +target_compile_options(Luau.CodeGen PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) target_compile_options(isocline PRIVATE ${LUAU_OPTIONS} ${ISOCLINE_OPTIONS}) @@ -120,6 +131,7 @@ endif() if(MSVC) target_link_options(Luau.Ast INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Ast.natvis) target_link_options(Luau.Analysis INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Analysis.natvis) + target_link_options(Luau.CodeGen INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/CodeGen.natvis) target_link_options(Luau.VM INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/VM.natvis) endif() @@ -127,6 +139,7 @@ endif() if(MSVC_IDE) target_sources(Luau.Ast PRIVATE tools/natvis/Ast.natvis) target_sources(Luau.Analysis PRIVATE tools/natvis/Analysis.natvis) + target_sources(Luau.CodeGen PRIVATE tools/natvis/CodeGen.natvis) target_sources(Luau.VM PRIVATE tools/natvis/VM.natvis) endif() @@ -154,7 +167,7 @@ endif() if(LUAU_BUILD_TESTS) target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.UnitTest PRIVATE extern) - target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler) + target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen) target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Conformance PRIVATE extern) diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h new file mode 100644 index 0000000..c5979d3 --- /dev/null +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -0,0 +1,169 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/Condition.h" +#include "Luau/Label.h" +#include "Luau/OperandX64.h" +#include "Luau/RegisterX64.h" + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +class AssemblyBuilderX64 +{ +public: + explicit AssemblyBuilderX64(bool logText); + ~AssemblyBuilderX64(); + + // Base two operand instructions with 9 opcode selection + void add(OperandX64 lhs, OperandX64 rhs); + void sub(OperandX64 lhs, OperandX64 rhs); + void cmp(OperandX64 lhs, OperandX64 rhs); + void and_(OperandX64 lhs, OperandX64 rhs); + void or_(OperandX64 lhs, OperandX64 rhs); + void xor_(OperandX64 lhs, OperandX64 rhs); + + // Binary shift instructions with special rhs handling + void sal(OperandX64 lhs, OperandX64 rhs); + void sar(OperandX64 lhs, OperandX64 rhs); + void shl(OperandX64 lhs, OperandX64 rhs); + void shr(OperandX64 lhs, OperandX64 rhs); + + // Two operand mov instruction has additional specialized encodings + void mov(OperandX64 lhs, OperandX64 rhs); + void mov64(RegisterX64 lhs, int64_t imm); + + // Base one operand instruction with 2 opcode selection + void div(OperandX64 op); + void idiv(OperandX64 op); + void mul(OperandX64 op); + void neg(OperandX64 op); + void not_(OperandX64 op); + + void test(OperandX64 lhs, OperandX64 rhs); + void lea(OperandX64 lhs, OperandX64 rhs); + + void push(OperandX64 op); + void pop(OperandX64 op); + void ret(); + + // Control flow + void jcc(Condition cond, Label& label); + void jmp(Label& label); + void jmp(OperandX64 op); + + // AVX + void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vaddsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vaddss(OperandX64 dst, OperandX64 src1, OperandX64 src2); + + void vsqrtpd(OperandX64 dst, OperandX64 src); + void vsqrtps(OperandX64 dst, OperandX64 src); + void vsqrtsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vsqrtss(OperandX64 dst, OperandX64 src1, OperandX64 src2); + + void vmovsd(OperandX64 dst, OperandX64 src); + void vmovsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vmovss(OperandX64 dst, OperandX64 src); + void vmovss(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vmovapd(OperandX64 dst, OperandX64 src); + void vmovaps(OperandX64 dst, OperandX64 src); + void vmovupd(OperandX64 dst, OperandX64 src); + void vmovups(OperandX64 dst, OperandX64 src); + + // Run final checks + void finalize(); + + // Places a label at current location and returns it + Label setLabel(); + + // Assigns label position to the current location + void setLabel(Label& label); + + // Constant allocation (uses rip-relative addressing) + OperandX64 i64(int64_t value); + OperandX64 f32(float value); + OperandX64 f64(double value); + OperandX64 f32x4(float x, float y, float z, float w); + + // Resulting data and code that need to be copied over one after the other + // The *end* of 'data' has to be aligned to 16 bytes, this will also align 'code' + std::vector data; + std::vector code; + + std::string text; + +private: + // Instruction archetypes + void placeBinary(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t codeimm8, uint8_t codeimm, uint8_t codeimmImm8, uint8_t code8rev, + uint8_t coderev, uint8_t code8, uint8_t code, uint8_t opreg); + void placeBinaryRegMemAndImm(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code, uint8_t codeImm8, uint8_t opreg); + void placeBinaryRegAndRegMem(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code); + void placeBinaryRegMemAndReg(OperandX64 lhs, OperandX64 rhs, uint8_t code8, uint8_t code); + + void placeUnaryModRegMem(const char* name, OperandX64 op, uint8_t code8, uint8_t code, uint8_t opreg); + + void placeShift(const char* name, OperandX64 lhs, OperandX64 rhs, uint8_t opreg); + + void placeJcc(const char* name, Label& label, uint8_t cc); + + void placeAvx(const char* name, OperandX64 dst, OperandX64 src, uint8_t code, bool setW, uint8_t mode, uint8_t prefix); + void placeAvx(const char* name, OperandX64 dst, OperandX64 src, uint8_t code, uint8_t coderev, bool setW, uint8_t mode, uint8_t prefix); + void placeAvx(const char* name, OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t code, bool setW, uint8_t mode, uint8_t prefix); + + // Instruction components + void placeRegAndModRegMem(OperandX64 lhs, OperandX64 rhs); + void placeModRegMem(OperandX64 rhs, uint8_t regop); + void placeRex(RegisterX64 op); + void placeRex(OperandX64 op); + void placeRex(RegisterX64 lhs, OperandX64 rhs); + void placeVex(OperandX64 dst, OperandX64 src1, OperandX64 src2, bool setW, uint8_t mode, uint8_t prefix); + void placeImm8Or32(int32_t imm); + void placeImm8(int32_t imm); + void placeImm32(int32_t imm); + void placeImm64(int64_t imm); + void placeLabel(Label& label); + void place(uint8_t byte); + + void commit(); + LUAU_NOINLINE void extend(); + uint32_t getCodeSize(); + + // Data + size_t allocateData(size_t size, size_t align); + + // Logging of assembly in text form (Intel asm with VS disassembly formatting) + LUAU_NOINLINE void log(const char* opcode); + LUAU_NOINLINE void log(const char* opcode, OperandX64 op); + LUAU_NOINLINE void log(const char* opcode, OperandX64 op1, OperandX64 op2); + LUAU_NOINLINE void log(const char* opcode, OperandX64 op1, OperandX64 op2, OperandX64 op3); + LUAU_NOINLINE void log(Label label); + LUAU_NOINLINE void log(const char* opcode, Label label); + void log(OperandX64 op); + void logAppend(const char* fmt, ...); + + const char* getSizeName(SizeX64 size); + const char* getRegisterName(RegisterX64 reg); + + uint32_t nextLabel = 1; + std::vector(a) -> a" == toString(idType)); +} + +#if 1 +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") +{ + AstStatBlock* block = parse(R"( + local function a(c) + local function d(e) + return c + end + + return d + end + + local b = a(5) + )"); + + cgb.visit(block); + + ToStringOptions opts; + + ConstraintSolver cs{&arena, cgb.rootScope}; + + cs.run(); + + TypeId idType = requireBinding(cgb.rootScope, "b"); + + CHECK("(a) -> number" == toString(idType, opts)); +} +#endif + +TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 03f3e15..232ec2d 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -17,6 +17,8 @@ static const char* mainModuleName = "MainModule"; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + namespace Luau { @@ -249,7 +251,10 @@ std::optional Fixture::getType(const std::string& name) ModulePtr module = getMainModule(); REQUIRE(module); - return lookupName(module->getModuleScope(), name); + if (FFlag::DebugLuauDeferredConstraintResolution) + return linearSearchForBinding(module->getModuleScope2(), name.c_str()); + else + return lookupName(module->getModuleScope(), name); } TypeId Fixture::requireType(const std::string& name) @@ -421,6 +426,12 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); } +ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() + : Fixture() + , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} +{ +} + ModuleName fromString(std::string_view name) { return ModuleName(name); @@ -460,4 +471,27 @@ std::optional lookupName(ScopePtr scope, const std::string& name) return std::nullopt; } +std::optional linearSearchForBinding(Scope2* scope, const char* name) +{ + while (scope) + { + for (const auto& [n, ty] : scope->bindings) + { + if (n.astName() == name) + return ty; + } + + scope = scope->parent; + } + + return std::nullopt; +} + +void dump(const std::vector& constraints) +{ + ToStringOptions opts; + for (const auto& c : constraints) + printf("%s\n", toString(c, opts).c_str()); +} + } // namespace Luau diff --git a/tests/Fixture.h b/tests/Fixture.h index 901f7d4..ffcd4b9 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Config.h" +#include "Luau/ConstraintGraphBuilder.h" #include "Luau/FileResolver.h" #include "Luau/Frontend.h" #include "Luau/IostreamHelpers.h" @@ -156,6 +157,16 @@ struct BuiltinsFixture : Fixture BuiltinsFixture(bool freeze = true, bool prepareAutocomplete = false); }; +struct ConstraintGraphBuilderFixture : Fixture +{ + TypeArena arena; + ConstraintGraphBuilder cgb{&arena}; + + ScopedFastFlag forceTheFlag; + + ConstraintGraphBuilderFixture(); +}; + ModuleName fromString(std::string_view name); template @@ -175,9 +186,12 @@ bool isInArena(TypeId t, const TypeArena& arena); void dumpErrors(const ModulePtr& module); void dumpErrors(const Module& module); void dump(const std::string& name, TypeId ty); +void dump(const std::vector& constraints); std::optional lookupName(ScopePtr scope, const std::string& name); // Warning: This function runs in O(n**2) +std::optional linearSearchForBinding(Scope2* scope, const char* name); + } // namespace Luau #define LUAU_REQUIRE_ERRORS(result) \ diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 33b81be..c055466 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1031,8 +1031,6 @@ return false; TEST_CASE("check_without_builtin_next") { - ScopedFastFlag luauDoNotRelyOnNextBinding{"LuauDoNotRelyOnNextBinding", true}; - TestFileResolver fileResolver; TestConfigResolver configResolver; Frontend frontend(&fileResolver, &configResolver); diff --git a/tests/NotNull.test.cpp b/tests/NotNull.test.cpp new file mode 100644 index 0000000..1a323c8 --- /dev/null +++ b/tests/NotNull.test.cpp @@ -0,0 +1,116 @@ +#include "Luau/NotNull.h" + +#include "doctest.h" + +#include +#include +#include + +using Luau::NotNull; + +namespace +{ + +struct Test +{ + int x; + float y; + + static int count; + Test() + { + ++count; + } + + ~Test() + { + --count; + } +}; + +int Test::count = 0; + +} + +int foo(NotNull p) +{ + return *p; +} + +void bar(int* q) +{} + +TEST_SUITE_BEGIN("NotNull"); + +TEST_CASE("basic_stuff") +{ + NotNull a = NotNull{new int(55)}; // Does runtime test + NotNull b{new int(55)}; // As above + // NotNull c = new int(55); // Nope. Mildly regrettable, but implicit conversion from T* to NotNull in the general case is not good. + + // a = nullptr; // nope + + NotNull d = a; // No runtime test. a is known not to be null. + + int e = *d; + *d = 1; + CHECK(e == 55); + + const NotNull f = d; + *f = 5; // valid: there is a difference between const NotNull and NotNull + // f = a; // nope + + CHECK_EQ(a, d); + CHECK(a != b); + + NotNull g(a); + CHECK(g == a); + + // *g = 123; // nope + + (void)f; + + NotNull t{new Test}; + t->x = 5; + t->y = 3.14f; + + const NotNull u = t; + // u->x = 44; // nope + int v = u->x; + CHECK(v == 5); + + bar(a); + + // a++; // nope + // a[41]; // nope + // a + 41; // nope + // a - 41; // nope + + delete a; + delete b; + delete t; + + CHECK_EQ(0, Test::count); +} + +TEST_CASE("hashable") +{ + std::unordered_map, const char*> map; + NotNull a{new int(8)}; + NotNull b{new int(10)}; + + std::string hello = "hello"; + std::string world = "world"; + + map[a] = hello.c_str(); + map[b] = world.c_str(); + + CHECK_EQ(2, map.size()); + CHECK_EQ(hello.c_str(), map[a]); + CHECK_EQ(world.c_str(), map[b]); + + delete a; + delete b; +} + +TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index b854bc5..4d9fad1 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -505,7 +505,6 @@ TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function id(x) return x end )"); @@ -518,7 +517,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function map(arr, fn) local t = {} @@ -537,7 +535,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(a: number, b: string) end local function test(...: T...): U... @@ -554,7 +551,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") TEST_CASE("toStringNamedFunction_unit_f") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; TypePackVar empty{TypePack{}}; FunctionTypeVar ftv{&empty, &empty, {}, false}; CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv)); @@ -562,7 +558,6 @@ TEST_CASE("toStringNamedFunction_unit_f") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(x: a, ...): (a, a, b...) return x, x, ... @@ -577,7 +572,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(): ...number return 1, 2, 3 @@ -592,7 +586,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(): (string, ...number) return 'a', 1, 2, 3 @@ -607,7 +600,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_argnames") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local f: (number, y: number) -> number )"); @@ -620,7 +612,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_ar TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(x: T, g: (T) -> U)): () end @@ -636,8 +627,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_overrides_param_names") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; - CheckResult result = check(R"( local function test(a, b : string, ... : number) return a end )"); @@ -665,7 +654,6 @@ TEST_CASE_FIXTURE(Fixture, "pick_distinct_names_for_mixed_explicit_and_implicit_ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local foo = {} function foo:method(arg: string): () @@ -682,7 +670,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") { - ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local foo = {} function foo:method(arg: string): () diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index d90129d..6f4191e 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -470,8 +470,6 @@ caused by: TEST_CASE_FIXTURE(ClassFixture, "class_type_mismatch_with_name_conflict") { - ScopedFastFlag luauClassDefinitionModuleInError{"LuauClassDefinitionModuleInError", true}; - CheckResult result = check(R"( local i = ChildClass.New() type ChildClass = { x: number } diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index a3cae3d..4444cd6 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -78,8 +78,6 @@ TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") { - ScopedFastFlag luauDoNotRelyOnNextBinding{"LuauDoNotRelyOnNextBinding", true}; - CheckResult result = check(R"( local foo = "bar" for i, v in foo do diff --git a/tests/Variant.test.cpp b/tests/Variant.test.cpp index fcf3787..aa0731c 100644 --- a/tests/Variant.test.cpp +++ b/tests/Variant.test.cpp @@ -13,6 +13,25 @@ struct Foo int x = 42; }; +struct Bar +{ + explicit Bar(int x) + : prop(x * 2) + { + ++count; + } + + ~Bar() + { + --count; + } + + int prop; + static int count; +}; + +int Bar::count = 0; + TEST_SUITE_BEGIN("Variant"); TEST_CASE("DefaultCtor") @@ -46,6 +65,29 @@ TEST_CASE("Create") CHECK(get_if(&v3)->x == 3); } +TEST_CASE("Emplace") +{ + { + Variant v1; + + CHECK(0 == Bar::count); + int& i = v1.emplace(5); + CHECK(5 == i); + + CHECK(0 == Bar::count); + + CHECK(get_if(&v1) == &i); + + Bar& bar = v1.emplace(11); + CHECK(22 == bar.prop); + CHECK(1 == Bar::count); + + CHECK(get_if(&v1) == &bar); + } + + CHECK(0 == Bar::count); +} + TEST_CASE("NonPOD") { // initialize (copy) diff --git a/tools/natvis/CodeGen.natvis b/tools/natvis/CodeGen.natvis index 47ff0db..5ff6e14 100644 --- a/tools/natvis/CodeGen.natvis +++ b/tools/natvis/CodeGen.natvis @@ -2,7 +2,7 @@ - noreg + noreg rip al @@ -36,14 +36,20 @@ - {reg} - {mem.size,en} ptr[{mem.base} + {mem.index}*{(int)mem.scale,d} + {disp}] + {base} + {memSize,en} ptr[{base} + {index}*{(int)scale,d} + {imm}] + {memSize,en} ptr[{index}*{(int)scale,d} + {imm}] + {memSize,en} ptr[{base} + {imm}] + {memSize,en} ptr[{imm}] {imm} - reg - mem + base imm - disp + memSize + base + index + scale + imm From 7df94088512a660cad2fb2dcce443d72a72b33ef Mon Sep 17 00:00:00 2001 From: Rodactor Date: Thu, 9 Jun 2022 18:31:12 -0700 Subject: [PATCH 14/19] Sync to origin/release/531 --- Analysis/include/Luau/TypePack.h | 13 +- Analysis/include/Luau/TypeVar.h | 11 +- Analysis/include/Luau/Variant.h | 1 + Analysis/src/Autocomplete.cpp | 41 ++-- Analysis/src/BuiltinDefinitions.cpp | 36 +--- Analysis/src/Clone.cpp | 5 - Analysis/src/EmbeddedBuiltinDefinitions.cpp | 8 +- Analysis/src/Instantiation.cpp | 4 - Analysis/src/Quantify.cpp | 35 ---- Analysis/src/Scope.cpp | 5 +- Analysis/src/Substitution.cpp | 1 - Analysis/src/TxnLog.cpp | 56 +++++- Analysis/src/TypeInfer.cpp | 113 ++++++------ Analysis/src/TypePack.cpp | 21 +++ Analysis/src/TypeVar.cpp | 21 +++ Ast/src/Parser.cpp | 22 ++- Compiler/src/BytecodeBuilder.cpp | 10 +- Compiler/src/Compiler.cpp | 195 +++++--------------- VM/src/lapi.cpp | 106 +++++------ VM/src/lbuiltins.cpp | 12 +- VM/src/ldo.cpp | 25 ++- VM/src/lvmexecute.cpp | 19 +- tests/Autocomplete.test.cpp | 37 ++-- tests/Compiler.test.cpp | 13 -- tests/Conformance.test.cpp | 59 +++++- tests/Module.test.cpp | 20 ++ tests/Parser.test.cpp | 17 ++ tests/TypeInfer.aliases.test.cpp | 6 - tests/TypeInfer.loops.test.cpp | 8 - tests/TypeInfer.refinements.test.cpp | 32 +++- tests/TypeInfer.singletons.test.cpp | 2 - tests/TypePack.test.cpp | 16 ++ tests/TypeVar.test.cpp | 20 ++ tests/conformance/apicalls.lua | 11 ++ tests/conformance/pcall.lua | 17 ++ 35 files changed, 544 insertions(+), 474 deletions(-) diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index bbc65f9..c1de242 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -48,13 +48,24 @@ struct TypePackVar explicit TypePackVar(const TypePackVariant& ty); explicit TypePackVar(TypePackVariant&& ty); TypePackVar(TypePackVariant&& ty, bool persistent); + bool operator==(const TypePackVar& rhs) const; + TypePackVar& operator=(TypePackVariant&& tp); + TypePackVar& operator=(const TypePackVar& rhs); + + // Re-assignes the content of the pack, but doesn't change the owning arena and can't make pack persistent. + void reassign(const TypePackVar& rhs) + { + ty = rhs.ty; + } + TypePackVariant ty; + bool persistent = false; - // Pointer to the type arena that allocated this type. + // Pointer to the type arena that allocated this pack. TypeArena* owningArena = nullptr; }; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index b3c455c..b59e7c6 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -334,7 +334,6 @@ struct TableTypeVar // We need to know which is which when we stringify types. std::optional syntheticName; - std::map methodDefinitionLocations; // TODO: Remove with FFlag::LuauNoMethodLocations std::vector instantiatedTypeParams; std::vector instantiatedTypePackParams; ModuleName definitionModuleName; @@ -465,6 +464,14 @@ struct TypeVar final { } + // Re-assignes the content of the type, but doesn't change the owning arena and can't make type persistent. + void reassign(const TypeVar& rhs) + { + ty = rhs.ty; + normal = rhs.normal; + documentationSymbol = rhs.documentationSymbol; + } + TypeVariant ty; // Kludge: A persistent TypeVar is one that belongs to the global scope. @@ -486,6 +493,8 @@ struct TypeVar final TypeVar& operator=(const TypeVariant& rhs); TypeVar& operator=(TypeVariant&& rhs); + + TypeVar& operator=(const TypeVar& rhs); }; using SeenSet = std::set>; diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index c9c97c9..f637222 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -6,6 +6,7 @@ #include #include #include +#include namespace Luau { diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index b988ed3..a8319c5 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,8 +14,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); -LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteClassSecurityLevel, false); -LUAU_FASTFLAG(LuauSelfCallAutocompleteFix) +LUAU_FASTFLAG(LuauSelfCallAutocompleteFix2) static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -248,7 +247,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ ty = follow(ty); auto canUnify = [&typeArena](TypeId subTy, TypeId superTy) { - LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix); + LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix2); InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); @@ -267,7 +266,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ TypeId expectedType = follow(*typeAtPosition); auto checkFunctionType = [typeArena, &canUnify, &expectedType](const FunctionTypeVar* ftv) { - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) { if (std::optional firstRetTy = first(ftv->retType)) return checkTypeMatch(typeArena, *firstRetTy, expectedType); @@ -308,7 +307,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } } - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) return checkTypeMatch(typeArena, ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; else return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; @@ -325,7 +324,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, std::optional containingClass = std::nullopt) { - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) rootTy = follow(rootTy); ty = follow(ty); @@ -335,7 +334,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId seen.insert(ty); auto isWrongIndexer_DEPRECATED = [indexType, useStrictFunctionIndexers = !!get(ty)](Luau::TypeId type) { - LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix); + LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix2); if (indexType == PropIndexType::Key) return false; @@ -368,7 +367,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId } }; auto isWrongIndexer = [typeArena, rootTy, indexType](Luau::TypeId type) { - LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix); + LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix2); if (indexType == PropIndexType::Key) return false; @@ -382,10 +381,15 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId return calledWithSelf == ftv->hasSelf; } - if (std::optional firstArgTy = first(ftv->argTypes)) + // If a call is made with ':', it is invalid if a function has incompatible first argument or no arguments at all + // If a call is made with '.', but it was declared with 'self', it is considered invalid if first argument is compatible + if (calledWithSelf || ftv->hasSelf) { - if (checkTypeMatch(typeArena, rootTy, *firstArgTy)) - return calledWithSelf; + if (std::optional firstArgTy = first(ftv->argTypes)) + { + if (checkTypeMatch(typeArena, rootTy, *firstArgTy)) + return calledWithSelf; + } } return !calledWithSelf; @@ -427,7 +431,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryKind::Property, type, prop.deprecated, - FFlag::LuauSelfCallAutocompleteFix ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type), + FFlag::LuauSelfCallAutocompleteFix2 ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type), typeCorrect, containingClass, &prop, @@ -462,8 +466,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId containingClass = containingClass.value_or(cls); fillProps(cls->props); if (cls->parent) - autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, - FFlag::LuauFixAutocompleteClassSecurityLevel ? containingClass : cls); + autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); } else if (auto tbl = get(ty)) fillProps(tbl->props); @@ -471,7 +474,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId { autocompleteProps(module, typeArena, rootTy, mt->table, indexType, nodes, result, seen); - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) { if (auto mtable = get(mt->metatable)) fillMetatableProps(mtable); @@ -537,7 +540,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryMap inner; std::unordered_set innerSeen; - if (!FFlag::LuauSelfCallAutocompleteFix) + if (!FFlag::LuauSelfCallAutocompleteFix2) innerSeen = seen; if (isNil(*iter)) @@ -563,7 +566,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId ++iter; } } - else if (auto pt = get(ty); pt && FFlag::LuauSelfCallAutocompleteFix) + else if (auto pt = get(ty); pt && FFlag::LuauSelfCallAutocompleteFix2) { if (pt->metatable) { @@ -571,7 +574,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId fillMetatableProps(mtable); } } - else if (FFlag::LuauSelfCallAutocompleteFix && get(get(ty))) + else if (FFlag::LuauSelfCallAutocompleteFix2 && get(get(ty))) { autocompleteProps(module, typeArena, rootTy, getSingletonTypes().stringType, indexType, nodes, result, seen); } @@ -1501,7 +1504,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - if (!FFlag::LuauSelfCallAutocompleteFix && isString(ty)) + if (!FFlag::LuauSelfCallAutocompleteFix2 && isString(ty)) return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, finder.ancestry), finder.ancestry}; else diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 5ed6de6..98737b4 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -179,44 +179,13 @@ void registerBuiltinTypes(TypeChecker& typeChecker) LUAU_ASSERT(!typeChecker.globalTypes.typeVars.isFrozen()); LUAU_ASSERT(!typeChecker.globalTypes.typePacks.isFrozen()); - TypeId numberType = typeChecker.numberType; - TypeId booleanType = typeChecker.booleanType; TypeId nilType = typeChecker.nilType; TypeArena& arena = typeChecker.globalTypes; - TypePackId oneNumberPack = arena.addTypePack({numberType}); - TypePackId oneBooleanPack = arena.addTypePack({booleanType}); - - TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}}); - TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList}); - - TypeId listOfAtLeastOneNumberToNumberType = arena.addType(FunctionTypeVar{ - listOfAtLeastOneNumber, - oneNumberPack, - }); - - TypeId listOfAtLeastZeroNumbersToNumberType = arena.addType(FunctionTypeVar{numberVariadicList, oneNumberPack}); - LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau"); LUAU_ASSERT(loadResult.success); - TypeId mathLibType = getGlobalBinding(typeChecker, "math"); - if (TableTypeVar* ttv = getMutable(mathLibType)) - { - ttv->props["min"] = makeProperty(listOfAtLeastOneNumberToNumberType, "@luau/global/math.min"); - ttv->props["max"] = makeProperty(listOfAtLeastOneNumberToNumberType, "@luau/global/math.max"); - } - - TypeId bit32LibType = getGlobalBinding(typeChecker, "bit32"); - if (TableTypeVar* ttv = getMutable(bit32LibType)) - { - ttv->props["band"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.band"); - ttv->props["bor"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.bor"); - ttv->props["bxor"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.bxor"); - ttv->props["btest"] = makeProperty(arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneBooleanPack}), "@luau/global/bit32.btest"); - } - TypeId genericK = arena.addType(GenericTypeVar{"K"}); TypeId genericV = arena.addType(GenericTypeVar{"V"}); TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); @@ -231,7 +200,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); - // next(t: Table, i: K | nil) -> (K, V) + // next(t: Table, i: K?) -> (K, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); addGlobalBinding(typeChecker, "next", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); @@ -241,8 +210,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); - // NOTE we are missing 'i: K | nil' argument in the first return types' argument. - // pairs(t: Table) -> ((Table) -> (K, V), Table, nil) + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) addGlobalBinding(typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 19e3383..9180f30 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -9,7 +9,6 @@ LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) -LUAU_FASTFLAG(LuauNoMethodLocations) namespace Luau { @@ -241,8 +240,6 @@ void TypeCloner::operator()(const TableTypeVar& t) arg = clone(arg, dest, cloneState); ttv->definitionModuleName = t.definitionModuleName; - if (!FFlag::LuauNoMethodLocations) - ttv->methodDefinitionLocations = t.methodDefinitionLocations; ttv->tags = t.tags; } @@ -406,8 +403,6 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log) { LUAU_ASSERT(!ttv->boundTo); TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index f184b74..2407e3e 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -7,7 +7,10 @@ namespace Luau static const std::string kBuiltinDefinitionLuaSrc = R"BUILTIN_SRC( declare bit32: { - -- band, bor, bxor, and btest are declared in C++ + band: (...number) -> number, + bor: (...number) -> number, + bxor: (...number) -> number, + btest: (number, ...number) -> boolean, rrotate: (number, number) -> number, lrotate: (number, number) -> number, lshift: (number, number) -> number, @@ -50,7 +53,8 @@ declare math: { asin: (number) -> number, atan2: (number, number) -> number, - -- min and max are declared in C++. + min: (number, ...number) -> number, + max: (number, ...number) -> number, pi: number, huge: number, diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 4a12027..f145a51 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -4,8 +4,6 @@ #include "Luau/TxnLog.h" #include "Luau/TypeArena.h" -LUAU_FASTFLAG(LuauNoMethodLocations) - namespace Luau { @@ -110,8 +108,6 @@ TypeId ReplaceGenerics::clean(TypeId ty) if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; return addType(std::move(clone)); } diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 8f2cc8e..2177537 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -32,41 +32,6 @@ struct Quantifier final : TypeVarOnceVisitor LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); } - void cycle(TypeId) override {} - void cycle(TypePackId) override {} - - bool operator()(TypeId ty, const FreeTypeVar& ftv) - { - return visit(ty, ftv); - } - - template - bool operator()(TypeId ty, const T& t) - { - return true; - } - - template - bool operator()(TypePackId, const T&) - { - return true; - } - - bool operator()(TypeId ty, const ConstrainedTypeVar&) - { - return true; - } - - bool operator()(TypeId ty, const TableTypeVar& ttv) - { - return visit(ty, ttv); - } - - bool operator()(TypePackId tp, const FreeTypePack& ftp) - { - return visit(tp, ftp); - } - /// @return true if outer encloses inner bool subsumes(Scope2* outer, Scope2* inner) { diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 0a362a5..011e28d 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -2,8 +2,6 @@ #include "Luau/Scope.h" -LUAU_FASTFLAG(LuauTwoPassAliasDefinitionFix); - namespace Luau { @@ -19,8 +17,7 @@ Scope::Scope(const ScopePtr& parent, int subLevel) , returnType(parent->returnType) , level(parent->level.incr()) { - if (FFlag::LuauTwoPassAliasDefinitionFix) - level = level.incr(); + level = level.incr(); level.subLevel = subLevel; } diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 50c516d..5a22dee 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -10,7 +10,6 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) -LUAU_FASTFLAG(LuauNoMethodLocations) namespace Luau { diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index e45c0cb..4c6d54e 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,6 +7,8 @@ #include #include +LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) + namespace Luau { @@ -80,18 +82,32 @@ void TxnLog::commit() { for (auto& [ty, rep] : typeVarChanges) { - TypeArena* owningArena = ty->owningArena; - TypeVar* mtv = asMutable(ty); - *mtv = rep.get()->pending; - mtv->owningArena = owningArena; + if (FFlag::LuauNonCopyableTypeVarFields) + { + asMutable(ty)->reassign(rep.get()->pending); + } + else + { + TypeArena* owningArena = ty->owningArena; + TypeVar* mtv = asMutable(ty); + *mtv = rep.get()->pending; + mtv->owningArena = owningArena; + } } for (auto& [tp, rep] : typePackChanges) { - TypeArena* owningArena = tp->owningArena; - TypePackVar* mpv = asMutable(tp); - *mpv = rep.get()->pending; - mpv->owningArena = owningArena; + if (FFlag::LuauNonCopyableTypeVarFields) + { + asMutable(tp)->reassign(rep.get()->pending); + } + else + { + TypeArena* owningArena = tp->owningArena; + TypePackVar* mpv = asMutable(tp); + *mpv = rep.get()->pending; + mpv->owningArena = owningArena; + } } clear(); @@ -178,8 +194,13 @@ PendingType* TxnLog::queue(TypeId ty) // about this type, we don't want to mutate the parent's state. auto& pending = typeVarChanges[ty]; if (!pending) + { pending = std::make_unique(*ty); + if (FFlag::LuauNonCopyableTypeVarFields) + pending->pending.owningArena = nullptr; + } + return pending.get(); } @@ -191,8 +212,13 @@ PendingTypePack* TxnLog::queue(TypePackId tp) // about this type, we don't want to mutate the parent's state. auto& pending = typePackChanges[tp]; if (!pending) + { pending = std::make_unique(*tp); + if (FFlag::LuauNonCopyableTypeVarFields) + pending->pending.owningArena = nullptr; + } + return pending.get(); } @@ -229,14 +255,24 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) { PendingType* newTy = queue(ty); - newTy->pending = replacement; + + if (FFlag::LuauNonCopyableTypeVarFields) + newTy->pending.reassign(replacement); + else + newTy->pending = replacement; + return newTy; } PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) { PendingTypePack* newTp = queue(tp); - newTp->pending = replacement; + + if (FFlag::LuauNonCopyableTypeVarFields) + newTp->pending.reassign(replacement); + else + newTp->pending = replacement; + return newTp; } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 4931bc5..447cd02 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -33,21 +33,20 @@ LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAGVARIABLE(LuauExpectedPropTypeFromIndexer, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) +LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix2, false) LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) -LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); LUAU_FASTFLAGVARIABLE(LuauApplyTypeFunctionFix, false); -LUAU_FASTFLAGVARIABLE(LuauTypecheckIter, false); LUAU_FASTFLAGVARIABLE(LuauSuccessTypingForEqualityOperations, false) -LUAU_FASTFLAGVARIABLE(LuauNoMethodLocations, false); LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) +LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false) +LUAU_FASTFLAGVARIABLE(LuauNonCopyableTypeVarFields, false) namespace Luau { @@ -358,8 +357,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo unifierState.cachedUnifyError.clear(); unifierState.skipCacheForType.clear(); - if (FFlag::LuauTwoPassAliasDefinitionFix) - duplicateTypeAliases.clear(); + duplicateTypeAliases.clear(); return std::move(currentModule); } @@ -610,7 +608,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std { if (const auto& typealias = stat->as()) { - if (FFlag::LuauTwoPassAliasDefinitionFix && typealias->name == kParseNameError) + if (typealias->name == kParseNameError) continue; auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; @@ -619,7 +617,16 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std TypeId type = bindings[name].type; if (get(follow(type))) { - *asMutable(type) = *errorRecoveryType(anyType); + if (FFlag::LuauNonCopyableTypeVarFields) + { + TypeVar* mty = asMutable(follow(type)); + mty->reassign(*errorRecoveryType(anyType)); + } + else + { + *asMutable(type) = *errorRecoveryType(anyType); + } + reportError(TypeError{typealias->location, OccursCheckFailed{}}); } } @@ -1131,45 +1138,43 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); } - if (FFlag::LuauTypecheckIter) + if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location)) { - if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location)) + // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions + // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments + // the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types + for (TypeId var : varTypes) + unify(anyType, var, forin.location); + + return check(loopScope, *forin.body); + } + + if (const TableTypeVar* iterTable = get(iterTy)) + { + // TODO: note that this doesn't cleanly handle iteration over mixed tables and tables without an indexer + // this behavior is more or less consistent with what we do for pairs(), but really both are pretty wrong and need revisiting + if (iterTable->indexer) { - // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions - // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments - // the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types + if (varTypes.size() > 0) + unify(iterTable->indexer->indexType, varTypes[0], forin.location); + + if (varTypes.size() > 1) + unify(iterTable->indexer->indexResultType, varTypes[1], forin.location); + + for (size_t i = 2; i < varTypes.size(); ++i) + unify(nilType, varTypes[i], forin.location); + } + else + { + TypeId varTy = errorRecoveryType(loopScope); + for (TypeId var : varTypes) - unify(anyType, var, forin.location); + unify(varTy, var, forin.location); - return check(loopScope, *forin.body); + reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); } - else if (const TableTypeVar* iterTable = get(iterTy)) - { - // TODO: note that this doesn't cleanly handle iteration over mixed tables and tables without an indexer - // this behavior is more or less consistent with what we do for pairs(), but really both are pretty wrong and need revisiting - if (iterTable->indexer) - { - if (varTypes.size() > 0) - unify(iterTable->indexer->indexType, varTypes[0], forin.location); - if (varTypes.size() > 1) - unify(iterTable->indexer->indexResultType, varTypes[1], forin.location); - - for (size_t i = 2; i < varTypes.size(); ++i) - unify(nilType, varTypes[i], forin.location); - } - else - { - TypeId varTy = errorRecoveryType(loopScope); - - for (TypeId var : varTypes) - unify(varTy, var, forin.location); - - reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); - } - - return check(loopScope, *forin.body); - } + return check(loopScope, *forin.body); } const FunctionTypeVar* iterFunc = get(iterTy); @@ -1334,7 +1339,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias Name name = typealias.name.value; // If the alias is missing a name, we can't do anything with it. Ignore it. - if (FFlag::LuauTwoPassAliasDefinitionFix && name == kParseNameError) + if (name == kParseNameError) return; std::optional binding; @@ -1353,8 +1358,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; - if (FFlag::LuauTwoPassAliasDefinitionFix) - duplicateTypeAliases.insert({typealias.exported, name}); + duplicateTypeAliases.insert({typealias.exported, name}); } else { @@ -1378,7 +1382,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { // If the first pass failed (this should mean a duplicate definition), the second pass isn't going to be // interesting. - if (FFlag::LuauTwoPassAliasDefinitionFix && duplicateTypeAliases.find({typealias.exported, name})) + if (duplicateTypeAliases.find({typealias.exported, name})) return; if (!binding) @@ -1422,9 +1426,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = name; @@ -1462,9 +1463,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } TypeId& bindingType = bindingsMap[name].type; - bool ok = unify(ty, bindingType, typealias.location); - if (FFlag::LuauTwoPassAliasDefinitionFix && ok) + if (unify(ty, bindingType, typealias.location)) bindingType = ty; if (FFlag::LuauLowerBoundsCalculation) @@ -1532,7 +1532,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) ftv->hasSelf = true; } } @@ -3099,8 +3099,6 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T property.type = freshTy(); property.location = indexName->indexLocation; - if (!FFlag::LuauNoMethodLocations) - ttv->methodDefinitionLocations[name] = funName.location; return property.type; } else if (funName.is()) @@ -4393,8 +4391,6 @@ TypeId Anyification::clean(TypeId ty) if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; @@ -4705,8 +4701,11 @@ TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) if (isNil(ty)) return sense ? std::nullopt : std::optional(ty); - // at this point, anything else is kept if sense is true, or eliminated otherwise - return sense ? std::optional(ty) : std::nullopt; + // at this point, anything else is kept if sense is true, or replaced by nil + if (FFlag::LuauFalsyPredicateReturnsNilInstead) + return sense ? ty : nilType; + else + return sense ? std::optional(ty) : std::nullopt; }; } diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 3050323..82451bd 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -5,6 +5,8 @@ #include +LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) + namespace Luau { @@ -36,6 +38,25 @@ TypePackVar& TypePackVar::operator=(TypePackVariant&& tp) return *this; } +TypePackVar& TypePackVar::operator=(const TypePackVar& rhs) +{ + if (FFlag::LuauNonCopyableTypeVarFields) + { + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); + + reassign(rhs); + } + else + { + ty = rhs.ty; + persistent = rhs.persistent; + owningArena = rhs.owningArena; + } + + return *this; +} + TypePackIterator::TypePackIterator(TypePackId typePack) : TypePackIterator(typePack, TxnLog::empty()) { diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 12cbed9..33bfe25 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -24,6 +24,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) +LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) namespace Luau { @@ -644,6 +645,26 @@ TypeVar& TypeVar::operator=(TypeVariant&& rhs) return *this; } +TypeVar& TypeVar::operator=(const TypeVar& rhs) +{ + if (FFlag::LuauNonCopyableTypeVarFields) + { + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); + + reassign(rhs); + } + else + { + ty = rhs.ty; + persistent = rhs.persistent; + normal = rhs.normal; + owningArena = rhs.owningArena; + } + + return *this; +} + TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index eaf1991..95bce3e 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -12,6 +12,7 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false) +LUAU_FASTFLAGVARIABLE(LuauReturnTypeTokenConfusion, false) namespace Luau { @@ -1118,8 +1119,12 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector Parser::parseOptionalReturnTypeAnnotation() { - if (options.allowTypeAnnotations && lexer.current().type == ':') + if (options.allowTypeAnnotations && + (lexer.current().type == ':' || (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow))) { + if (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow) + report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'"); + nextLexeme(); unsigned int oldRecursionCount = recursionCounter; @@ -1350,8 +1355,12 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) AstArray paramTypes = copy(params); + bool returnTypeIntroducer = + FFlag::LuauReturnTypeTokenConfusion ? lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':' : false; + // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element - if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow) + if (params.size() == 1 && !varargAnnotation && monomorphic && + (FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow)) { if (allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; @@ -1359,7 +1368,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) return {params[0], {}}; } - if (lexer.current().type != Lexeme::SkinnyArrow && monomorphic && allowPack) + if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && monomorphic && allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, varargAnnotation})}; AstArray> paramNames = copy(names); @@ -1373,8 +1382,13 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray' instead of ':'"); + lexer.next(); + } // Users occasionally write '()' as the 'unit' type when they actually want to use 'nil', here we'll try to give a more specific error - if (lexer.current().type != Lexeme::SkinnyArrow && generics.size == 0 && genericPacks.size == 0 && params.size == 0) + else if (lexer.current().type != Lexeme::SkinnyArrow && generics.size == 0 && genericPacks.size == 0 && params.size == 0) { report(Location(begin.location, lexer.previousLocation()), "Expected '->' after '()' when parsing function type; did you mean 'nil'?"); diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 3aa12d9..597b2f0 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -6,8 +6,6 @@ #include #include -LUAU_FASTFLAG(LuauCompileNestedClosureO2) - namespace Luau { @@ -390,17 +388,15 @@ int32_t BytecodeBuilder::addConstantClosure(uint32_t fid) int16_t BytecodeBuilder::addChildFunction(uint32_t fid) { - if (FFlag::LuauCompileNestedClosureO2) - if (int16_t* cache = protoMap.find(fid)) - return *cache; + if (int16_t* cache = protoMap.find(fid)) + return *cache; uint32_t id = uint32_t(protos.size()); if (id >= kMaxClosureCount) return -1; - if (FFlag::LuauCompileNestedClosureO2) - protoMap[fid] = int16_t(id); + protoMap[fid] = int16_t(id); protos.push_back(fid); return int16_t(id); diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index eea56c6..7431cde 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -16,7 +16,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompileIter, false) LUAU_FASTFLAGVARIABLE(LuauCompileIterNoPairs, false) LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThreshold, 25) @@ -26,8 +25,6 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileNestedClosureO2, false) - namespace Luau { @@ -172,30 +169,6 @@ struct Compiler return node->as(); } - bool canInlineFunctionBody(AstStat* stat) - { - if (FFlag::LuauCompileNestedClosureO2) - return true; // TODO: remove this function - - struct CanInlineVisitor : AstVisitor - { - bool result = true; - - bool visit(AstExprFunction* node) override - { - result = false; - - // short-circuit to avoid analyzing nested closure bodies - return false; - } - }; - - CanInlineVisitor canInline; - stat->visit(&canInline); - - return canInline.result; - } - uint32_t compileFunction(AstExprFunction* func) { LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler"); @@ -268,7 +241,7 @@ struct Compiler f.upvals = upvals; // record information for inlining - if (options.optimizationLevel >= 2 && !func->vararg && canInlineFunctionBody(func->body) && !getfenvUsed && !setfenvUsed) + if (options.optimizationLevel >= 2 && !func->vararg && !getfenvUsed && !setfenvUsed) { f.canInline = true; f.stackSize = stackSize; @@ -827,110 +800,62 @@ struct Compiler if (pid < 0) CompileError::raise(expr->location, "Exceeded closure limit; simplify the code to compile"); - if (FFlag::LuauCompileNestedClosureO2) - { - captures.clear(); - captures.reserve(f->upvals.size()); - - for (AstLocal* uv : f->upvals) - { - LUAU_ASSERT(uv->functionDepth < expr->functionDepth); - - if (int reg = getLocalReg(uv); reg >= 0) - { - // note: we can't check if uv is an upvalue in the current frame because inlining can migrate from upvalues to locals - Variable* ul = variables.find(uv); - bool immutable = !ul || !ul->written; - - captures.push_back({immutable ? LCT_VAL : LCT_REF, uint8_t(reg)}); - } - else if (const Constant* uc = locstants.find(uv); uc && uc->type != Constant::Type_Unknown) - { - // inlining can result in an upvalue capture of a constant, in which case we can't capture without a temporary register - uint8_t reg = allocReg(expr, 1); - compileExprConstant(expr, uc, reg); - - captures.push_back({LCT_VAL, reg}); - } - else - { - LUAU_ASSERT(uv->functionDepth < expr->functionDepth - 1); - - // get upvalue from parent frame - // note: this will add uv to the current upvalue list if necessary - uint8_t uid = getUpval(uv); - - captures.push_back({LCT_UPVAL, uid}); - } - } - - // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure - // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it - // is used) - int16_t shared = -1; - - if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); - - if (cid >= 0 && cid < 32768) - shared = int16_t(cid); - } - - if (shared >= 0) - bytecode.emitAD(LOP_DUPCLOSURE, target, shared); - else - bytecode.emitAD(LOP_NEWCLOSURE, target, pid); - - for (const Capture& c : captures) - bytecode.emitABC(LOP_CAPTURE, uint8_t(c.type), c.data, 0); - - return; - } - - bool shared = false; - - // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure - // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it - // is used) - if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); - - if (cid >= 0 && cid < 32768) - { - bytecode.emitAD(LOP_DUPCLOSURE, target, int16_t(cid)); - shared = true; - } - } - - if (!shared) - bytecode.emitAD(LOP_NEWCLOSURE, target, pid); + // we use a scratch vector to reduce allocations; this is safe since compileExprFunction is not reentrant + captures.clear(); + captures.reserve(f->upvals.size()); for (AstLocal* uv : f->upvals) { LUAU_ASSERT(uv->functionDepth < expr->functionDepth); - Variable* ul = variables.find(uv); - bool immutable = !ul || !ul->written; - - if (uv->functionDepth == expr->functionDepth - 1) + if (int reg = getLocalReg(uv); reg >= 0) { - // get local variable - int reg = getLocalReg(uv); - LUAU_ASSERT(reg >= 0); + // note: we can't check if uv is an upvalue in the current frame because inlining can migrate from upvalues to locals + Variable* ul = variables.find(uv); + bool immutable = !ul || !ul->written; - bytecode.emitABC(LOP_CAPTURE, uint8_t(immutable ? LCT_VAL : LCT_REF), uint8_t(reg), 0); + captures.push_back({immutable ? LCT_VAL : LCT_REF, uint8_t(reg)}); + } + else if (const Constant* uc = locstants.find(uv); uc && uc->type != Constant::Type_Unknown) + { + // inlining can result in an upvalue capture of a constant, in which case we can't capture without a temporary register + uint8_t reg = allocReg(expr, 1); + compileExprConstant(expr, uc, reg); + + captures.push_back({LCT_VAL, reg}); } else { + LUAU_ASSERT(uv->functionDepth < expr->functionDepth - 1); + // get upvalue from parent frame // note: this will add uv to the current upvalue list if necessary uint8_t uid = getUpval(uv); - bytecode.emitABC(LOP_CAPTURE, LCT_UPVAL, uid, 0); + captures.push_back({LCT_UPVAL, uid}); } } + + // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure + // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it + // is used) + int16_t shared = -1; + + if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) + { + int32_t cid = bytecode.addConstantClosure(f->id); + + if (cid >= 0 && cid < 32768) + shared = int16_t(cid); + } + + if (shared >= 0) + bytecode.emitAD(LOP_DUPCLOSURE, target, shared); + else + bytecode.emitAD(LOP_NEWCLOSURE, target, pid); + + for (const Capture& c : captures) + bytecode.emitABC(LOP_CAPTURE, uint8_t(c.type), c.data, 0); } LuauOpcode getUnaryOp(AstExprUnary::Op op) @@ -2511,30 +2436,6 @@ struct Compiler pushLocal(stat->vars.data[i], uint8_t(vars + i)); } - bool canUnrollForBody(AstStatFor* stat) - { - if (FFlag::LuauCompileNestedClosureO2) - return true; // TODO: remove this function - - struct CanUnrollVisitor : AstVisitor - { - bool result = true; - - bool visit(AstExprFunction* node) override - { - result = false; - - // short-circuit to avoid analyzing nested closure bodies - return false; - } - }; - - CanUnrollVisitor canUnroll; - stat->body->visit(&canUnroll); - - return canUnroll.result; - } - bool tryCompileUnrolledFor(AstStatFor* stat, int thresholdBase, int thresholdMaxBoost) { Constant one = {Constant::Type_Number}; @@ -2560,12 +2461,6 @@ struct Compiler return false; } - if (!canUnrollForBody(stat)) - { - bytecode.addDebugRemark("loop unroll failed: unsupported loop body"); - return false; - } - if (Variable* lv = variables.find(stat->var); lv && lv->written) { bytecode.addDebugRemark("loop unroll failed: mutable loop variable"); @@ -2730,12 +2625,12 @@ struct Compiler uint8_t vars = allocReg(stat, std::max(unsigned(stat->vars.size), 2u)); LUAU_ASSERT(vars == regs + 3); - // Optimization: when we iterate through pairs/ipairs, we generate special bytecode that optimizes the traversal using internal iteration - // index These instructions dynamically check if generator is equal to next/inext and bail out They assume that the generator produces 2 - // variables, which is why we allocate at least 2 above (see vars assignment) - LuauOpcode skipOp = FFlag::LuauCompileIter ? LOP_FORGPREP : LOP_JUMP; + LuauOpcode skipOp = LOP_FORGPREP; LuauOpcode loopOp = LOP_FORGLOOP; + // Optimization: when we iterate via pairs/ipairs, we generate special bytecode that optimizes the traversal using internal iteration index + // These instructions dynamically check if generator is equal to next/inext and bail out + // They assume that the generator produces 2 variables, which is why we allocate at least 2 above (see vars assignment) if (options.optimizationLevel >= 1 && stat->vars.size <= 2) { if (stat->values.size == 1 && stat->values.data[0]->is()) diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index f86371d..3c3b7bd 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,6 +14,26 @@ #include +/* + * This file contains most implementations of core Lua APIs from lua.h. + * + * These implementations should use api_check macros to verify that stack and type contracts hold; it's the callers + * responsibility to, for example, pass a valid table index to lua_rawgetfield. Generally errors should only be raised + * for conditions caller can't predict such as an out-of-memory error. + * + * The caller is expected to handle stack reservation (by using less than LUA_MINSTACK slots or by calling lua_checkstack). + * To ensure this is handled correctly, use api_incr_top(L) when pushing values to the stack. + * + * Functions that push any collectable objects to the stack *should* call luaC_checkthreadsleep. Failure to do this can result + * in stack references that point to dead objects since sleeping threads don't get rescanned. + * + * Functions that push newly created objects to the stack *should* call luaC_checkGC in addition to luaC_checkthreadsleep. + * Failure to do this can result in OOM since GC may never run. + * + * Note that luaC_checkGC may scan the thread and put it back to sleep; functions that call both before pushing objects must + * therefore call luaC_checkGC before luaC_checkthreadsleep to guarantee the object is pushed to an awake thread. + */ + const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -221,15 +241,13 @@ void lua_insert(lua_State* L, int idx) void lua_replace(lua_State* L, int idx) { - /* explicit test for incompatible code */ - if (idx == LUA_ENVIRONINDEX && L->ci == L->base_ci) - luaG_runerror(L, "no calling environment"); api_checknelems(L, 1); luaC_checkthreadsleep(L); StkId o = index2addr(L, idx); api_checkvalidindex(L, o); if (idx == LUA_ENVIRONINDEX) { + api_check(L, L->ci != L->base_ci); Closure* func = curr_func(L); api_check(L, ttistable(L->top - 1)); func->env = hvalue(L->top - 1); @@ -443,9 +461,7 @@ const float* lua_tovector(lua_State* L, int idx) { StkId o = index2addr(L, idx); if (!ttisvector(o)) - { return NULL; - } return vvalue(o); } @@ -460,11 +476,6 @@ int lua_objlen(lua_State* L, int idx) return uvalue(o)->len; case LUA_TTABLE: return luaH_getn(hvalue(o)); - case LUA_TNUMBER: - { - int l = (luaV_tostring(L, o) ? tsvalue(o)->len : 0); - return l; - } default: return 0; } @@ -752,10 +763,9 @@ void lua_setsafeenv(lua_State* L, int objindex, int enabled) int lua_getmetatable(lua_State* L, int objindex) { - const TValue* obj; + luaC_checkthreadsleep(L); Table* mt = NULL; - int res; - obj = index2addr(L, objindex); + const TValue* obj = index2addr(L, objindex); switch (ttype(obj)) { case LUA_TTABLE: @@ -768,21 +778,18 @@ int lua_getmetatable(lua_State* L, int objindex) mt = L->global->mt[ttype(obj)]; break; } - if (mt == NULL) - res = 0; - else + if (mt) { sethvalue(L, L->top, mt); api_incr_top(L); - res = 1; } - return res; + return mt != NULL; } void lua_getfenv(lua_State* L, int idx) { - StkId o; - o = index2addr(L, idx); + luaC_checkthreadsleep(L); + StkId o = index2addr(L, idx); api_checkvalidindex(L, o); switch (ttype(o)) { @@ -806,9 +813,8 @@ void lua_getfenv(lua_State* L, int idx) void lua_settable(lua_State* L, int idx) { - StkId t; api_checknelems(L, 2); - t = index2addr(L, idx); + StkId t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_settable(L, t, L->top - 2, L->top - 1); L->top -= 2; /* pop index and value */ @@ -817,22 +823,20 @@ void lua_settable(lua_State* L, int idx) void lua_setfield(lua_State* L, int idx, const char* k) { - StkId t; - TValue key; api_checknelems(L, 1); - t = index2addr(L, idx); + StkId t = index2addr(L, idx); api_checkvalidindex(L, t); + TValue key; setsvalue(L, &key, luaS_new(L, k)); luaV_settable(L, t, &key, L->top - 1); - L->top--; /* pop value */ + L->top--; return; } void lua_rawset(lua_State* L, int idx) { - StkId t; api_checknelems(L, 2); - t = index2addr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); if (hvalue(t)->readonly) luaG_runerror(L, "Attempt to modify a readonly table"); @@ -844,9 +848,8 @@ void lua_rawset(lua_State* L, int idx) void lua_rawseti(lua_State* L, int idx, int n) { - StkId o; api_checknelems(L, 1); - o = index2addr(L, idx); + StkId o = index2addr(L, idx); api_check(L, ttistable(o)); if (hvalue(o)->readonly) luaG_runerror(L, "Attempt to modify a readonly table"); @@ -858,14 +861,11 @@ void lua_rawseti(lua_State* L, int idx, int n) int lua_setmetatable(lua_State* L, int objindex) { - TValue* obj; - Table* mt; api_checknelems(L, 1); - obj = index2addr(L, objindex); + TValue* obj = index2addr(L, objindex); api_checkvalidindex(L, obj); - if (ttisnil(L->top - 1)) - mt = NULL; - else + Table* mt = NULL; + if (!ttisnil(L->top - 1)) { api_check(L, ttistable(L->top - 1)); mt = hvalue(L->top - 1); @@ -900,10 +900,9 @@ int lua_setmetatable(lua_State* L, int objindex) int lua_setfenv(lua_State* L, int idx) { - StkId o; int res = 1; api_checknelems(L, 1); - o = index2addr(L, idx); + StkId o = index2addr(L, idx); api_checkvalidindex(L, o); api_check(L, ttistable(L->top - 1)); switch (ttype(o)) @@ -970,24 +969,21 @@ static void f_call(lua_State* L, void* ud) int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) { - struct CallS c; - int status; - ptrdiff_t func; api_checknelems(L, nargs + 1); api_check(L, L->status == 0); checkresults(L, nargs, nresults); - if (errfunc == 0) - func = 0; - else + ptrdiff_t func = 0; + if (errfunc != 0) { StkId o = index2addr(L, errfunc); api_checkvalidindex(L, o); func = savestack(L, o); } + struct CallS c; c.func = L->top - (nargs + 1); /* function to be called */ c.nresults = nresults; - status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); + int status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); adjustresults(L, nresults); return status; @@ -1247,12 +1243,10 @@ const char* lua_getupvalue(lua_State* L, int funcindex, int n) const char* lua_setupvalue(lua_State* L, int funcindex, int n) { - const char* name; - TValue* val; - StkId fi; - fi = index2addr(L, funcindex); api_checknelems(L, 1); - name = aux_upvalue(fi, n, &val); + StkId fi = index2addr(L, funcindex); + TValue* val; + const char* name = aux_upvalue(fi, n, &val); if (name) { L->top--; @@ -1319,14 +1313,16 @@ void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(lua_State*, void*)) void lua_clonefunction(lua_State* L, int idx) { + luaC_checkGC(L); + luaC_checkthreadsleep(L); StkId p = index2addr(L, idx); api_check(L, isLfunction(p)); - - luaC_checkthreadsleep(L); - Closure* cl = clvalue(p); - Closure* newcl = luaF_newLclosure(L, 0, L->gt, cl->l.p); - setclvalue(L, L->top - 1, newcl); + Closure* newcl = luaF_newLclosure(L, cl->nupvalues, L->gt, cl->l.p); + for (int i = 0; i < cl->nupvalues; ++i) + setobj2n(L, &newcl->l.uprefs[i], &cl->l.uprefs[i]); + setclvalue(L, L->top, newcl); + api_incr_top(L); } lua_Callbacks* lua_callbacks(lua_State* L) diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index cc6e560..deaf140 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1018,18 +1018,20 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { -#if LUA_VECTOR_SIZE == 4 - if (nparams >= 4 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1) && ttisnumber(args + 2)) -#else if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) -#endif { double x = nvalue(arg0); double y = nvalue(args); double z = nvalue(args + 1); #if LUA_VECTOR_SIZE == 4 - double w = nvalue(args + 2); + double w = 0.0; + if (nparams >= 4) + { + if (!ttisnumber(args + 2)) + return -1; + w = nvalue(args + 2); + } setvvalue(res, float(x), float(y), float(z), float(w)); #else setvvalue(res, float(x), float(y), float(z), 0.0f); diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index a71fce5..0642cb6 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -202,22 +202,29 @@ void luaD_growstack(lua_State* L, int n) CallInfo* luaD_growCI(lua_State* L) { - if (L->size_ci > LUAI_MAXCALLS) /* overflow while handling overflow? */ - luaD_throw(L, LUA_ERRERR); - else - { - luaD_reallocCI(L, 2 * L->size_ci); - if (L->size_ci > LUAI_MAXCALLS) - luaG_runerror(L, "stack overflow"); - } + /* allow extra stack space to handle stack overflow in xpcall */ + const int hardlimit = LUAI_MAXCALLS + (LUAI_MAXCALLS >> 3); + + if (L->size_ci >= hardlimit) + luaD_throw(L, LUA_ERRERR); /* error while handling stack error */ + + int request = L->size_ci * 2; + luaD_reallocCI(L, L->size_ci >= LUAI_MAXCALLS ? hardlimit : request < LUAI_MAXCALLS ? request : LUAI_MAXCALLS); + + if (L->size_ci > LUAI_MAXCALLS) + luaG_runerror(L, "stack overflow"); + return ++L->ci; } void luaD_checkCstack(lua_State* L) { + /* allow extra stack space to handle stack overflow in xpcall */ + const int hardlimit = LUAI_MAXCCALLS + (LUAI_MAXCCALLS >> 3); + if (L->nCcalls == LUAI_MAXCCALLS) luaG_runerror(L, "C stack overflow"); - else if (L->nCcalls >= (LUAI_MAXCCALLS + (LUAI_MAXCCALLS >> 3))) + else if (L->nCcalls >= hardlimit) luaD_throw(L, LUA_ERRERR); /* error while handling stack error */ } diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index f9fd657..e0a9647 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,8 +16,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauIter, false) - // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -2214,7 +2212,7 @@ static void luau_execute(lua_State* L) { /* will be called during FORGLOOP */ } - else if (FFlag::LuauIter) + else { Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); @@ -2259,17 +2257,6 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); uint32_t aux = *pc; - if (!FFlag::LuauIter) - { - bool stop; - VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), aux)); - - // note that we need to increment pc by 1 to exit the loop since we need to skip over aux - pc += stop ? 1 : LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } - // fast-path: builtin table iteration if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) { @@ -2362,7 +2349,7 @@ static void luau_execute(lua_State* L) /* ra+1 is already the table */ setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } - else if (FFlag::LuauIter && !ttisfunction(ra)) + else if (!ttisfunction(ra)) { VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); } @@ -2434,7 +2421,7 @@ static void luau_execute(lua_State* L) /* ra+1 is already the table */ setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } - else if (FFlag::LuauIter && !ttisfunction(ra)) + else if (!ttisfunction(ra)) { VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); } diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index cc5b31c..dea1ab1 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2764,8 +2764,6 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") { - ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; - check(R"( type tag = "cat" | "dog" local function f(a: tag) end @@ -2838,8 +2836,6 @@ f(@1) TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_escape") { - ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; - check(R"( type tag = "strange\t\"cat\"" | 'nice\t"dog"' local function f(x: tag) end @@ -2873,7 +2869,7 @@ local abc = b@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_on_class") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; loadDefinition(R"( declare class Foo @@ -2913,7 +2909,7 @@ t.@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local t = {} @@ -2929,7 +2925,7 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_2") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local f: (() -> number) & ((number) -> number) = function(x: number?) return 2 end @@ -2961,7 +2957,7 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "string_prim_self_calls_are_fine") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local s = "hello" @@ -2980,7 +2976,7 @@ s:@1 TEST_CASE_FIXTURE(ACFixture, "string_prim_non_self_calls_are_avoided") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local s = "hello" @@ -2989,17 +2985,15 @@ s.@1 auto ac = autocomplete('1'); - REQUIRE(ac.entryMap.count("byte")); - CHECK(ac.entryMap["byte"].wrongIndexType == true); REQUIRE(ac.entryMap.count("char")); CHECK(ac.entryMap["char"].wrongIndexType == false); REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == true); } -TEST_CASE_FIXTURE(ACBuiltinsFixture, "string_library_non_self_calls_are_fine") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_non_self_calls_are_fine") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( string.@1 @@ -3013,11 +3007,24 @@ string.@1 CHECK(ac.entryMap["char"].wrongIndexType == false); REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == false); + + check(R"( +table.@1 + )"); + + ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("remove")); + CHECK(ac.entryMap["remove"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("getn")); + CHECK(ac.entryMap["getn"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("insert")); + CHECK(ac.entryMap["insert"].wrongIndexType == false); } -TEST_CASE_FIXTURE(ACBuiltinsFixture, "string_library_self_calls_are_invalid") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_self_calls_are_invalid") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( string:@1 diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 2013965..6eee254 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -261,7 +261,6 @@ L1: RETURN R0 0 TEST_CASE("ForBytecode") { - ScopedFastFlag sff("LuauCompileIter", true); ScopedFastFlag sff2("LuauCompileIterNoPairs", false); // basic for loop: variable directly refers to internal iteration index (R2) @@ -350,8 +349,6 @@ RETURN R0 0 TEST_CASE("ForBytecodeBuiltin") { - ScopedFastFlag sff("LuauCompileIter", true); - // we generally recognize builtins like pairs/ipairs and emit special opcodes CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"( GETIMPORT R0 1 @@ -2323,8 +2320,6 @@ return result TEST_CASE("DebugLineInfoFor") { - ScopedFastFlag sff("LuauCompileIter", true); - Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); Luau::compileOrThrow(bcb, R"( @@ -4355,8 +4350,6 @@ L1: RETURN R0 0 TEST_CASE("LoopUnrollControlFlow") { - ScopedFastFlag sff("LuauCompileNestedClosureO2", true); - ScopedFastInt sfis[] = { {"LuauCompileLoopUnrollThreshold", 50}, {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, @@ -4475,8 +4468,6 @@ RETURN R0 0 TEST_CASE("LoopUnrollNestedClosure") { - ScopedFastFlag sff("LuauCompileNestedClosureO2", true); - // if the body has functions that refer to loop variables, we unroll the loop and use MOVE+CAPTURE for upvalues CHECK_EQ("\n" + compileFunction(R"( for i=1,2 do @@ -4756,8 +4747,6 @@ RETURN R1 1 TEST_CASE("InlineBasicProhibited") { - ScopedFastFlag sff("LuauCompileNestedClosureO2", true); - // we can't inline variadic functions CHECK_EQ("\n" + compileFunction(R"( local function foo(...) @@ -4833,8 +4822,6 @@ RETURN R1 1 TEST_CASE("InlineNestedClosures") { - ScopedFastFlag sff("LuauCompileNestedClosureO2", true); - // we can inline functions that contain/return functions CHECK_EQ("\n" + compileFunction(R"( local function foo(x) diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index f7f2b4a..96a2775 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -741,7 +741,7 @@ TEST_CASE("ApiTables") lua_pop(L, 1); } -TEST_CASE("ApiFunctionCalls") +TEST_CASE("ApiCalls") { StateRef globalState = runConformance("apicalls.lua"); lua_State* L = globalState.get(); @@ -790,6 +790,58 @@ TEST_CASE("ApiFunctionCalls") CHECK(lua_equal(L2, -1, -2) == 1); lua_pop(L2, 2); } + + // lua_clonefunction + fenv + { + lua_getfield(L, LUA_GLOBALSINDEX, "getpi"); + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3.1415926); + lua_pop(L, 1); + + lua_getfield(L, LUA_GLOBALSINDEX, "getpi"); + + // clone & override env + lua_clonefunction(L, -1); + lua_newtable(L); + lua_pushnumber(L, 42); + lua_setfield(L, -2, "pi"); + lua_setfenv(L, -2); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); + + // this one calls original function again + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3.1415926); + lua_pop(L, 1); + } + + // lua_clonefunction + upvalues + { + lua_getfield(L, LUA_GLOBALSINDEX, "incuv"); + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 1); + lua_pop(L, 1); + + lua_getfield(L, LUA_GLOBALSINDEX, "incuv"); + // two clones + lua_clonefunction(L, -1); + lua_clonefunction(L, -2); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 2); + lua_pop(L, 1); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3); + lua_pop(L, 1); + + // this one calls original function again + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 4); + lua_pop(L, 1); + } } static bool endsWith(const std::string& str, const std::string& suffix) @@ -1113,11 +1165,6 @@ TEST_CASE("UserdataApi") TEST_CASE("Iter") { - ScopedFastFlag sffs[] = { - {"LuauCompileIter", true}, - {"LuauIter", true}, - }; - runConformance("iter.lua"); } diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index c7e18ef..89b13ab 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -300,4 +300,24 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); } +TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") +{ + ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; + + fileResolver.source["Module/A"] = R"( +export type A = B +type B = A + )"; + + FrontendOptions opts; + opts.retainFullTypeGraphs = false; + CheckResult result = frontend.check("Module/A", opts); + LUAU_REQUIRE_ERRORS(result); + + auto mod = frontend.moduleResolver.getModule("Module/A"); + auto it = mod->getModuleScope()->exportedTypeBindings.find("A"); + REQUIRE(it != mod->getModuleScope()->exportedTypeBindings.end()); + CHECK(toString(it->second.type) == "any"); +} + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 87b1263..878023e 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2622,6 +2622,23 @@ type Z = { a: string | T..., b: number } REQUIRE_EQ(3, result.errors.size()); } +TEST_CASE_FIXTURE(Fixture, "recover_function_return_type_annotations") +{ + ScopedFastFlag sff{"LuauReturnTypeTokenConfusion", true}; + ParseResult result = tryParse(R"( +type Custom = { x: A, y: B, z: C } +type Packed = { x: (A...) -> () } +type F = (number): Custom +type G = Packed<(number): (string, number, boolean)> +local function f(x: number) -> Custom +end + )"); + REQUIRE_EQ(3, result.errors.size()); + CHECK_EQ(result.errors[0].getMessage(), "Return types in function type annotations are written after '->' instead of ':'"); + CHECK_EQ(result.errors[1].getMessage(), "Return types in function type annotations are written after '->' instead of ':'"); + CHECK_EQ(result.errors[2].getMessage(), "Function return type annotations are written after ':' instead of '->'"); +} + TEST_CASE_FIXTURE(Fixture, "error_message_for_using_function_as_type_annotation") { ScopedFastFlag sff{"LuauParserFunctionKeywordAsTypeHelp", true}; diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 7562a4d..86cc970 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -615,8 +615,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_ */ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any") { - ScopedFastFlag sff[] = {{"LuauTwoPassAliasDefinitionFix", true}}; - CheckResult result = check(R"( local function x() local y: FutureType = {}::any @@ -633,10 +631,6 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2") { - ScopedFastFlag sff[] = { - {"LuauTwoPassAliasDefinitionFix", true}, - }; - CheckResult result = check(R"( local B = {} B.bar = 4 diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 4444cd6..1c6fe1d 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -486,8 +486,6 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_fail_missing_instantitation_follow") TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") { - ScopedFastFlag sff{"LuauTypecheckIter", true}; - CheckResult result = check(R"( local t: {string} = {} local key @@ -506,8 +504,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") { - ScopedFastFlag sff{"LuauTypecheckIter", true}; - CheckResult result = check(R"( local t: {string} = {} local extra @@ -522,8 +518,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer") { - ScopedFastFlag sff{"LuauTypecheckIter", true}; - CheckResult result = check(R"( local t = {} for k, v in t do @@ -539,8 +533,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer") TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_iter_metamethod") { - ScopedFastFlag sff{"LuauTypecheckIter", true}; - CheckResult result = check(R"( local t = {} setmetatable(t, { __iter = function(o) return next, o.children end }) diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 6785f27..207b3cf 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -932,6 +932,8 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") { + ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true}; + CheckResult result = check(R"( type T = {tag: "missing", x: nil} | {tag: "exists", x: string} @@ -947,7 +949,7 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ(R"({| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ(R"({| tag: "exists", x: string |} | {| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); } TEST_CASE_FIXTURE(Fixture, "discriminate_tag") @@ -1191,7 +1193,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") { - const std::string code = R"( + CheckResult result = check(R"( function f(a) if type(a) == "boolean" then local a1 = a @@ -1201,10 +1203,30 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") local a3 = a end end - )"; - CheckResult result = check(code); + )"); + LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil") +{ + ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true}; + + CheckResult result = check(R"( + local function f(t: {number}) + local x = t[1] + if not x then + local foo = x + else + local bar = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("number", toString(requireTypeAtPosition({6, 28}))); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 14a5a6a..a90f434 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -139,8 +139,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") { - ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; - CheckResult result = check(R"( type MyEnum = "foo" | "bar" | "baz" local a : MyEnum = "bang" diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index c493157..8a5a65f 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -197,4 +197,20 @@ TEST_CASE_FIXTURE(TypePackFixture, "std_distance") CHECK_EQ(4, std::distance(b, e)); } +TEST_CASE("content_reassignment") +{ + ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; + + TypePackVar myError{Unifiable::Error{}, /*presistent*/ true}; + + TypeArena arena; + + TypePackId futureError = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); + asMutable(futureError)->reassign(myError); + + CHECK(get(futureError) != nullptr); + CHECK(!futureError->persistent); + CHECK(futureError->owningArena == &arena); +} + TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index bb2d94b..4f8fc50 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -416,4 +416,24 @@ TEST_CASE("proof_that_isBoolean_uses_all_of") CHECK(!isBoolean(&union_)); } +TEST_CASE("content_reassignment") +{ + ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; + + TypeVar myAny{AnyTypeVar{}, /*presistent*/ true}; + myAny.normal = true; + myAny.documentationSymbol = "@global/any"; + + TypeArena arena; + + TypeId futureAny = arena.addType(FreeTypeVar{TypeLevel{}}); + asMutable(futureAny)->reassign(myAny); + + CHECK(get(futureAny) != nullptr); + CHECK(!futureAny->persistent); + CHECK(futureAny->normal); + CHECK(futureAny->documentationSymbol == "@global/any"); + CHECK(futureAny->owningArena == &arena); +} + TEST_SUITE_END(); diff --git a/tests/conformance/apicalls.lua b/tests/conformance/apicalls.lua index 7a4058b..2741662 100644 --- a/tests/conformance/apicalls.lua +++ b/tests/conformance/apicalls.lua @@ -11,4 +11,15 @@ function create_with_tm(x) return setmetatable({ a = x }, m) end +local gen = 0 +function incuv() + gen += 1 + return gen +end + +pi = 3.1415926 +function getpi() + return pi +end + return('OK') diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua index 84ac2ba..969209f 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.lua @@ -144,4 +144,21 @@ coroutine.resume(co) resumeerror(co, "fail") checkresults({ true, false, "fail" }, coroutine.resume(co)) +-- stack overflow needs to happen at the call limit +local calllimit = 20000 +function recurse(n) return n <= 1 and 1 or recurse(n-1) + 1 end + +-- we use one frame for top-level function and one frame is the service frame for coroutines +assert(recurse(calllimit - 2) == calllimit - 2) + +-- note that when calling through pcall, pcall eats one more frame +checkresults({ true, calllimit - 3 }, pcall(recurse, calllimit - 3)) +checkerror(pcall(recurse, calllimit - 2)) + +-- xpcall handler runs in context of the stack frame, but this works just fine since we allow extra stack consumption past stack overflow +checkresults({ false, "ok" }, xpcall(recurse, function() return string.reverse("ko") end, calllimit - 2)) + +-- however, if xpcall handler itself runs out of extra stack space, we get "error in error handling" +checkresults({ false, "error in error handling" }, xpcall(recurse, function() return recurse(calllimit) end, calllimit - 2)) + return 'OK' From 316838f253b64a2e469649beb5dd67c51a948f27 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 16 Jun 2022 17:52:23 -0700 Subject: [PATCH 15/19] Sync to upstream/release/531 --- Analysis/include/Luau/TypePack.h | 13 +- Analysis/include/Luau/TypeVar.h | 11 +- Analysis/include/Luau/Variant.h | 1 + Analysis/src/Autocomplete.cpp | 41 ++-- Analysis/src/BuiltinDefinitions.cpp | 36 +--- Analysis/src/Clone.cpp | 5 - Analysis/src/EmbeddedBuiltinDefinitions.cpp | 8 +- Analysis/src/Instantiation.cpp | 4 - Analysis/src/Quantify.cpp | 35 ---- Analysis/src/Scope.cpp | 5 +- Analysis/src/Substitution.cpp | 1 - Analysis/src/TxnLog.cpp | 56 +++++- Analysis/src/TypeInfer.cpp | 113 ++++++------ Analysis/src/TypePack.cpp | 21 +++ Analysis/src/TypeVar.cpp | 21 +++ Ast/src/Parser.cpp | 22 ++- Compiler/src/BytecodeBuilder.cpp | 10 +- Compiler/src/Compiler.cpp | 195 +++++--------------- VM/src/lapi.cpp | 106 +++++------ VM/src/lbuiltins.cpp | 12 +- VM/src/ldo.cpp | 25 ++- VM/src/lvmexecute.cpp | 19 +- tests/Autocomplete.test.cpp | 37 ++-- tests/Compiler.test.cpp | 13 -- tests/Conformance.test.cpp | 59 +++++- tests/Module.test.cpp | 20 ++ tests/Parser.test.cpp | 17 ++ tests/TypeInfer.aliases.test.cpp | 6 - tests/TypeInfer.loops.test.cpp | 8 - tests/TypeInfer.refinements.test.cpp | 32 +++- tests/TypeInfer.singletons.test.cpp | 2 - tests/TypePack.test.cpp | 16 ++ tests/TypeVar.test.cpp | 20 ++ tests/conformance/apicalls.lua | 11 ++ tests/conformance/pcall.lua | 17 ++ 35 files changed, 544 insertions(+), 474 deletions(-) diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index bbc65f9..c1de242 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -48,13 +48,24 @@ struct TypePackVar explicit TypePackVar(const TypePackVariant& ty); explicit TypePackVar(TypePackVariant&& ty); TypePackVar(TypePackVariant&& ty, bool persistent); + bool operator==(const TypePackVar& rhs) const; + TypePackVar& operator=(TypePackVariant&& tp); + TypePackVar& operator=(const TypePackVar& rhs); + + // Re-assignes the content of the pack, but doesn't change the owning arena and can't make pack persistent. + void reassign(const TypePackVar& rhs) + { + ty = rhs.ty; + } + TypePackVariant ty; + bool persistent = false; - // Pointer to the type arena that allocated this type. + // Pointer to the type arena that allocated this pack. TypeArena* owningArena = nullptr; }; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index b3c455c..b59e7c6 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -334,7 +334,6 @@ struct TableTypeVar // We need to know which is which when we stringify types. std::optional syntheticName; - std::map methodDefinitionLocations; // TODO: Remove with FFlag::LuauNoMethodLocations std::vector instantiatedTypeParams; std::vector instantiatedTypePackParams; ModuleName definitionModuleName; @@ -465,6 +464,14 @@ struct TypeVar final { } + // Re-assignes the content of the type, but doesn't change the owning arena and can't make type persistent. + void reassign(const TypeVar& rhs) + { + ty = rhs.ty; + normal = rhs.normal; + documentationSymbol = rhs.documentationSymbol; + } + TypeVariant ty; // Kludge: A persistent TypeVar is one that belongs to the global scope. @@ -486,6 +493,8 @@ struct TypeVar final TypeVar& operator=(const TypeVariant& rhs); TypeVar& operator=(TypeVariant&& rhs); + + TypeVar& operator=(const TypeVar& rhs); }; using SeenSet = std::set>; diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h index c9c97c9..f637222 100644 --- a/Analysis/include/Luau/Variant.h +++ b/Analysis/include/Luau/Variant.h @@ -6,6 +6,7 @@ #include #include #include +#include namespace Luau { diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index b988ed3..a8319c5 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,8 +14,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); -LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteClassSecurityLevel, false); -LUAU_FASTFLAG(LuauSelfCallAutocompleteFix) +LUAU_FASTFLAG(LuauSelfCallAutocompleteFix2) static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -248,7 +247,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ ty = follow(ty); auto canUnify = [&typeArena](TypeId subTy, TypeId superTy) { - LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix); + LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix2); InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); @@ -267,7 +266,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ TypeId expectedType = follow(*typeAtPosition); auto checkFunctionType = [typeArena, &canUnify, &expectedType](const FunctionTypeVar* ftv) { - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) { if (std::optional firstRetTy = first(ftv->retType)) return checkTypeMatch(typeArena, *firstRetTy, expectedType); @@ -308,7 +307,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } } - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) return checkTypeMatch(typeArena, ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; else return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; @@ -325,7 +324,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, std::optional containingClass = std::nullopt) { - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) rootTy = follow(rootTy); ty = follow(ty); @@ -335,7 +334,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId seen.insert(ty); auto isWrongIndexer_DEPRECATED = [indexType, useStrictFunctionIndexers = !!get(ty)](Luau::TypeId type) { - LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix); + LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix2); if (indexType == PropIndexType::Key) return false; @@ -368,7 +367,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId } }; auto isWrongIndexer = [typeArena, rootTy, indexType](Luau::TypeId type) { - LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix); + LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix2); if (indexType == PropIndexType::Key) return false; @@ -382,10 +381,15 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId return calledWithSelf == ftv->hasSelf; } - if (std::optional firstArgTy = first(ftv->argTypes)) + // If a call is made with ':', it is invalid if a function has incompatible first argument or no arguments at all + // If a call is made with '.', but it was declared with 'self', it is considered invalid if first argument is compatible + if (calledWithSelf || ftv->hasSelf) { - if (checkTypeMatch(typeArena, rootTy, *firstArgTy)) - return calledWithSelf; + if (std::optional firstArgTy = first(ftv->argTypes)) + { + if (checkTypeMatch(typeArena, rootTy, *firstArgTy)) + return calledWithSelf; + } } return !calledWithSelf; @@ -427,7 +431,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryKind::Property, type, prop.deprecated, - FFlag::LuauSelfCallAutocompleteFix ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type), + FFlag::LuauSelfCallAutocompleteFix2 ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type), typeCorrect, containingClass, &prop, @@ -462,8 +466,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId containingClass = containingClass.value_or(cls); fillProps(cls->props); if (cls->parent) - autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, - FFlag::LuauFixAutocompleteClassSecurityLevel ? containingClass : cls); + autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); } else if (auto tbl = get(ty)) fillProps(tbl->props); @@ -471,7 +474,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId { autocompleteProps(module, typeArena, rootTy, mt->table, indexType, nodes, result, seen); - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) { if (auto mtable = get(mt->metatable)) fillMetatableProps(mtable); @@ -537,7 +540,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryMap inner; std::unordered_set innerSeen; - if (!FFlag::LuauSelfCallAutocompleteFix) + if (!FFlag::LuauSelfCallAutocompleteFix2) innerSeen = seen; if (isNil(*iter)) @@ -563,7 +566,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId ++iter; } } - else if (auto pt = get(ty); pt && FFlag::LuauSelfCallAutocompleteFix) + else if (auto pt = get(ty); pt && FFlag::LuauSelfCallAutocompleteFix2) { if (pt->metatable) { @@ -571,7 +574,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId fillMetatableProps(mtable); } } - else if (FFlag::LuauSelfCallAutocompleteFix && get(get(ty))) + else if (FFlag::LuauSelfCallAutocompleteFix2 && get(get(ty))) { autocompleteProps(module, typeArena, rootTy, getSingletonTypes().stringType, indexType, nodes, result, seen); } @@ -1501,7 +1504,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - if (!FFlag::LuauSelfCallAutocompleteFix && isString(ty)) + if (!FFlag::LuauSelfCallAutocompleteFix2 && isString(ty)) return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, finder.ancestry), finder.ancestry}; else diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 5ed6de6..98737b4 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -179,44 +179,13 @@ void registerBuiltinTypes(TypeChecker& typeChecker) LUAU_ASSERT(!typeChecker.globalTypes.typeVars.isFrozen()); LUAU_ASSERT(!typeChecker.globalTypes.typePacks.isFrozen()); - TypeId numberType = typeChecker.numberType; - TypeId booleanType = typeChecker.booleanType; TypeId nilType = typeChecker.nilType; TypeArena& arena = typeChecker.globalTypes; - TypePackId oneNumberPack = arena.addTypePack({numberType}); - TypePackId oneBooleanPack = arena.addTypePack({booleanType}); - - TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}}); - TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList}); - - TypeId listOfAtLeastOneNumberToNumberType = arena.addType(FunctionTypeVar{ - listOfAtLeastOneNumber, - oneNumberPack, - }); - - TypeId listOfAtLeastZeroNumbersToNumberType = arena.addType(FunctionTypeVar{numberVariadicList, oneNumberPack}); - LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau"); LUAU_ASSERT(loadResult.success); - TypeId mathLibType = getGlobalBinding(typeChecker, "math"); - if (TableTypeVar* ttv = getMutable(mathLibType)) - { - ttv->props["min"] = makeProperty(listOfAtLeastOneNumberToNumberType, "@luau/global/math.min"); - ttv->props["max"] = makeProperty(listOfAtLeastOneNumberToNumberType, "@luau/global/math.max"); - } - - TypeId bit32LibType = getGlobalBinding(typeChecker, "bit32"); - if (TableTypeVar* ttv = getMutable(bit32LibType)) - { - ttv->props["band"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.band"); - ttv->props["bor"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.bor"); - ttv->props["bxor"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.bxor"); - ttv->props["btest"] = makeProperty(arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneBooleanPack}), "@luau/global/bit32.btest"); - } - TypeId genericK = arena.addType(GenericTypeVar{"K"}); TypeId genericV = arena.addType(GenericTypeVar{"V"}); TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); @@ -231,7 +200,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); - // next(t: Table, i: K | nil) -> (K, V) + // next(t: Table, i: K?) -> (K, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); addGlobalBinding(typeChecker, "next", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); @@ -241,8 +210,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); - // NOTE we are missing 'i: K | nil' argument in the first return types' argument. - // pairs(t: Table) -> ((Table) -> (K, V), Table, nil) + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) addGlobalBinding(typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 19e3383..9180f30 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -9,7 +9,6 @@ LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) -LUAU_FASTFLAG(LuauNoMethodLocations) namespace Luau { @@ -241,8 +240,6 @@ void TypeCloner::operator()(const TableTypeVar& t) arg = clone(arg, dest, cloneState); ttv->definitionModuleName = t.definitionModuleName; - if (!FFlag::LuauNoMethodLocations) - ttv->methodDefinitionLocations = t.methodDefinitionLocations; ttv->tags = t.tags; } @@ -406,8 +403,6 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log) { LUAU_ASSERT(!ttv->boundTo); TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index f184b74..2407e3e 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -7,7 +7,10 @@ namespace Luau static const std::string kBuiltinDefinitionLuaSrc = R"BUILTIN_SRC( declare bit32: { - -- band, bor, bxor, and btest are declared in C++ + band: (...number) -> number, + bor: (...number) -> number, + bxor: (...number) -> number, + btest: (number, ...number) -> boolean, rrotate: (number, number) -> number, lrotate: (number, number) -> number, lshift: (number, number) -> number, @@ -50,7 +53,8 @@ declare math: { asin: (number) -> number, atan2: (number, number) -> number, - -- min and max are declared in C++. + min: (number, ...number) -> number, + max: (number, ...number) -> number, pi: number, huge: number, diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 4a12027..f145a51 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -4,8 +4,6 @@ #include "Luau/TxnLog.h" #include "Luau/TypeArena.h" -LUAU_FASTFLAG(LuauNoMethodLocations) - namespace Luau { @@ -110,8 +108,6 @@ TypeId ReplaceGenerics::clean(TypeId ty) if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; return addType(std::move(clone)); } diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 8f2cc8e..2177537 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -32,41 +32,6 @@ struct Quantifier final : TypeVarOnceVisitor LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); } - void cycle(TypeId) override {} - void cycle(TypePackId) override {} - - bool operator()(TypeId ty, const FreeTypeVar& ftv) - { - return visit(ty, ftv); - } - - template - bool operator()(TypeId ty, const T& t) - { - return true; - } - - template - bool operator()(TypePackId, const T&) - { - return true; - } - - bool operator()(TypeId ty, const ConstrainedTypeVar&) - { - return true; - } - - bool operator()(TypeId ty, const TableTypeVar& ttv) - { - return visit(ty, ttv); - } - - bool operator()(TypePackId tp, const FreeTypePack& ftp) - { - return visit(tp, ftp); - } - /// @return true if outer encloses inner bool subsumes(Scope2* outer, Scope2* inner) { diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 0a362a5..011e28d 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -2,8 +2,6 @@ #include "Luau/Scope.h" -LUAU_FASTFLAG(LuauTwoPassAliasDefinitionFix); - namespace Luau { @@ -19,8 +17,7 @@ Scope::Scope(const ScopePtr& parent, int subLevel) , returnType(parent->returnType) , level(parent->level.incr()) { - if (FFlag::LuauTwoPassAliasDefinitionFix) - level = level.incr(); + level = level.incr(); level.subLevel = subLevel; } diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 50c516d..5a22dee 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -10,7 +10,6 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) -LUAU_FASTFLAG(LuauNoMethodLocations) namespace Luau { diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index e45c0cb..4c6d54e 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,6 +7,8 @@ #include #include +LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) + namespace Luau { @@ -80,18 +82,32 @@ void TxnLog::commit() { for (auto& [ty, rep] : typeVarChanges) { - TypeArena* owningArena = ty->owningArena; - TypeVar* mtv = asMutable(ty); - *mtv = rep.get()->pending; - mtv->owningArena = owningArena; + if (FFlag::LuauNonCopyableTypeVarFields) + { + asMutable(ty)->reassign(rep.get()->pending); + } + else + { + TypeArena* owningArena = ty->owningArena; + TypeVar* mtv = asMutable(ty); + *mtv = rep.get()->pending; + mtv->owningArena = owningArena; + } } for (auto& [tp, rep] : typePackChanges) { - TypeArena* owningArena = tp->owningArena; - TypePackVar* mpv = asMutable(tp); - *mpv = rep.get()->pending; - mpv->owningArena = owningArena; + if (FFlag::LuauNonCopyableTypeVarFields) + { + asMutable(tp)->reassign(rep.get()->pending); + } + else + { + TypeArena* owningArena = tp->owningArena; + TypePackVar* mpv = asMutable(tp); + *mpv = rep.get()->pending; + mpv->owningArena = owningArena; + } } clear(); @@ -178,8 +194,13 @@ PendingType* TxnLog::queue(TypeId ty) // about this type, we don't want to mutate the parent's state. auto& pending = typeVarChanges[ty]; if (!pending) + { pending = std::make_unique(*ty); + if (FFlag::LuauNonCopyableTypeVarFields) + pending->pending.owningArena = nullptr; + } + return pending.get(); } @@ -191,8 +212,13 @@ PendingTypePack* TxnLog::queue(TypePackId tp) // about this type, we don't want to mutate the parent's state. auto& pending = typePackChanges[tp]; if (!pending) + { pending = std::make_unique(*tp); + if (FFlag::LuauNonCopyableTypeVarFields) + pending->pending.owningArena = nullptr; + } + return pending.get(); } @@ -229,14 +255,24 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) { PendingType* newTy = queue(ty); - newTy->pending = replacement; + + if (FFlag::LuauNonCopyableTypeVarFields) + newTy->pending.reassign(replacement); + else + newTy->pending = replacement; + return newTy; } PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) { PendingTypePack* newTp = queue(tp); - newTp->pending = replacement; + + if (FFlag::LuauNonCopyableTypeVarFields) + newTp->pending.reassign(replacement); + else + newTp->pending = replacement; + return newTp; } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 4931bc5..447cd02 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -33,21 +33,20 @@ LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAGVARIABLE(LuauExpectedPropTypeFromIndexer, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) +LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix2, false) LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) -LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); LUAU_FASTFLAGVARIABLE(LuauApplyTypeFunctionFix, false); -LUAU_FASTFLAGVARIABLE(LuauTypecheckIter, false); LUAU_FASTFLAGVARIABLE(LuauSuccessTypingForEqualityOperations, false) -LUAU_FASTFLAGVARIABLE(LuauNoMethodLocations, false); LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) +LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false) +LUAU_FASTFLAGVARIABLE(LuauNonCopyableTypeVarFields, false) namespace Luau { @@ -358,8 +357,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo unifierState.cachedUnifyError.clear(); unifierState.skipCacheForType.clear(); - if (FFlag::LuauTwoPassAliasDefinitionFix) - duplicateTypeAliases.clear(); + duplicateTypeAliases.clear(); return std::move(currentModule); } @@ -610,7 +608,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std { if (const auto& typealias = stat->as()) { - if (FFlag::LuauTwoPassAliasDefinitionFix && typealias->name == kParseNameError) + if (typealias->name == kParseNameError) continue; auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; @@ -619,7 +617,16 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std TypeId type = bindings[name].type; if (get(follow(type))) { - *asMutable(type) = *errorRecoveryType(anyType); + if (FFlag::LuauNonCopyableTypeVarFields) + { + TypeVar* mty = asMutable(follow(type)); + mty->reassign(*errorRecoveryType(anyType)); + } + else + { + *asMutable(type) = *errorRecoveryType(anyType); + } + reportError(TypeError{typealias->location, OccursCheckFailed{}}); } } @@ -1131,45 +1138,43 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); } - if (FFlag::LuauTypecheckIter) + if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location)) { - if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location)) + // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions + // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments + // the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types + for (TypeId var : varTypes) + unify(anyType, var, forin.location); + + return check(loopScope, *forin.body); + } + + if (const TableTypeVar* iterTable = get(iterTy)) + { + // TODO: note that this doesn't cleanly handle iteration over mixed tables and tables without an indexer + // this behavior is more or less consistent with what we do for pairs(), but really both are pretty wrong and need revisiting + if (iterTable->indexer) { - // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions - // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments - // the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types + if (varTypes.size() > 0) + unify(iterTable->indexer->indexType, varTypes[0], forin.location); + + if (varTypes.size() > 1) + unify(iterTable->indexer->indexResultType, varTypes[1], forin.location); + + for (size_t i = 2; i < varTypes.size(); ++i) + unify(nilType, varTypes[i], forin.location); + } + else + { + TypeId varTy = errorRecoveryType(loopScope); + for (TypeId var : varTypes) - unify(anyType, var, forin.location); + unify(varTy, var, forin.location); - return check(loopScope, *forin.body); + reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); } - else if (const TableTypeVar* iterTable = get(iterTy)) - { - // TODO: note that this doesn't cleanly handle iteration over mixed tables and tables without an indexer - // this behavior is more or less consistent with what we do for pairs(), but really both are pretty wrong and need revisiting - if (iterTable->indexer) - { - if (varTypes.size() > 0) - unify(iterTable->indexer->indexType, varTypes[0], forin.location); - if (varTypes.size() > 1) - unify(iterTable->indexer->indexResultType, varTypes[1], forin.location); - - for (size_t i = 2; i < varTypes.size(); ++i) - unify(nilType, varTypes[i], forin.location); - } - else - { - TypeId varTy = errorRecoveryType(loopScope); - - for (TypeId var : varTypes) - unify(varTy, var, forin.location); - - reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); - } - - return check(loopScope, *forin.body); - } + return check(loopScope, *forin.body); } const FunctionTypeVar* iterFunc = get(iterTy); @@ -1334,7 +1339,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias Name name = typealias.name.value; // If the alias is missing a name, we can't do anything with it. Ignore it. - if (FFlag::LuauTwoPassAliasDefinitionFix && name == kParseNameError) + if (name == kParseNameError) return; std::optional binding; @@ -1353,8 +1358,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; - if (FFlag::LuauTwoPassAliasDefinitionFix) - duplicateTypeAliases.insert({typealias.exported, name}); + duplicateTypeAliases.insert({typealias.exported, name}); } else { @@ -1378,7 +1382,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { // If the first pass failed (this should mean a duplicate definition), the second pass isn't going to be // interesting. - if (FFlag::LuauTwoPassAliasDefinitionFix && duplicateTypeAliases.find({typealias.exported, name})) + if (duplicateTypeAliases.find({typealias.exported, name})) return; if (!binding) @@ -1422,9 +1426,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; - - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = name; @@ -1462,9 +1463,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } TypeId& bindingType = bindingsMap[name].type; - bool ok = unify(ty, bindingType, typealias.location); - if (FFlag::LuauTwoPassAliasDefinitionFix && ok) + if (unify(ty, bindingType, typealias.location)) bindingType = ty; if (FFlag::LuauLowerBoundsCalculation) @@ -1532,7 +1532,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); - if (FFlag::LuauSelfCallAutocompleteFix) + if (FFlag::LuauSelfCallAutocompleteFix2) ftv->hasSelf = true; } } @@ -3099,8 +3099,6 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T property.type = freshTy(); property.location = indexName->indexLocation; - if (!FFlag::LuauNoMethodLocations) - ttv->methodDefinitionLocations[name] = funName.location; return property.type; } else if (funName.is()) @@ -4393,8 +4391,6 @@ TypeId Anyification::clean(TypeId ty) if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; - if (!FFlag::LuauNoMethodLocations) - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; @@ -4705,8 +4701,11 @@ TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) if (isNil(ty)) return sense ? std::nullopt : std::optional(ty); - // at this point, anything else is kept if sense is true, or eliminated otherwise - return sense ? std::optional(ty) : std::nullopt; + // at this point, anything else is kept if sense is true, or replaced by nil + if (FFlag::LuauFalsyPredicateReturnsNilInstead) + return sense ? ty : nilType; + else + return sense ? std::optional(ty) : std::nullopt; }; } diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 3050323..82451bd 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -5,6 +5,8 @@ #include +LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) + namespace Luau { @@ -36,6 +38,25 @@ TypePackVar& TypePackVar::operator=(TypePackVariant&& tp) return *this; } +TypePackVar& TypePackVar::operator=(const TypePackVar& rhs) +{ + if (FFlag::LuauNonCopyableTypeVarFields) + { + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); + + reassign(rhs); + } + else + { + ty = rhs.ty; + persistent = rhs.persistent; + owningArena = rhs.owningArena; + } + + return *this; +} + TypePackIterator::TypePackIterator(TypePackId typePack) : TypePackIterator(typePack, TxnLog::empty()) { diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 12cbed9..33bfe25 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -24,6 +24,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) +LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) namespace Luau { @@ -644,6 +645,26 @@ TypeVar& TypeVar::operator=(TypeVariant&& rhs) return *this; } +TypeVar& TypeVar::operator=(const TypeVar& rhs) +{ + if (FFlag::LuauNonCopyableTypeVarFields) + { + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); + + reassign(rhs); + } + else + { + ty = rhs.ty; + persistent = rhs.persistent; + normal = rhs.normal; + owningArena = rhs.owningArena; + } + + return *this; +} + TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index eaf1991..95bce3e 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -12,6 +12,7 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false) +LUAU_FASTFLAGVARIABLE(LuauReturnTypeTokenConfusion, false) namespace Luau { @@ -1118,8 +1119,12 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector Parser::parseOptionalReturnTypeAnnotation() { - if (options.allowTypeAnnotations && lexer.current().type == ':') + if (options.allowTypeAnnotations && + (lexer.current().type == ':' || (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow))) { + if (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow) + report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'"); + nextLexeme(); unsigned int oldRecursionCount = recursionCounter; @@ -1350,8 +1355,12 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) AstArray paramTypes = copy(params); + bool returnTypeIntroducer = + FFlag::LuauReturnTypeTokenConfusion ? lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':' : false; + // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element - if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow) + if (params.size() == 1 && !varargAnnotation && monomorphic && + (FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow)) { if (allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; @@ -1359,7 +1368,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) return {params[0], {}}; } - if (lexer.current().type != Lexeme::SkinnyArrow && monomorphic && allowPack) + if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && monomorphic && allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, varargAnnotation})}; AstArray> paramNames = copy(names); @@ -1373,8 +1382,13 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray' instead of ':'"); + lexer.next(); + } // Users occasionally write '()' as the 'unit' type when they actually want to use 'nil', here we'll try to give a more specific error - if (lexer.current().type != Lexeme::SkinnyArrow && generics.size == 0 && genericPacks.size == 0 && params.size == 0) + else if (lexer.current().type != Lexeme::SkinnyArrow && generics.size == 0 && genericPacks.size == 0 && params.size == 0) { report(Location(begin.location, lexer.previousLocation()), "Expected '->' after '()' when parsing function type; did you mean 'nil'?"); diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 3aa12d9..597b2f0 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -6,8 +6,6 @@ #include #include -LUAU_FASTFLAG(LuauCompileNestedClosureO2) - namespace Luau { @@ -390,17 +388,15 @@ int32_t BytecodeBuilder::addConstantClosure(uint32_t fid) int16_t BytecodeBuilder::addChildFunction(uint32_t fid) { - if (FFlag::LuauCompileNestedClosureO2) - if (int16_t* cache = protoMap.find(fid)) - return *cache; + if (int16_t* cache = protoMap.find(fid)) + return *cache; uint32_t id = uint32_t(protos.size()); if (id >= kMaxClosureCount) return -1; - if (FFlag::LuauCompileNestedClosureO2) - protoMap[fid] = int16_t(id); + protoMap[fid] = int16_t(id); protos.push_back(fid); return int16_t(id); diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index eea56c6..7431cde 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -16,7 +16,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompileIter, false) LUAU_FASTFLAGVARIABLE(LuauCompileIterNoPairs, false) LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThreshold, 25) @@ -26,8 +25,6 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileNestedClosureO2, false) - namespace Luau { @@ -172,30 +169,6 @@ struct Compiler return node->as(); } - bool canInlineFunctionBody(AstStat* stat) - { - if (FFlag::LuauCompileNestedClosureO2) - return true; // TODO: remove this function - - struct CanInlineVisitor : AstVisitor - { - bool result = true; - - bool visit(AstExprFunction* node) override - { - result = false; - - // short-circuit to avoid analyzing nested closure bodies - return false; - } - }; - - CanInlineVisitor canInline; - stat->visit(&canInline); - - return canInline.result; - } - uint32_t compileFunction(AstExprFunction* func) { LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler"); @@ -268,7 +241,7 @@ struct Compiler f.upvals = upvals; // record information for inlining - if (options.optimizationLevel >= 2 && !func->vararg && canInlineFunctionBody(func->body) && !getfenvUsed && !setfenvUsed) + if (options.optimizationLevel >= 2 && !func->vararg && !getfenvUsed && !setfenvUsed) { f.canInline = true; f.stackSize = stackSize; @@ -827,110 +800,62 @@ struct Compiler if (pid < 0) CompileError::raise(expr->location, "Exceeded closure limit; simplify the code to compile"); - if (FFlag::LuauCompileNestedClosureO2) - { - captures.clear(); - captures.reserve(f->upvals.size()); - - for (AstLocal* uv : f->upvals) - { - LUAU_ASSERT(uv->functionDepth < expr->functionDepth); - - if (int reg = getLocalReg(uv); reg >= 0) - { - // note: we can't check if uv is an upvalue in the current frame because inlining can migrate from upvalues to locals - Variable* ul = variables.find(uv); - bool immutable = !ul || !ul->written; - - captures.push_back({immutable ? LCT_VAL : LCT_REF, uint8_t(reg)}); - } - else if (const Constant* uc = locstants.find(uv); uc && uc->type != Constant::Type_Unknown) - { - // inlining can result in an upvalue capture of a constant, in which case we can't capture without a temporary register - uint8_t reg = allocReg(expr, 1); - compileExprConstant(expr, uc, reg); - - captures.push_back({LCT_VAL, reg}); - } - else - { - LUAU_ASSERT(uv->functionDepth < expr->functionDepth - 1); - - // get upvalue from parent frame - // note: this will add uv to the current upvalue list if necessary - uint8_t uid = getUpval(uv); - - captures.push_back({LCT_UPVAL, uid}); - } - } - - // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure - // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it - // is used) - int16_t shared = -1; - - if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); - - if (cid >= 0 && cid < 32768) - shared = int16_t(cid); - } - - if (shared >= 0) - bytecode.emitAD(LOP_DUPCLOSURE, target, shared); - else - bytecode.emitAD(LOP_NEWCLOSURE, target, pid); - - for (const Capture& c : captures) - bytecode.emitABC(LOP_CAPTURE, uint8_t(c.type), c.data, 0); - - return; - } - - bool shared = false; - - // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure - // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it - // is used) - if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); - - if (cid >= 0 && cid < 32768) - { - bytecode.emitAD(LOP_DUPCLOSURE, target, int16_t(cid)); - shared = true; - } - } - - if (!shared) - bytecode.emitAD(LOP_NEWCLOSURE, target, pid); + // we use a scratch vector to reduce allocations; this is safe since compileExprFunction is not reentrant + captures.clear(); + captures.reserve(f->upvals.size()); for (AstLocal* uv : f->upvals) { LUAU_ASSERT(uv->functionDepth < expr->functionDepth); - Variable* ul = variables.find(uv); - bool immutable = !ul || !ul->written; - - if (uv->functionDepth == expr->functionDepth - 1) + if (int reg = getLocalReg(uv); reg >= 0) { - // get local variable - int reg = getLocalReg(uv); - LUAU_ASSERT(reg >= 0); + // note: we can't check if uv is an upvalue in the current frame because inlining can migrate from upvalues to locals + Variable* ul = variables.find(uv); + bool immutable = !ul || !ul->written; - bytecode.emitABC(LOP_CAPTURE, uint8_t(immutable ? LCT_VAL : LCT_REF), uint8_t(reg), 0); + captures.push_back({immutable ? LCT_VAL : LCT_REF, uint8_t(reg)}); + } + else if (const Constant* uc = locstants.find(uv); uc && uc->type != Constant::Type_Unknown) + { + // inlining can result in an upvalue capture of a constant, in which case we can't capture without a temporary register + uint8_t reg = allocReg(expr, 1); + compileExprConstant(expr, uc, reg); + + captures.push_back({LCT_VAL, reg}); } else { + LUAU_ASSERT(uv->functionDepth < expr->functionDepth - 1); + // get upvalue from parent frame // note: this will add uv to the current upvalue list if necessary uint8_t uid = getUpval(uv); - bytecode.emitABC(LOP_CAPTURE, LCT_UPVAL, uid, 0); + captures.push_back({LCT_UPVAL, uid}); } } + + // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure + // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it + // is used) + int16_t shared = -1; + + if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) + { + int32_t cid = bytecode.addConstantClosure(f->id); + + if (cid >= 0 && cid < 32768) + shared = int16_t(cid); + } + + if (shared >= 0) + bytecode.emitAD(LOP_DUPCLOSURE, target, shared); + else + bytecode.emitAD(LOP_NEWCLOSURE, target, pid); + + for (const Capture& c : captures) + bytecode.emitABC(LOP_CAPTURE, uint8_t(c.type), c.data, 0); } LuauOpcode getUnaryOp(AstExprUnary::Op op) @@ -2511,30 +2436,6 @@ struct Compiler pushLocal(stat->vars.data[i], uint8_t(vars + i)); } - bool canUnrollForBody(AstStatFor* stat) - { - if (FFlag::LuauCompileNestedClosureO2) - return true; // TODO: remove this function - - struct CanUnrollVisitor : AstVisitor - { - bool result = true; - - bool visit(AstExprFunction* node) override - { - result = false; - - // short-circuit to avoid analyzing nested closure bodies - return false; - } - }; - - CanUnrollVisitor canUnroll; - stat->body->visit(&canUnroll); - - return canUnroll.result; - } - bool tryCompileUnrolledFor(AstStatFor* stat, int thresholdBase, int thresholdMaxBoost) { Constant one = {Constant::Type_Number}; @@ -2560,12 +2461,6 @@ struct Compiler return false; } - if (!canUnrollForBody(stat)) - { - bytecode.addDebugRemark("loop unroll failed: unsupported loop body"); - return false; - } - if (Variable* lv = variables.find(stat->var); lv && lv->written) { bytecode.addDebugRemark("loop unroll failed: mutable loop variable"); @@ -2730,12 +2625,12 @@ struct Compiler uint8_t vars = allocReg(stat, std::max(unsigned(stat->vars.size), 2u)); LUAU_ASSERT(vars == regs + 3); - // Optimization: when we iterate through pairs/ipairs, we generate special bytecode that optimizes the traversal using internal iteration - // index These instructions dynamically check if generator is equal to next/inext and bail out They assume that the generator produces 2 - // variables, which is why we allocate at least 2 above (see vars assignment) - LuauOpcode skipOp = FFlag::LuauCompileIter ? LOP_FORGPREP : LOP_JUMP; + LuauOpcode skipOp = LOP_FORGPREP; LuauOpcode loopOp = LOP_FORGLOOP; + // Optimization: when we iterate via pairs/ipairs, we generate special bytecode that optimizes the traversal using internal iteration index + // These instructions dynamically check if generator is equal to next/inext and bail out + // They assume that the generator produces 2 variables, which is why we allocate at least 2 above (see vars assignment) if (options.optimizationLevel >= 1 && stat->vars.size <= 2) { if (stat->values.size == 1 && stat->values.data[0]->is()) diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index f86371d..3c3b7bd 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,6 +14,26 @@ #include +/* + * This file contains most implementations of core Lua APIs from lua.h. + * + * These implementations should use api_check macros to verify that stack and type contracts hold; it's the callers + * responsibility to, for example, pass a valid table index to lua_rawgetfield. Generally errors should only be raised + * for conditions caller can't predict such as an out-of-memory error. + * + * The caller is expected to handle stack reservation (by using less than LUA_MINSTACK slots or by calling lua_checkstack). + * To ensure this is handled correctly, use api_incr_top(L) when pushing values to the stack. + * + * Functions that push any collectable objects to the stack *should* call luaC_checkthreadsleep. Failure to do this can result + * in stack references that point to dead objects since sleeping threads don't get rescanned. + * + * Functions that push newly created objects to the stack *should* call luaC_checkGC in addition to luaC_checkthreadsleep. + * Failure to do this can result in OOM since GC may never run. + * + * Note that luaC_checkGC may scan the thread and put it back to sleep; functions that call both before pushing objects must + * therefore call luaC_checkGC before luaC_checkthreadsleep to guarantee the object is pushed to an awake thread. + */ + const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -221,15 +241,13 @@ void lua_insert(lua_State* L, int idx) void lua_replace(lua_State* L, int idx) { - /* explicit test for incompatible code */ - if (idx == LUA_ENVIRONINDEX && L->ci == L->base_ci) - luaG_runerror(L, "no calling environment"); api_checknelems(L, 1); luaC_checkthreadsleep(L); StkId o = index2addr(L, idx); api_checkvalidindex(L, o); if (idx == LUA_ENVIRONINDEX) { + api_check(L, L->ci != L->base_ci); Closure* func = curr_func(L); api_check(L, ttistable(L->top - 1)); func->env = hvalue(L->top - 1); @@ -443,9 +461,7 @@ const float* lua_tovector(lua_State* L, int idx) { StkId o = index2addr(L, idx); if (!ttisvector(o)) - { return NULL; - } return vvalue(o); } @@ -460,11 +476,6 @@ int lua_objlen(lua_State* L, int idx) return uvalue(o)->len; case LUA_TTABLE: return luaH_getn(hvalue(o)); - case LUA_TNUMBER: - { - int l = (luaV_tostring(L, o) ? tsvalue(o)->len : 0); - return l; - } default: return 0; } @@ -752,10 +763,9 @@ void lua_setsafeenv(lua_State* L, int objindex, int enabled) int lua_getmetatable(lua_State* L, int objindex) { - const TValue* obj; + luaC_checkthreadsleep(L); Table* mt = NULL; - int res; - obj = index2addr(L, objindex); + const TValue* obj = index2addr(L, objindex); switch (ttype(obj)) { case LUA_TTABLE: @@ -768,21 +778,18 @@ int lua_getmetatable(lua_State* L, int objindex) mt = L->global->mt[ttype(obj)]; break; } - if (mt == NULL) - res = 0; - else + if (mt) { sethvalue(L, L->top, mt); api_incr_top(L); - res = 1; } - return res; + return mt != NULL; } void lua_getfenv(lua_State* L, int idx) { - StkId o; - o = index2addr(L, idx); + luaC_checkthreadsleep(L); + StkId o = index2addr(L, idx); api_checkvalidindex(L, o); switch (ttype(o)) { @@ -806,9 +813,8 @@ void lua_getfenv(lua_State* L, int idx) void lua_settable(lua_State* L, int idx) { - StkId t; api_checknelems(L, 2); - t = index2addr(L, idx); + StkId t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_settable(L, t, L->top - 2, L->top - 1); L->top -= 2; /* pop index and value */ @@ -817,22 +823,20 @@ void lua_settable(lua_State* L, int idx) void lua_setfield(lua_State* L, int idx, const char* k) { - StkId t; - TValue key; api_checknelems(L, 1); - t = index2addr(L, idx); + StkId t = index2addr(L, idx); api_checkvalidindex(L, t); + TValue key; setsvalue(L, &key, luaS_new(L, k)); luaV_settable(L, t, &key, L->top - 1); - L->top--; /* pop value */ + L->top--; return; } void lua_rawset(lua_State* L, int idx) { - StkId t; api_checknelems(L, 2); - t = index2addr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); if (hvalue(t)->readonly) luaG_runerror(L, "Attempt to modify a readonly table"); @@ -844,9 +848,8 @@ void lua_rawset(lua_State* L, int idx) void lua_rawseti(lua_State* L, int idx, int n) { - StkId o; api_checknelems(L, 1); - o = index2addr(L, idx); + StkId o = index2addr(L, idx); api_check(L, ttistable(o)); if (hvalue(o)->readonly) luaG_runerror(L, "Attempt to modify a readonly table"); @@ -858,14 +861,11 @@ void lua_rawseti(lua_State* L, int idx, int n) int lua_setmetatable(lua_State* L, int objindex) { - TValue* obj; - Table* mt; api_checknelems(L, 1); - obj = index2addr(L, objindex); + TValue* obj = index2addr(L, objindex); api_checkvalidindex(L, obj); - if (ttisnil(L->top - 1)) - mt = NULL; - else + Table* mt = NULL; + if (!ttisnil(L->top - 1)) { api_check(L, ttistable(L->top - 1)); mt = hvalue(L->top - 1); @@ -900,10 +900,9 @@ int lua_setmetatable(lua_State* L, int objindex) int lua_setfenv(lua_State* L, int idx) { - StkId o; int res = 1; api_checknelems(L, 1); - o = index2addr(L, idx); + StkId o = index2addr(L, idx); api_checkvalidindex(L, o); api_check(L, ttistable(L->top - 1)); switch (ttype(o)) @@ -970,24 +969,21 @@ static void f_call(lua_State* L, void* ud) int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) { - struct CallS c; - int status; - ptrdiff_t func; api_checknelems(L, nargs + 1); api_check(L, L->status == 0); checkresults(L, nargs, nresults); - if (errfunc == 0) - func = 0; - else + ptrdiff_t func = 0; + if (errfunc != 0) { StkId o = index2addr(L, errfunc); api_checkvalidindex(L, o); func = savestack(L, o); } + struct CallS c; c.func = L->top - (nargs + 1); /* function to be called */ c.nresults = nresults; - status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); + int status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); adjustresults(L, nresults); return status; @@ -1247,12 +1243,10 @@ const char* lua_getupvalue(lua_State* L, int funcindex, int n) const char* lua_setupvalue(lua_State* L, int funcindex, int n) { - const char* name; - TValue* val; - StkId fi; - fi = index2addr(L, funcindex); api_checknelems(L, 1); - name = aux_upvalue(fi, n, &val); + StkId fi = index2addr(L, funcindex); + TValue* val; + const char* name = aux_upvalue(fi, n, &val); if (name) { L->top--; @@ -1319,14 +1313,16 @@ void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(lua_State*, void*)) void lua_clonefunction(lua_State* L, int idx) { + luaC_checkGC(L); + luaC_checkthreadsleep(L); StkId p = index2addr(L, idx); api_check(L, isLfunction(p)); - - luaC_checkthreadsleep(L); - Closure* cl = clvalue(p); - Closure* newcl = luaF_newLclosure(L, 0, L->gt, cl->l.p); - setclvalue(L, L->top - 1, newcl); + Closure* newcl = luaF_newLclosure(L, cl->nupvalues, L->gt, cl->l.p); + for (int i = 0; i < cl->nupvalues; ++i) + setobj2n(L, &newcl->l.uprefs[i], &cl->l.uprefs[i]); + setclvalue(L, L->top, newcl); + api_incr_top(L); } lua_Callbacks* lua_callbacks(lua_State* L) diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index cc6e560..deaf140 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1018,18 +1018,20 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { -#if LUA_VECTOR_SIZE == 4 - if (nparams >= 4 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1) && ttisnumber(args + 2)) -#else if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) -#endif { double x = nvalue(arg0); double y = nvalue(args); double z = nvalue(args + 1); #if LUA_VECTOR_SIZE == 4 - double w = nvalue(args + 2); + double w = 0.0; + if (nparams >= 4) + { + if (!ttisnumber(args + 2)) + return -1; + w = nvalue(args + 2); + } setvvalue(res, float(x), float(y), float(z), float(w)); #else setvvalue(res, float(x), float(y), float(z), 0.0f); diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index a71fce5..0642cb6 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -202,22 +202,29 @@ void luaD_growstack(lua_State* L, int n) CallInfo* luaD_growCI(lua_State* L) { - if (L->size_ci > LUAI_MAXCALLS) /* overflow while handling overflow? */ - luaD_throw(L, LUA_ERRERR); - else - { - luaD_reallocCI(L, 2 * L->size_ci); - if (L->size_ci > LUAI_MAXCALLS) - luaG_runerror(L, "stack overflow"); - } + /* allow extra stack space to handle stack overflow in xpcall */ + const int hardlimit = LUAI_MAXCALLS + (LUAI_MAXCALLS >> 3); + + if (L->size_ci >= hardlimit) + luaD_throw(L, LUA_ERRERR); /* error while handling stack error */ + + int request = L->size_ci * 2; + luaD_reallocCI(L, L->size_ci >= LUAI_MAXCALLS ? hardlimit : request < LUAI_MAXCALLS ? request : LUAI_MAXCALLS); + + if (L->size_ci > LUAI_MAXCALLS) + luaG_runerror(L, "stack overflow"); + return ++L->ci; } void luaD_checkCstack(lua_State* L) { + /* allow extra stack space to handle stack overflow in xpcall */ + const int hardlimit = LUAI_MAXCCALLS + (LUAI_MAXCCALLS >> 3); + if (L->nCcalls == LUAI_MAXCCALLS) luaG_runerror(L, "C stack overflow"); - else if (L->nCcalls >= (LUAI_MAXCCALLS + (LUAI_MAXCCALLS >> 3))) + else if (L->nCcalls >= hardlimit) luaD_throw(L, LUA_ERRERR); /* error while handling stack error */ } diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index f9fd657..e0a9647 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,8 +16,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauIter, false) - // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -2214,7 +2212,7 @@ static void luau_execute(lua_State* L) { /* will be called during FORGLOOP */ } - else if (FFlag::LuauIter) + else { Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); @@ -2259,17 +2257,6 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); uint32_t aux = *pc; - if (!FFlag::LuauIter) - { - bool stop; - VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), aux)); - - // note that we need to increment pc by 1 to exit the loop since we need to skip over aux - pc += stop ? 1 : LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } - // fast-path: builtin table iteration if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) { @@ -2362,7 +2349,7 @@ static void luau_execute(lua_State* L) /* ra+1 is already the table */ setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } - else if (FFlag::LuauIter && !ttisfunction(ra)) + else if (!ttisfunction(ra)) { VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); } @@ -2434,7 +2421,7 @@ static void luau_execute(lua_State* L) /* ra+1 is already the table */ setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } - else if (FFlag::LuauIter && !ttisfunction(ra)) + else if (!ttisfunction(ra)) { VM_PROTECT(luaG_typeerror(L, ra, "iterate over")); } diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index cc5b31c..dea1ab1 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2764,8 +2764,6 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") { - ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; - check(R"( type tag = "cat" | "dog" local function f(a: tag) end @@ -2838,8 +2836,6 @@ f(@1) TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_escape") { - ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; - check(R"( type tag = "strange\t\"cat\"" | 'nice\t"dog"' local function f(x: tag) end @@ -2873,7 +2869,7 @@ local abc = b@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_on_class") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; loadDefinition(R"( declare class Foo @@ -2913,7 +2909,7 @@ t.@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local t = {} @@ -2929,7 +2925,7 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_2") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local f: (() -> number) & ((number) -> number) = function(x: number?) return 2 end @@ -2961,7 +2957,7 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "string_prim_self_calls_are_fine") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local s = "hello" @@ -2980,7 +2976,7 @@ s:@1 TEST_CASE_FIXTURE(ACFixture, "string_prim_non_self_calls_are_avoided") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( local s = "hello" @@ -2989,17 +2985,15 @@ s.@1 auto ac = autocomplete('1'); - REQUIRE(ac.entryMap.count("byte")); - CHECK(ac.entryMap["byte"].wrongIndexType == true); REQUIRE(ac.entryMap.count("char")); CHECK(ac.entryMap["char"].wrongIndexType == false); REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == true); } -TEST_CASE_FIXTURE(ACBuiltinsFixture, "string_library_non_self_calls_are_fine") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_non_self_calls_are_fine") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( string.@1 @@ -3013,11 +3007,24 @@ string.@1 CHECK(ac.entryMap["char"].wrongIndexType == false); REQUIRE(ac.entryMap.count("sub")); CHECK(ac.entryMap["sub"].wrongIndexType == false); + + check(R"( +table.@1 + )"); + + ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("remove")); + CHECK(ac.entryMap["remove"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("getn")); + CHECK(ac.entryMap["getn"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("insert")); + CHECK(ac.entryMap["insert"].wrongIndexType == false); } -TEST_CASE_FIXTURE(ACBuiltinsFixture, "string_library_self_calls_are_invalid") +TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_self_calls_are_invalid") { - ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + ScopedFastFlag selfCallAutocompleteFix2{"LuauSelfCallAutocompleteFix2", true}; check(R"( string:@1 diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 2013965..6eee254 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -261,7 +261,6 @@ L1: RETURN R0 0 TEST_CASE("ForBytecode") { - ScopedFastFlag sff("LuauCompileIter", true); ScopedFastFlag sff2("LuauCompileIterNoPairs", false); // basic for loop: variable directly refers to internal iteration index (R2) @@ -350,8 +349,6 @@ RETURN R0 0 TEST_CASE("ForBytecodeBuiltin") { - ScopedFastFlag sff("LuauCompileIter", true); - // we generally recognize builtins like pairs/ipairs and emit special opcodes CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"( GETIMPORT R0 1 @@ -2323,8 +2320,6 @@ return result TEST_CASE("DebugLineInfoFor") { - ScopedFastFlag sff("LuauCompileIter", true); - Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); Luau::compileOrThrow(bcb, R"( @@ -4355,8 +4350,6 @@ L1: RETURN R0 0 TEST_CASE("LoopUnrollControlFlow") { - ScopedFastFlag sff("LuauCompileNestedClosureO2", true); - ScopedFastInt sfis[] = { {"LuauCompileLoopUnrollThreshold", 50}, {"LuauCompileLoopUnrollThresholdMaxBoost", 300}, @@ -4475,8 +4468,6 @@ RETURN R0 0 TEST_CASE("LoopUnrollNestedClosure") { - ScopedFastFlag sff("LuauCompileNestedClosureO2", true); - // if the body has functions that refer to loop variables, we unroll the loop and use MOVE+CAPTURE for upvalues CHECK_EQ("\n" + compileFunction(R"( for i=1,2 do @@ -4756,8 +4747,6 @@ RETURN R1 1 TEST_CASE("InlineBasicProhibited") { - ScopedFastFlag sff("LuauCompileNestedClosureO2", true); - // we can't inline variadic functions CHECK_EQ("\n" + compileFunction(R"( local function foo(...) @@ -4833,8 +4822,6 @@ RETURN R1 1 TEST_CASE("InlineNestedClosures") { - ScopedFastFlag sff("LuauCompileNestedClosureO2", true); - // we can inline functions that contain/return functions CHECK_EQ("\n" + compileFunction(R"( local function foo(x) diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index f7f2b4a..96a2775 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -741,7 +741,7 @@ TEST_CASE("ApiTables") lua_pop(L, 1); } -TEST_CASE("ApiFunctionCalls") +TEST_CASE("ApiCalls") { StateRef globalState = runConformance("apicalls.lua"); lua_State* L = globalState.get(); @@ -790,6 +790,58 @@ TEST_CASE("ApiFunctionCalls") CHECK(lua_equal(L2, -1, -2) == 1); lua_pop(L2, 2); } + + // lua_clonefunction + fenv + { + lua_getfield(L, LUA_GLOBALSINDEX, "getpi"); + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3.1415926); + lua_pop(L, 1); + + lua_getfield(L, LUA_GLOBALSINDEX, "getpi"); + + // clone & override env + lua_clonefunction(L, -1); + lua_newtable(L); + lua_pushnumber(L, 42); + lua_setfield(L, -2, "pi"); + lua_setfenv(L, -2); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); + + // this one calls original function again + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3.1415926); + lua_pop(L, 1); + } + + // lua_clonefunction + upvalues + { + lua_getfield(L, LUA_GLOBALSINDEX, "incuv"); + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 1); + lua_pop(L, 1); + + lua_getfield(L, LUA_GLOBALSINDEX, "incuv"); + // two clones + lua_clonefunction(L, -1); + lua_clonefunction(L, -2); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 2); + lua_pop(L, 1); + + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 3); + lua_pop(L, 1); + + // this one calls original function again + lua_call(L, 0, 1); + CHECK(lua_tonumber(L, -1) == 4); + lua_pop(L, 1); + } } static bool endsWith(const std::string& str, const std::string& suffix) @@ -1113,11 +1165,6 @@ TEST_CASE("UserdataApi") TEST_CASE("Iter") { - ScopedFastFlag sffs[] = { - {"LuauCompileIter", true}, - {"LuauIter", true}, - }; - runConformance("iter.lua"); } diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index c7e18ef..89b13ab 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -300,4 +300,24 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); } +TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") +{ + ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; + + fileResolver.source["Module/A"] = R"( +export type A = B +type B = A + )"; + + FrontendOptions opts; + opts.retainFullTypeGraphs = false; + CheckResult result = frontend.check("Module/A", opts); + LUAU_REQUIRE_ERRORS(result); + + auto mod = frontend.moduleResolver.getModule("Module/A"); + auto it = mod->getModuleScope()->exportedTypeBindings.find("A"); + REQUIRE(it != mod->getModuleScope()->exportedTypeBindings.end()); + CHECK(toString(it->second.type) == "any"); +} + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 87b1263..878023e 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2622,6 +2622,23 @@ type Z = { a: string | T..., b: number } REQUIRE_EQ(3, result.errors.size()); } +TEST_CASE_FIXTURE(Fixture, "recover_function_return_type_annotations") +{ + ScopedFastFlag sff{"LuauReturnTypeTokenConfusion", true}; + ParseResult result = tryParse(R"( +type Custom = { x: A, y: B, z: C } +type Packed = { x: (A...) -> () } +type F = (number): Custom +type G = Packed<(number): (string, number, boolean)> +local function f(x: number) -> Custom +end + )"); + REQUIRE_EQ(3, result.errors.size()); + CHECK_EQ(result.errors[0].getMessage(), "Return types in function type annotations are written after '->' instead of ':'"); + CHECK_EQ(result.errors[1].getMessage(), "Return types in function type annotations are written after '->' instead of ':'"); + CHECK_EQ(result.errors[2].getMessage(), "Function return type annotations are written after ':' instead of '->'"); +} + TEST_CASE_FIXTURE(Fixture, "error_message_for_using_function_as_type_annotation") { ScopedFastFlag sff{"LuauParserFunctionKeywordAsTypeHelp", true}; diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 7562a4d..86cc970 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -615,8 +615,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_ */ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any") { - ScopedFastFlag sff[] = {{"LuauTwoPassAliasDefinitionFix", true}}; - CheckResult result = check(R"( local function x() local y: FutureType = {}::any @@ -633,10 +631,6 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2") { - ScopedFastFlag sff[] = { - {"LuauTwoPassAliasDefinitionFix", true}, - }; - CheckResult result = check(R"( local B = {} B.bar = 4 diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 4444cd6..1c6fe1d 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -486,8 +486,6 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_fail_missing_instantitation_follow") TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") { - ScopedFastFlag sff{"LuauTypecheckIter", true}; - CheckResult result = check(R"( local t: {string} = {} local key @@ -506,8 +504,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") { - ScopedFastFlag sff{"LuauTypecheckIter", true}; - CheckResult result = check(R"( local t: {string} = {} local extra @@ -522,8 +518,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer") { - ScopedFastFlag sff{"LuauTypecheckIter", true}; - CheckResult result = check(R"( local t = {} for k, v in t do @@ -539,8 +533,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer") TEST_CASE_FIXTURE(BuiltinsFixture, "loop_iter_iter_metamethod") { - ScopedFastFlag sff{"LuauTypecheckIter", true}; - CheckResult result = check(R"( local t = {} setmetatable(t, { __iter = function(o) return next, o.children end }) diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 6785f27..207b3cf 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -932,6 +932,8 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") { + ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true}; + CheckResult result = check(R"( type T = {tag: "missing", x: nil} | {tag: "exists", x: string} @@ -947,7 +949,7 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ(R"({| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ(R"({| tag: "exists", x: string |} | {| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); } TEST_CASE_FIXTURE(Fixture, "discriminate_tag") @@ -1191,7 +1193,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") { - const std::string code = R"( + CheckResult result = check(R"( function f(a) if type(a) == "boolean" then local a1 = a @@ -1201,10 +1203,30 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") local a3 = a end end - )"; - CheckResult result = check(code); + )"); + LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil") +{ + ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true}; + + CheckResult result = check(R"( + local function f(t: {number}) + local x = t[1] + if not x then + local foo = x + else + local bar = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("number", toString(requireTypeAtPosition({6, 28}))); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 14a5a6a..a90f434 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -139,8 +139,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") { - ScopedFastFlag sff{"LuauTwoPassAliasDefinitionFix", true}; - CheckResult result = check(R"( type MyEnum = "foo" | "bar" | "baz" local a : MyEnum = "bang" diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index c493157..8a5a65f 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -197,4 +197,20 @@ TEST_CASE_FIXTURE(TypePackFixture, "std_distance") CHECK_EQ(4, std::distance(b, e)); } +TEST_CASE("content_reassignment") +{ + ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; + + TypePackVar myError{Unifiable::Error{}, /*presistent*/ true}; + + TypeArena arena; + + TypePackId futureError = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); + asMutable(futureError)->reassign(myError); + + CHECK(get(futureError) != nullptr); + CHECK(!futureError->persistent); + CHECK(futureError->owningArena == &arena); +} + TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index bb2d94b..4f8fc50 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -416,4 +416,24 @@ TEST_CASE("proof_that_isBoolean_uses_all_of") CHECK(!isBoolean(&union_)); } +TEST_CASE("content_reassignment") +{ + ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; + + TypeVar myAny{AnyTypeVar{}, /*presistent*/ true}; + myAny.normal = true; + myAny.documentationSymbol = "@global/any"; + + TypeArena arena; + + TypeId futureAny = arena.addType(FreeTypeVar{TypeLevel{}}); + asMutable(futureAny)->reassign(myAny); + + CHECK(get(futureAny) != nullptr); + CHECK(!futureAny->persistent); + CHECK(futureAny->normal); + CHECK(futureAny->documentationSymbol == "@global/any"); + CHECK(futureAny->owningArena == &arena); +} + TEST_SUITE_END(); diff --git a/tests/conformance/apicalls.lua b/tests/conformance/apicalls.lua index 7a4058b..2741662 100644 --- a/tests/conformance/apicalls.lua +++ b/tests/conformance/apicalls.lua @@ -11,4 +11,15 @@ function create_with_tm(x) return setmetatable({ a = x }, m) end +local gen = 0 +function incuv() + gen += 1 + return gen +end + +pi = 3.1415926 +function getpi() + return pi +end + return('OK') diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua index 84ac2ba..969209f 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.lua @@ -144,4 +144,21 @@ coroutine.resume(co) resumeerror(co, "fail") checkresults({ true, false, "fail" }, coroutine.resume(co)) +-- stack overflow needs to happen at the call limit +local calllimit = 20000 +function recurse(n) return n <= 1 and 1 or recurse(n-1) + 1 end + +-- we use one frame for top-level function and one frame is the service frame for coroutines +assert(recurse(calllimit - 2) == calllimit - 2) + +-- note that when calling through pcall, pcall eats one more frame +checkresults({ true, calllimit - 3 }, pcall(recurse, calllimit - 3)) +checkerror(pcall(recurse, calllimit - 2)) + +-- xpcall handler runs in context of the stack frame, but this works just fine since we allow extra stack consumption past stack overflow +checkresults({ false, "ok" }, xpcall(recurse, function() return string.reverse("ko") end, calllimit - 2)) + +-- however, if xpcall handler itself runs out of extra stack space, we get "error in error handling" +checkresults({ false, "error in error handling" }, xpcall(recurse, function() return recurse(calllimit) end, calllimit - 2)) + return 'OK' From 88b3984711dff98bf51f1a152dcee874f79bd4a0 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 16 Jun 2022 17:54:42 -0700 Subject: [PATCH 16/19] Sync to upstream/release/532 --- Analysis/include/Luau/Constraint.h | 82 ++++ .../include/Luau/ConstraintGraphBuilder.h | 113 ++--- Analysis/include/Luau/ConstraintSolver.h | 103 +++-- .../include/Luau/ConstraintSolverLogger.h | 26 ++ Analysis/include/Luau/Frontend.h | 4 +- Analysis/include/Luau/Module.h | 1 + Analysis/include/Luau/Normalize.h | 1 + Analysis/include/Luau/NotNull.h | 41 +- Analysis/include/Luau/Quantify.h | 3 +- Analysis/include/Luau/RequireTracer.h | 2 +- Analysis/include/Luau/TypeChecker2.h | 13 + Analysis/include/Luau/TypeInfer.h | 41 +- Analysis/include/Luau/TypeVar.h | 37 +- Analysis/include/Luau/Unifier.h | 2 +- Analysis/include/Luau/VisitTypeVar.h | 2 +- Analysis/src/Autocomplete.cpp | 19 +- Analysis/src/BuiltinDefinitions.cpp | 68 +-- Analysis/src/Clone.cpp | 10 +- Analysis/src/Constraint.cpp | 14 + Analysis/src/ConstraintGraphBuilder.cpp | 406 +++++++++++++++--- Analysis/src/ConstraintSolver.cpp | 174 +++++--- Analysis/src/ConstraintSolverLogger.cpp | 139 ++++++ Analysis/src/Frontend.cpp | 34 +- Analysis/src/Instantiation.cpp | 2 +- Analysis/src/Linter.cpp | 4 +- Analysis/src/Normalize.cpp | 46 +- Analysis/src/Quantify.cpp | 153 ++++++- Analysis/src/RequireTracer.cpp | 14 +- Analysis/src/Substitution.cpp | 4 +- Analysis/src/ToDot.cpp | 2 +- Analysis/src/ToString.cpp | 94 +++- Analysis/src/TypeAttach.cpp | 7 +- Analysis/src/TypeChecker2.cpp | 160 +++++++ Analysis/src/TypeInfer.cpp | 253 ++++++----- Analysis/src/TypeUtils.cpp | 2 +- Analysis/src/TypeVar.cpp | 41 +- Analysis/src/Unifier.cpp | 13 +- Compiler/src/BytecodeBuilder.cpp | 14 + Compiler/src/Compiler.cpp | 76 +++- Sources.cmake | 8 +- VM/src/lobject.h | 2 +- VM/src/ltable.cpp | 8 +- VM/src/ltm.cpp | 4 +- VM/src/ltm.h | 4 +- tests/Autocomplete.test.cpp | 2 +- tests/Compiler.test.cpp | 252 ++++++++++- tests/ConstraintGraphBuilder.test.cpp | 61 ++- tests/Fixture.cpp | 3 +- tests/Frontend.test.cpp | 8 +- tests/Module.test.cpp | 2 +- tests/NonstrictMode.test.cpp | 53 ++- tests/Normalize.test.cpp | 42 +- tests/NotNull.test.cpp | 53 ++- tests/ToString.test.cpp | 36 ++ tests/TypeInfer.annotations.test.cpp | 2 +- tests/TypeInfer.functions.test.cpp | 60 ++- tests/TypeInfer.generics.test.cpp | 6 +- tests/TypeInfer.operators.test.cpp | 4 +- tests/TypeInfer.provisional.test.cpp | 5 +- tests/TypeInfer.refinements.test.cpp | 6 +- tests/TypeInfer.tables.test.cpp | 11 +- tests/TypeInfer.test.cpp | 89 +--- tests/TypeInfer.typePacks.cpp | 14 +- tools/natvis/Analysis.natvis | 66 +-- 64 files changed, 2276 insertions(+), 745 deletions(-) create mode 100644 Analysis/include/Luau/Constraint.h create mode 100644 Analysis/include/Luau/ConstraintSolverLogger.h create mode 100644 Analysis/include/Luau/TypeChecker2.h create mode 100644 Analysis/src/Constraint.cpp create mode 100644 Analysis/src/ConstraintSolverLogger.cpp create mode 100644 Analysis/src/TypeChecker2.cpp diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h new file mode 100644 index 0000000..c62166e --- /dev/null +++ b/Analysis/include/Luau/Constraint.h @@ -0,0 +1,82 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" +#include "Luau/NotNull.h" +#include "Luau/Variant.h" + +#include +#include + +namespace Luau +{ + +struct Scope2; +struct TypeVar; +using TypeId = const TypeVar*; + +struct TypePackVar; +using TypePackId = const TypePackVar*; + +// subType <: superType +struct SubtypeConstraint +{ + TypeId subType; + TypeId superType; +}; + +// subPack <: superPack +struct PackSubtypeConstraint +{ + TypePackId subPack; + TypePackId superPack; +}; + +// subType ~ gen superType +struct GeneralizationConstraint +{ + TypeId generalizedType; + TypeId sourceType; + Scope2* scope; +}; + +// subType ~ inst superType +struct InstantiationConstraint +{ + TypeId subType; + TypeId superType; +}; + +using ConstraintV = Variant; +using ConstraintPtr = std::unique_ptr; + +struct Constraint +{ + Constraint(ConstraintV&& c, Location location); + + Constraint(const Constraint&) = delete; + Constraint& operator=(const Constraint&) = delete; + + ConstraintV c; + Location location; + std::vector> dependencies; +}; + +inline Constraint& asMutable(const Constraint& c) +{ + return const_cast(c); +} + +template +T* getMutable(Constraint& c) +{ + return ::Luau::get_if(&c.c); +} + +template +const T* get(const Constraint& c) +{ + return getMutable(asMutable(c)); +} + +} // namespace Luau diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 4234f2f..da774a2 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -4,9 +4,12 @@ #include #include +#include #include "Luau/Ast.h" +#include "Luau/Constraint.h" #include "Luau/Module.h" +#include "Luau/NotNull.h" #include "Luau/Symbol.h" #include "Luau/TypeVar.h" #include "Luau/Variant.h" @@ -14,69 +17,6 @@ namespace Luau { -struct Scope2; - -// subType <: superType -struct SubtypeConstraint -{ - TypeId subType; - TypeId superType; -}; - -// subPack <: superPack -struct PackSubtypeConstraint -{ - TypePackId subPack; - TypePackId superPack; -}; - -// subType ~ gen superType -struct GeneralizationConstraint -{ - TypeId subType; - TypeId superType; - Scope2* scope; -}; - -// subType ~ inst superType -struct InstantiationConstraint -{ - TypeId subType; - TypeId superType; -}; - -using ConstraintV = Variant; -using ConstraintPtr = std::unique_ptr; - -struct Constraint -{ - Constraint(ConstraintV&& c); - Constraint(ConstraintV&& c, std::vector dependencies); - - Constraint(const Constraint&) = delete; - Constraint& operator=(const Constraint&) = delete; - - ConstraintV c; - std::vector dependencies; -}; - -inline Constraint& asMutable(const Constraint& c) -{ - return const_cast(c); -} - -template -T* getMutable(Constraint& c) -{ - return ::Luau::get_if(&c.c); -} - -template -const T* get(const Constraint& c) -{ - return getMutable(asMutable(c)); -} - struct Scope2 { // The parent scope of this scope. Null if there is no parent (i.e. this @@ -102,6 +42,11 @@ struct ConstraintGraphBuilder TypeArena* const arena; // The root scope of the module we're generating constraints for. Scope2* rootScope; + // A mapping of AST node to TypeId. + DenseHashMap astTypes{nullptr}; + // A mapping of AST node to TypePackId. + DenseHashMap astTypePacks{nullptr}; + DenseHashMap astOriginalCallTypes{nullptr}; explicit ConstraintGraphBuilder(TypeArena* arena); @@ -128,8 +73,9 @@ struct ConstraintGraphBuilder * Adds a new constraint with no dependencies to a given scope. * @param scope the scope to add the constraint to. Must not be null. * @param cv the constraint variant to add. + * @param location the location to attribute to the constraint. */ - void addConstraint(Scope2* scope, ConstraintV cv); + void addConstraint(Scope2* scope, ConstraintV cv, Location location); /** * Adds a constraint to a given scope. @@ -148,15 +94,48 @@ struct ConstraintGraphBuilder void visit(Scope2* scope, AstStat* stat); void visit(Scope2* scope, AstStatBlock* block); void visit(Scope2* scope, AstStatLocal* local); - void visit(Scope2* scope, AstStatLocalFunction* local); - void visit(Scope2* scope, AstStatReturn* local); + void visit(Scope2* scope, AstStatLocalFunction* function); + void visit(Scope2* scope, AstStatFunction* function); + void visit(Scope2* scope, AstStatReturn* ret); + void visit(Scope2* scope, AstStatAssign* assign); + void visit(Scope2* scope, AstStatIf* ifStatement); + + TypePackId checkExprList(Scope2* scope, const AstArray& exprs); TypePackId checkPack(Scope2* scope, AstArray exprs); TypePackId checkPack(Scope2* scope, AstExpr* expr); + /** + * Checks an expression that is expected to evaluate to one type. + * @param scope the scope the expression is contained within. + * @param expr the expression to check. + * @return the type of the expression. + */ TypeId check(Scope2* scope, AstExpr* expr); + + TypeId checkExprTable(Scope2* scope, AstExprTable* expr); + TypeId check(Scope2* scope, AstExprIndexName* indexName); + + std::pair checkFunctionSignature(Scope2* parent, AstExprFunction* fn); + + /** + * Checks the body of a function expression. + * @param scope the interior scope of the body of the function. + * @param fn the function expression to check. + */ + void checkFunctionBody(Scope2* scope, AstExprFunction* fn); }; -std::vector collectConstraints(Scope2* rootScope); +/** + * Collects a vector of borrowed constraints from the scope and all its child + * scopes. It is important to only call this function when you're done adding + * constraints to the scope or its descendants, lest the borrowed pointers + * become invalid due to a container reallocation. + * @param rootScope the root scope of the scope graph to collect constraints + * from. + * @return a list of pointers to constraints contained within the scope graph. + * None of these pointers should be null. + */ +std::vector> collectConstraints(Scope2* rootScope); } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 85006e6..7e6d446 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -4,7 +4,8 @@ #include "Luau/Error.h" #include "Luau/Variant.h" -#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/Constraint.h" +#include "Luau/ConstraintSolverLogger.h" #include "Luau/TypeVar.h" #include @@ -20,39 +21,81 @@ struct ConstraintSolver { TypeArena* arena; InternalErrorReporter iceReporter; - // The entire set of constraints that the solver is trying to resolve. - std::vector constraints; + // The entire set of constraints that the solver is trying to resolve. It + // is important to not add elements to this vector, lest the underlying + // storage that we retain pointers to be mutated underneath us. + const std::vector> constraints; Scope2* rootScope; - std::vector errors; // This includes every constraint that has not been fully solved. // A constraint can be both blocked and unsolved, for instance. - std::unordered_set unsolvedConstraints; + std::vector> unsolvedConstraints; // A mapping of constraint pointer to how many things the constraint is // blocked on. Can be empty or 0 for constraints that are not blocked on // anything. - std::unordered_map blockedConstraints; + std::unordered_map, size_t> blockedConstraints; // A mapping of type/pack pointers to the constraints they block. - std::unordered_map> blocked; + std::unordered_map>> blocked; + + ConstraintSolverLogger logger; explicit ConstraintSolver(TypeArena* arena, Scope2* rootScope); /** * Attempts to dispatch all pending constraints and reach a type solution - * that satisfies all of the constraints, recording any errors that are - * encountered. + * that satisfies all of the constraints. **/ void run(); bool done(); - bool tryDispatch(const Constraint* c); - bool tryDispatch(const SubtypeConstraint& c); - bool tryDispatch(const PackSubtypeConstraint& c); - bool tryDispatch(const GeneralizationConstraint& c); - bool tryDispatch(const InstantiationConstraint& c, const Constraint* constraint); + bool tryDispatch(NotNull c, bool force); + bool tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force); + void block(NotNull target, NotNull constraint); + /** + * Block a constraint on the resolution of a TypeVar. + * @returns false always. This is just to allow tryDispatch to return the result of block() + */ + bool block(TypeId target, NotNull constraint); + bool block(TypePackId target, NotNull constraint); + + void unblock(NotNull progressed); + void unblock(TypeId progressed); + void unblock(TypePackId progressed); + + /** + * @returns true if the TypeId is in a blocked state. + */ + bool isBlocked(TypeId ty); + + /** + * Returns whether the constraint is blocked on anything. + * @param constraint the constraint to check. + */ + bool isBlocked(NotNull constraint); + + /** + * Creates a new Unifier and performs a single unification operation. Commits + * the result. + * @param subType the sub-type to unify. + * @param superType the super-type to unify. + */ + void unify(TypeId subType, TypeId superType, Location location); + + /** + * Creates a new Unifier and performs a single unification operation. Commits + * the result. + * @param subPack the sub-type pack to unify. + * @param superPack the super-type pack to unify. + */ + void unify(TypePackId subPack, TypePackId superPack, Location location); + +private: /** * Marks a constraint as being blocked on a type or type pack. The constraint * solver will not attempt to dispatch blocked constraints until their @@ -60,10 +103,7 @@ struct ConstraintSolver * @param target the type or type pack pointer that the constraint is blocked on. * @param constraint the constraint to block. **/ - void block_(BlockedConstraintId target, const Constraint* constraint); - void block(const Constraint* target, const Constraint* constraint); - void block(TypeId target, const Constraint* constraint); - void block(TypePackId target, const Constraint* constraint); + void block_(BlockedConstraintId target, NotNull constraint); /** * Informs the solver that progress has been made on a type or type pack. The @@ -72,33 +112,6 @@ struct ConstraintSolver * @param progressed the type or type pack pointer that has progressed. **/ void unblock_(BlockedConstraintId progressed); - void unblock(const Constraint* progressed); - void unblock(TypeId progressed); - void unblock(TypePackId progressed); - - /** - * Returns whether the constraint is blocked on anything. - * @param constraint the constraint to check. - */ - bool isBlocked(const Constraint* constraint); - - void reportErrors(const std::vector& errors); - - /** - * Creates a new Unifier and performs a single unification operation. Commits - * the result and reports errors if necessary. - * @param subType the sub-type to unify. - * @param superType the super-type to unify. - */ - void unify(TypeId subType, TypeId superType); - - /** - * Creates a new Unifier and performs a single unification operation. Commits - * the result and reports errors if necessary. - * @param subPack the sub-type pack to unify. - * @param superPack the super-type pack to unify. - */ - void unify(TypePackId subPack, TypePackId superPack); }; void dump(Scope2* rootScope, struct ToStringOptions& opts); diff --git a/Analysis/include/Luau/ConstraintSolverLogger.h b/Analysis/include/Luau/ConstraintSolverLogger.h new file mode 100644 index 0000000..2b195d7 --- /dev/null +++ b/Analysis/include/Luau/ConstraintSolverLogger.h @@ -0,0 +1,26 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/ToString.h" + +#include +#include +#include + +namespace Luau +{ + +struct ConstraintSolverLogger +{ + std::string compileOutput(); + void captureBoundarySnapshot(const Scope2* rootScope, std::vector>& unsolvedConstraints); + void prepareStepSnapshot(const Scope2* rootScope, NotNull current, std::vector>& unsolvedConstraints); + void commitPreparedStepSnapshot(); + +private: + std::vector snapshots; + std::optional preparedSnapshot; + ToStringOptions opts; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 58be0ff..f4226cc 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -66,7 +66,7 @@ struct SourceNode } ModuleName name; - std::unordered_set requires; + std::unordered_set requireSet; std::vector> requireLocations; bool dirtySourceModule = true; bool dirtyModule = true; @@ -186,7 +186,7 @@ public: std::unordered_map sourceNodes; std::unordered_map sourceModules; - std::unordered_map requires; + std::unordered_map requireTrace; Stats stats = {}; }; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index f6e077d..e979b3f 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -69,6 +69,7 @@ struct Module std::vector>> scope2s; // never empty DenseHashMap astTypes{nullptr}; + DenseHashMap astTypePacks{nullptr}; DenseHashMap astExpectedTypes{nullptr}; DenseHashMap astOriginalCallTypes{nullptr}; DenseHashMap astOverloadResolvedTypes{nullptr}; diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 262b54b..d4c7698 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -10,6 +10,7 @@ namespace Luau struct InternalErrorReporter; bool isSubtype(TypeId superTy, TypeId subTy, InternalErrorReporter& ice); +bool isSubtype(TypePackId superTy, TypePackId subTy, InternalErrorReporter& ice); std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice); std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice); diff --git a/Analysis/include/Luau/NotNull.h b/Analysis/include/Luau/NotNull.h index 3d05fde..f6043e9 100644 --- a/Analysis/include/Luau/NotNull.h +++ b/Analysis/include/Luau/NotNull.h @@ -9,20 +9,22 @@ namespace Luau { /** A non-owning, non-null pointer to a T. - * - * A NotNull is notionally identical to a T* with the added restriction that it - * can never store nullptr. - * - * The sole conversion rule from T* to NotNull is the single-argument constructor, which - * is intentionally marked explicit. This constructor performs a runtime test to verify - * that the passed pointer is never nullptr. - * - * Pointer arithmetic, increment, decrement, and array indexing are all forbidden. - * - * An implicit coersion from NotNull to T* is afforded, as are the pointer indirection and member - * access operators. (*p and p->prop) * - * The explicit delete statement is permitted on a NotNull through this implicit conversion. + * A NotNull is notionally identical to a T* with the added restriction that + * it can never store nullptr. + * + * The sole conversion rule from T* to NotNull is the single-argument + * constructor, which is intentionally marked explicit. This constructor + * performs a runtime test to verify that the passed pointer is never nullptr. + * + * Pointer arithmetic, increment, decrement, and array indexing are all + * forbidden. + * + * An implicit coersion from NotNull to T* is afforded, as are the pointer + * indirection and member access operators. (*p and p->prop) + * + * The explicit delete statement is permitted (but not recommended) on a + * NotNull through this implicit conversion. */ template struct NotNull @@ -36,6 +38,11 @@ struct NotNull explicit NotNull(std::nullptr_t) = delete; void operator=(std::nullptr_t) = delete; + template + NotNull(NotNull other) + : ptr(other.get()) + {} + operator T*() const noexcept { return ptr; @@ -56,6 +63,12 @@ struct NotNull T& operator+(int) = delete; T& operator-(int) = delete; + T* get() const noexcept + { + return ptr; + } + +private: T* ptr; }; @@ -68,7 +81,7 @@ template struct hash> { size_t operator()(const Luau::NotNull& p) const { - return std::hash()(p.ptr); + return std::hash()(p.get()); } }; diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h index b32d684..f46f0cb 100644 --- a/Analysis/include/Luau/Quantify.h +++ b/Analysis/include/Luau/Quantify.h @@ -6,9 +6,10 @@ namespace Luau { +struct TypeArena; struct Scope2; void quantify(TypeId ty, TypeLevel level); -void quantify(TypeId ty, Scope2* scope); +TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope); } // namespace Luau diff --git a/Analysis/include/Luau/RequireTracer.h b/Analysis/include/Luau/RequireTracer.h index c25545f..f69d133 100644 --- a/Analysis/include/Luau/RequireTracer.h +++ b/Analysis/include/Luau/RequireTracer.h @@ -19,7 +19,7 @@ struct RequireTraceResult { DenseHashMap exprs{nullptr}; - std::vector> requires; + std::vector> requireList; }; RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName); diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h new file mode 100644 index 0000000..a6c7a3e --- /dev/null +++ b/Analysis/include/Luau/TypeChecker2.h @@ -0,0 +1,13 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Module.h" + +namespace Luau +{ + +void check(const SourceModule& sourceModule, Module* module); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 183cc05..28adc9d 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -138,25 +138,25 @@ struct TypeChecker void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); - ExprResult checkExpr( + WithPredicate checkExpr( const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt, bool forceSingleton = false); - ExprResult checkExpr(const ScopePtr& scope, const AstExprLocal& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprCall& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexName& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); - ExprResult checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); - ExprResult checkExpr(const ScopePtr& scope, const AstExprUnary& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprLocal& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprCall& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIndexName& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprUnary& expr); TypeId checkRelationalOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); TypeId checkBinaryOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); - ExprResult checkExpr(const ScopePtr& scope, const AstExprBinary& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprError& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprBinary& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprError& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, std::optional expectedType); @@ -179,11 +179,11 @@ struct TypeChecker void checkArgumentList( const ScopePtr& scope, Unifier& state, TypePackId paramPack, TypePackId argPack, const std::vector& argLocations); - ExprResult checkExprPack(const ScopePtr& scope, const AstExpr& expr); - ExprResult checkExprPack(const ScopePtr& scope, const AstExprCall& expr); + WithPredicate checkExprPack(const ScopePtr& scope, const AstExpr& expr); + WithPredicate checkExprPack(const ScopePtr& scope, const AstExprCall& expr); std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); - std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, + std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, const std::vector& errors); @@ -191,7 +191,7 @@ struct TypeChecker const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, const std::vector& errors); - ExprResult checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, + WithPredicate checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, bool substituteFreeForNil = false, const std::vector& lhsAnnotations = {}, const std::vector>& expectedTypes = {}); @@ -234,7 +234,7 @@ struct TypeChecker ErrorVec canUnify(TypeId subTy, TypeId superTy, const Location& location); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const Location& location); - void unifyLowerBound(TypePackId subTy, TypePackId superTy, const Location& location); + void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location); std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); @@ -412,7 +412,6 @@ public: const TypeId booleanType; const TypeId threadType; const TypeId anyType; - const TypeId optionalNumberType; const TypePackId anyTypePack; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index b59e7c6..ff7708d 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -84,6 +84,24 @@ using Tags = std::vector; using ModuleName = std::string; +/** A TypeVar that cannot be computed. + * + * BlockedTypeVars essentially serve as a way to encode partial ordering on the + * constraint graph. Until a BlockedTypeVar is unblocked by its owning + * constraint, nothing at all can be said about it. Constraints that need to + * process a BlockedTypeVar cannot be dispatched. + * + * Whenever a BlockedTypeVar is added to the graph, we also record a constraint + * that will eventually unblock it. + */ +struct BlockedTypeVar +{ + BlockedTypeVar(); + int index; + + static int nextIndex; +}; + struct PrimitiveTypeVar { enum Type @@ -231,29 +249,29 @@ struct FunctionDefinition // TODO: Do we actually need this? We'll find out later if we can delete this. // Does not exactly belong in TypeVar.h, but this is the only way to appease the compiler. template -struct ExprResult +struct WithPredicate { T type; PredicateVec predicates; }; -using MagicFunction = std::function>( - struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, ExprResult)>; +using MagicFunction = std::function>( + struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate)>; struct FunctionTypeVar { // Global monomorphic function - FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn = {}, bool hasSelf = false); + FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Global polymorphic function - FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, + FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Local monomorphic function - FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional defn = {}, bool hasSelf = false); + FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); // Local polymorphic function - FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, + FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn = {}, bool hasSelf = false); TypeLevel level; @@ -263,7 +281,7 @@ struct FunctionTypeVar std::vector genericPacks; TypePackId argTypes; std::vector> argNames; - TypePackId retType; + TypePackId retTypes; std::optional definition; MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr. bool hasSelf; @@ -442,7 +460,7 @@ struct LazyTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; struct TypeVar final @@ -555,7 +573,6 @@ struct SingletonTypes const TypeId trueType; const TypeId falseType; const TypeId anyType; - const TypeId optionalNumberType; const TypePackId anyTypePack; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 627b52c..b51a485 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -110,7 +110,7 @@ private: void tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy); public: - void unifyLowerBound(TypePackId subTy, TypePackId superTy); + void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel); // Report an "infinite type error" if the type "needle" already occurs within "haystack" void occursCheck(TypeId needle, TypeId haystack); diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index f383991..642522c 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -209,7 +209,7 @@ struct GenericTypeVarVisitor if (visit(ty, *ftv)) { traverse(ftv->argTypes); - traverse(ftv->retType); + traverse(ftv->retTypes); } } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index a8319c5..8a63901 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,7 +13,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); LUAU_FASTFLAG(LuauSelfCallAutocompleteFix2) static const std::unordered_set kStatementStartingKeywords = { @@ -268,14 +267,14 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ auto checkFunctionType = [typeArena, &canUnify, &expectedType](const FunctionTypeVar* ftv) { if (FFlag::LuauSelfCallAutocompleteFix2) { - if (std::optional firstRetTy = first(ftv->retType)) + if (std::optional firstRetTy = first(ftv->retTypes)) return checkTypeMatch(typeArena, *firstRetTy, expectedType); return false; } else { - auto [retHead, retTail] = flatten(ftv->retType); + auto [retHead, retTail] = flatten(ftv->retTypes); if (!retHead.empty() && canUnify(retHead.front(), expectedType)) return true; @@ -454,7 +453,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId } else if (auto indexFunction = get(followed)) { - std::optional indexFunctionResult = first(indexFunction->retType); + std::optional indexFunctionResult = first(indexFunction->retTypes); if (indexFunctionResult) autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen); } @@ -493,7 +492,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId autocompleteProps(module, typeArena, rootTy, followed, indexType, nodes, result, seen); else if (auto indexFunction = get(followed)) { - std::optional indexFunctionResult = first(indexFunction->retType); + std::optional indexFunctionResult = first(indexFunction->retTypes); if (indexFunctionResult) autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen); } @@ -742,7 +741,7 @@ static std::optional findTypeElementAt(AstType* astType, TypeId ty, Posi if (auto element = findTypeElementAt(type->argTypes, ftv->argTypes, position)) return element; - if (auto element = findTypeElementAt(type->returnTypes, ftv->retType, position)) + if (auto element = findTypeElementAt(type->returnTypes, ftv->retTypes, position)) return element; } @@ -958,7 +957,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (const FunctionTypeVar* ftv = get(follow(*it))) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, tailPos)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, tailPos)) inferredType = *ty; } } @@ -1050,7 +1049,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, i)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, i)) tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); } @@ -1067,7 +1066,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi { if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) { - if (auto ty = tryGetTypePackTypeAt(ftv->retType, ~0u)) + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, ~0u)) tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); } } @@ -1266,7 +1265,7 @@ static bool autocompleteIfElseExpression( if (!parent) return false; - if (FFlag::LuauIfElseExprFixCompletionIssue && node->is()) + if (node->is()) { // Don't try to complete when the current node is an if-else expression (i.e. only try to complete when the node is a child of an if-else // expression. diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 98737b4..2f57e23 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -19,16 +19,16 @@ LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) namespace Luau { -static std::optional> magicFunctionSelect( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionAssert( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionPack( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); -static std::optional> magicFunctionRequire( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); +static std::optional> magicFunctionSelect( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionSetMetaTable( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionAssert( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionPack( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionRequire( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); TypeId makeUnion(TypeArena& arena, std::vector&& types) { @@ -263,10 +263,10 @@ void registerBuiltinTypes(TypeChecker& typeChecker) attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); } -static std::optional> magicFunctionSelect( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionSelect( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; (void)scope; @@ -287,10 +287,10 @@ static std::optional> magicFunctionSelect( if (size_t(offset) < v.size()) { std::vector result(v.begin() + offset, v.end()); - return ExprResult{typechecker.currentModule->internalTypes.addTypePack(TypePack{std::move(result), tail})}; + return WithPredicate{typechecker.currentModule->internalTypes.addTypePack(TypePack{std::move(result), tail})}; } else if (tail) - return ExprResult{*tail}; + return WithPredicate{*tail}; } typechecker.reportError(TypeError{arg1->location, GenericError{"bad argument #1 to select (index out of range)"}}); @@ -298,16 +298,16 @@ static std::optional> magicFunctionSelect( else if (AstExprConstantString* str = arg1->as()) { if (str->value.size == 1 && str->value.data[0] == '#') - return ExprResult{typechecker.currentModule->internalTypes.addTypePack({typechecker.numberType})}; + return WithPredicate{typechecker.currentModule->internalTypes.addTypePack({typechecker.numberType})}; } return std::nullopt; } -static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionSetMetaTable( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -343,7 +343,7 @@ static std::optional> magicFunctionSetMetaTable( if (FFlag::LuauSetMetaTableArgsCheck && expr.args.size < 1) { - return ExprResult{}; + return WithPredicate{}; } if (!FFlag::LuauSetMetaTableArgsCheck || !expr.self) @@ -356,7 +356,7 @@ static std::optional> magicFunctionSetMetaTable( } } - return ExprResult{arena.addTypePack({mtTy})}; + return WithPredicate{arena.addTypePack({mtTy})}; } } else if (get(target) || get(target) || isTableIntersection(target)) @@ -367,13 +367,13 @@ static std::optional> magicFunctionSetMetaTable( typechecker.reportError(TypeError{expr.location, GenericError{"setmetatable should take a table"}}); } - return ExprResult{arena.addTypePack({target})}; + return WithPredicate{arena.addTypePack({target})}; } -static std::optional> magicFunctionAssert( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionAssert( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, predicates] = exprResult; + auto [paramPack, predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -382,7 +382,7 @@ static std::optional> magicFunctionAssert( { std::optional fst = first(*tail); if (!fst) - return ExprResult{paramPack}; + return WithPredicate{paramPack}; head.push_back(*fst); } @@ -397,13 +397,13 @@ static std::optional> magicFunctionAssert( head[0] = *newhead; } - return ExprResult{arena.addTypePack(TypePack{std::move(head), tail})}; + return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; } -static std::optional> magicFunctionPack( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionPack( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -436,7 +436,7 @@ static std::optional> magicFunctionPack( TypeId packedTable = arena.addType( TableTypeVar{{{"n", {typechecker.numberType}}}, TableIndexer(typechecker.numberType, result), scope->level, TableState::Sealed}); - return ExprResult{arena.addTypePack({packedTable})}; + return WithPredicate{arena.addTypePack({packedTable})}; } static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) @@ -461,8 +461,8 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) return good; } -static std::optional> magicFunctionRequire( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +static std::optional> magicFunctionRequire( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { TypeArena& arena = typechecker.currentModule->internalTypes; @@ -476,7 +476,7 @@ static std::optional> magicFunctionRequire( return std::nullopt; if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, expr)) - return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; + return WithPredicate{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; return std::nullopt; } diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 9180f30..248262c 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -47,6 +47,7 @@ struct TypeCloner void operator()(const Unifiable::Generic& t); void operator()(const Unifiable::Bound& t); void operator()(const Unifiable::Error& t); + void operator()(const BlockedTypeVar& t); void operator()(const PrimitiveTypeVar& t); void operator()(const ConstrainedTypeVar& t); void operator()(const SingletonTypeVar& t); @@ -158,6 +159,11 @@ void TypeCloner::operator()(const Unifiable::Error& t) defaultClone(t); } +void TypeCloner::operator()(const BlockedTypeVar& t) +{ + defaultClone(t); +} + void TypeCloner::operator()(const PrimitiveTypeVar& t) { defaultClone(t); @@ -200,7 +206,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t) ftv->tags = t.tags; ftv->argTypes = clone(t.argTypes, dest, cloneState); ftv->argNames = t.argNames; - ftv->retType = clone(t.retType, dest, cloneState); + ftv->retTypes = clone(t.retTypes, dest, cloneState); ftv->hasNoGenerics = t.hasNoGenerics; } @@ -391,7 +397,7 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log) if (const FunctionTypeVar* ftv = get(ty)) { - FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; + FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; clone.generics = ftv->generics; clone.genericPacks = ftv->genericPacks; clone.magicFunction = ftv->magicFunction; diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp new file mode 100644 index 0000000..6cb0e4e --- /dev/null +++ b/Analysis/src/Constraint.cpp @@ -0,0 +1,14 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Constraint.h" + +namespace Luau +{ + +Constraint::Constraint(ConstraintV&& c, Location location) + : c(std::move(c)) + , location(location) +{ +} + +} // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index c8f77dd..fa627e7 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -5,16 +5,7 @@ namespace Luau { -Constraint::Constraint(ConstraintV&& c) - : c(std::move(c)) -{ -} - -Constraint::Constraint(ConstraintV&& c, std::vector dependencies) - : c(std::move(c)) - , dependencies(dependencies) -{ -} +const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp std::optional Scope2::lookup(Symbol sym) { @@ -68,10 +59,10 @@ Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent) return borrow; } -void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv) +void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv, Location location) { LUAU_ASSERT(scope); - scope->constraints.emplace_back(new Constraint{std::move(cv)}); + scope->constraints.emplace_back(new Constraint{std::move(cv), location}); } void ConstraintGraphBuilder::addConstraint(Scope2* scope, std::unique_ptr c) @@ -99,10 +90,18 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat) visit(scope, s); else if (auto s = stat->as()) visit(scope, s); + else if (auto f = stat->as()) + visit(scope, f); else if (auto f = stat->as()) visit(scope, f); else if (auto r = stat->as()) visit(scope, r); + else if (auto a = stat->as()) + visit(scope, a); + else if (auto e = stat->as()) + checkPack(scope, e->expr); + else if (auto i = stat->as()) + visit(scope, i); else LUAU_ASSERT(0); } @@ -121,12 +120,30 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) scope->bindings[local] = ty; } - for (size_t i = 0; i < local->vars.size; ++i) + for (size_t i = 0; i < local->values.size; ++i) { - if (i < local->values.size) + if (local->values.data[i]->is()) + { + // HACK: we leave nil-initialized things floating under the assumption that they will later be populated. + // See the test TypeInfer/infer_locals_with_nil_value. + // Better flow awareness should make this obsolete. + } + else if (i == local->values.size - 1) + { + TypePackId exprPack = checkPack(scope, local->values.data[i]); + + if (i < local->vars.size) + { + std::vector tailValues{varTypes.begin() + i, varTypes.end()}; + TypePackId tailPack = arena->addTypePack(std::move(tailValues)); + addConstraint(scope, PackSubtypeConstraint{exprPack, tailPack}, local->location); + } + } + else { TypeId exprType = check(scope, local->values.data[i]); - addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}); + if (i < varTypes.size()) + addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}, local->vars.data[i]->location); } } } @@ -138,7 +155,7 @@ void addConstraints(Constraint* constraint, Scope2* scope) scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); for (const auto& c : scope->constraints) - constraint->dependencies.push_back(c.get()); + constraint->dependencies.push_back(NotNull{c.get()}); for (Scope2* childScope : scope->children) addConstraints(constraint, childScope); @@ -155,31 +172,75 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function TypeId functionType = nullptr; auto ty = scope->lookup(function->name); - LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. - - functionType = freshType(scope); - scope->bindings[function->name] = functionType; - - Scope2* innerScope = childScope(function->func->body->location, scope); - TypePackId returnType = freshTypePack(scope); - innerScope->returnType = returnType; - - std::vector argTypes; - - for (AstLocal* local : function->func->args) + if (ty.has_value()) { - TypeId t = freshType(innerScope); - argTypes.push_back(t); - innerScope->bindings[local] = t; // TODO annotations + // TODO: This is duplicate definition of a local function. Is this allowed? + functionType = *ty; + } + else + { + functionType = arena->addType(BlockedTypeVar{}); + scope->bindings[function->name] = functionType; } - for (AstStat* stat : function->func->body->body) - visit(innerScope, stat); + auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func); + innerScope->bindings[function->name] = actualFunctionType; - FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; - TypeId actualFunctionType = arena->addType(std::move(actualFunction)); + checkFunctionBody(innerScope, function->func); - std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; + std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}, function->location}}; + addConstraints(c.get(), innerScope); + + addConstraint(scope, std::move(c)); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) +{ + // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. + // With or without self + + TypeId functionType = nullptr; + + auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func); + + if (AstExprLocal* localName = function->name->as()) + { + std::optional existingFunctionTy = scope->lookup(localName->local); + if (existingFunctionTy) + { + // Duplicate definition + functionType = *existingFunctionTy; + } + else + { + functionType = arena->addType(BlockedTypeVar{}); + scope->bindings[localName->local] = functionType; + } + innerScope->bindings[localName->local] = actualFunctionType; + } + else if (AstExprGlobal* globalName = function->name->as()) + { + std::optional existingFunctionTy = scope->lookup(globalName->name); + if (existingFunctionTy) + { + // Duplicate definition + functionType = *existingFunctionTy; + } + else + { + functionType = arena->addType(BlockedTypeVar{}); + rootScope->bindings[globalName->name] = functionType; + } + innerScope->bindings[globalName->name] = actualFunctionType; + } + else if (AstExprIndexName* indexName = function->name->as()) + { + LUAU_ASSERT(0); // not yet implemented + } + + checkFunctionBody(innerScope, function->func); + + std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}, function->location}}; addConstraints(c.get(), innerScope); addConstraint(scope, std::move(c)); @@ -190,7 +251,7 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatReturn* ret) LUAU_ASSERT(scope); TypePackId exprTypes = checkPack(scope, ret->list); - addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}); + addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}, ret->location); } void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) @@ -201,6 +262,28 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) visit(scope, stat); } +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign) +{ + TypePackId varPackId = checkExprList(scope, assign->vars); + TypePackId valuePack = checkPack(scope, assign->values); + + addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}, assign->location); +} + +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatIf* ifStatement) +{ + check(scope, ifStatement->condition); + + Scope2* thenScope = childScope(ifStatement->thenbody->location, scope); + visit(thenScope, ifStatement->thenbody); + + if (ifStatement->elsebody) + { + Scope2* elseScope = childScope(ifStatement->elsebody->location, scope); + visit(elseScope, ifStatement->elsebody); + } +} + TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray exprs) { LUAU_ASSERT(scope); @@ -224,75 +307,256 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray e return arena->addTypePack(TypePack{std::move(types), last}); } +TypePackId ConstraintGraphBuilder::checkExprList(Scope2* scope, const AstArray& exprs) +{ + TypePackId result = arena->addTypePack({}); + TypePack* resultPack = getMutable(result); + LUAU_ASSERT(resultPack); + + for (size_t i = 0; i < exprs.size; ++i) + { + AstExpr* expr = exprs.data[i]; + if (i < exprs.size - 1) + resultPack->head.push_back(check(scope, expr)); + else + resultPack->tail = checkPack(scope, expr); + } + + if (resultPack->head.empty() && resultPack->tail) + return *resultPack->tail; + else + return result; +} + TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) { LUAU_ASSERT(scope); - // TEMP TEMP TEMP HACK HACK HACK FIXME FIXME - TypeId t = check(scope, expr); - return arena->addTypePack({t}); + TypePackId result = nullptr; + + if (AstExprCall* call = expr->as()) + { + std::vector args; + + for (AstExpr* arg : call->args) + { + args.push_back(check(scope, arg)); + } + + // TODO self + + TypeId fnType = check(scope, call->func); + + astOriginalCallTypes[call->func] = fnType; + + TypeId instantiatedType = freshType(scope); + addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}, expr->location); + + TypePackId rets = freshTypePack(scope); + FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); + TypeId inferredFnType = arena->addType(ftv); + + addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}, expr->location); + result = rets; + } + else + { + TypeId t = check(scope, expr); + result = arena->addTypePack({t}); + } + + LUAU_ASSERT(result); + astTypePacks[expr] = result; + return result; } TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) { LUAU_ASSERT(scope); - if (auto a = expr->as()) - return singletonTypes.stringType; - else if (auto a = expr->as()) - return singletonTypes.numberType; - else if (auto a = expr->as()) - return singletonTypes.booleanType; - else if (auto a = expr->as()) - return singletonTypes.nilType; + TypeId result = nullptr; + + if (auto group = expr->as()) + result = check(scope, group->expr); + else if (expr->is()) + result = singletonTypes.stringType; + else if (expr->is()) + result = singletonTypes.numberType; + else if (expr->is()) + result = singletonTypes.booleanType; + else if (expr->is()) + result = singletonTypes.nilType; else if (auto a = expr->as()) { std::optional ty = scope->lookup(a->local); if (ty) - return *ty; + result = *ty; else - return singletonTypes.errorRecoveryType(singletonTypes.anyType); // FIXME? Record an error at this point? + result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? + } + else if (auto g = expr->as()) + { + std::optional ty = scope->lookup(g->name); + if (ty) + result = *ty; + else + result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? } else if (auto a = expr->as()) { - std::vector args; - - for (AstExpr* arg : a->args) + TypePackId packResult = checkPack(scope, expr); + if (auto f = first(packResult)) + return *f; + else if (get(packResult)) { - args.push_back(check(scope, arg)); + TypeId typeResult = freshType(scope); + TypePack onePack{{typeResult}, freshTypePack(scope)}; + TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); + + addConstraint(scope, PackSubtypeConstraint{packResult, oneTypePack}, expr->location); + + return typeResult; } - - TypeId fnType = check(scope, a->func); - TypeId instantiatedType = freshType(scope); - addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}); - - TypeId firstRet = freshType(scope); - TypePackId rets = arena->addTypePack(TypePack{{firstRet}, arena->addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); - FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); - TypeId inferredFnType = arena->addType(ftv); - - addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}); - return firstRet; + } + else if (auto a = expr->as()) + { + auto [fnType, functionScope] = checkFunctionSignature(scope, a); + checkFunctionBody(functionScope, a); + return fnType; + } + else if (auto indexName = expr->as()) + { + result = check(scope, indexName); + } + else if (auto table = expr->as()) + { + result = checkExprTable(scope, table); } else { LUAU_ASSERT(0); - return freshType(scope); + result = freshType(scope); + } + + LUAU_ASSERT(result); + astTypes[expr] = result; + return result; +} + +TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName) +{ + TypeId obj = check(scope, indexName->expr); + TypeId result = freshType(scope); + + TableTypeVar::Props props{{indexName->index.value, Property{result}}}; + const std::optional indexer; + TableTypeVar ttv{std::move(props), indexer, TypeLevel{}, TableState::Free}; + + TypeId expectedTableType = arena->addType(std::move(ttv)); + + addConstraint(scope, SubtypeConstraint{obj, expectedTableType}, indexName->location); + + return result; +} + +TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) +{ + TypeId ty = arena->addType(TableTypeVar{}); + TableTypeVar* ttv = getMutable(ty); + LUAU_ASSERT(ttv); + + auto createIndexer = [this, scope, ttv]( + TypeId currentIndexType, TypeId currentResultType, Location itemLocation, std::optional keyLocation) { + if (!ttv->indexer) + { + TypeId indexType = this->freshType(scope); + TypeId resultType = this->freshType(scope); + ttv->indexer = TableIndexer{indexType, resultType}; + } + + addConstraint(scope, SubtypeConstraint{ttv->indexer->indexType, currentIndexType}, keyLocation ? *keyLocation : itemLocation); + addConstraint(scope, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}, itemLocation); + }; + + for (const AstExprTable::Item& item : expr->items) + { + TypeId itemTy = check(scope, item.value); + + if (item.key) + { + // Even though we don't need to use the type of the item's key if + // it's a string constant, we still want to check it to populate + // astTypes. + TypeId keyTy = check(scope, item.key); + + if (AstExprConstantString* key = item.key->as()) + { + ttv->props[key->value.begin()] = {itemTy}; + } + else + { + createIndexer(keyTy, itemTy, item.value->location, item.key->location); + } + } + else + { + TypeId numberType = singletonTypes.numberType; + createIndexer(numberType, itemTy, item.value->location, std::nullopt); + } + } + + return ty; +} + +std::pair ConstraintGraphBuilder::checkFunctionSignature(Scope2* parent, AstExprFunction* fn) +{ + Scope2* innerScope = childScope(fn->body->location, parent); + TypePackId returnType = freshTypePack(innerScope); + innerScope->returnType = returnType; + + std::vector argTypes; + + for (AstLocal* local : fn->args) + { + TypeId t = freshType(innerScope); + argTypes.push_back(t); + innerScope->bindings[local] = t; // TODO annotations + } + + FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; + TypeId actualFunctionType = arena->addType(std::move(actualFunction)); + LUAU_ASSERT(actualFunctionType); + astTypes[fn] = actualFunctionType; + + return {actualFunctionType, innerScope}; +} + +void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* fn) +{ + for (AstStat* stat : fn->body->body) + visit(scope, stat); + + // If it is possible for execution to reach the end of the function, the return type must be compatible with () + + if (nullptr != getFallthrough(fn->body)) + { + TypePackId empty = arena->addTypePack({}); // TODO we could have CSG retain one of these forever + addConstraint(scope, PackSubtypeConstraint{scope->returnType, empty}, fn->body->location); } } -static void collectConstraints(std::vector& result, Scope2* scope) +void collectConstraints(std::vector>& result, Scope2* scope) { for (const auto& c : scope->constraints) - result.push_back(c.get()); + result.push_back(NotNull{c.get()}); for (Scope2* child : scope->children) collectConstraints(result, child); } -std::vector collectConstraints(Scope2* rootScope) +std::vector> collectConstraints(Scope2* rootScope) { - std::vector result; + std::vector> result; collectConstraints(result, rootScope); return result; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index f40cd4b..41dfd89 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -7,6 +7,7 @@ #include "Luau/Unifier.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau { @@ -58,11 +59,11 @@ ConstraintSolver::ConstraintSolver(TypeArena* arena, Scope2* rootScope) , constraints(collectConstraints(rootScope)) , rootScope(rootScope) { - for (const Constraint* c : constraints) + for (NotNull c : constraints) { - unsolvedConstraints.insert(c); + unsolvedConstraints.push_back(c); - for (const Constraint* dep : c->dependencies) + for (NotNull dep : c->dependencies) { block(dep, c); } @@ -74,8 +75,6 @@ void ConstraintSolver::run() if (done()) return; - bool progress = false; - ToStringOptions opts; if (FFlag::DebugLuauLogSolver) @@ -84,44 +83,80 @@ void ConstraintSolver::run() dump(this, opts); } - do + if (FFlag::DebugLuauLogSolverToJson) { - progress = false; + logger.captureBoundarySnapshot(rootScope, unsolvedConstraints); + } - auto it = begin(unsolvedConstraints); - auto endIt = end(unsolvedConstraints); + auto runSolverPass = [&](bool force) { + bool progress = false; - while (it != endIt) + size_t i = 0; + while (i < unsolvedConstraints.size()) { - if (isBlocked(*it)) + NotNull c = unsolvedConstraints[i]; + if (!force && isBlocked(c)) { - ++it; + ++i; continue; } - std::string saveMe = FFlag::DebugLuauLogSolver ? toString(**it, opts) : std::string{}; + std::string saveMe = FFlag::DebugLuauLogSolver ? toString(*c, opts) : std::string{}; - bool success = tryDispatch(*it); - progress = progress || success; + if (FFlag::DebugLuauLogSolverToJson) + { + logger.prepareStepSnapshot(rootScope, c, unsolvedConstraints); + } + + bool success = tryDispatch(c, force); + + progress |= success; - auto saveIt = it; - ++it; if (success) { - unsolvedConstraints.erase(saveIt); + unsolvedConstraints.erase(unsolvedConstraints.begin() + i); + + if (FFlag::DebugLuauLogSolverToJson) + { + logger.commitPreparedStepSnapshot(); + } + if (FFlag::DebugLuauLogSolver) { + if (force) + printf("Force "); printf("Dispatched\n\t%s\n", saveMe.c_str()); dump(this, opts); } } + else + ++i; + + if (force && success) + return true; } + + return progress; + }; + + bool progress = false; + do + { + progress = runSolverPass(false); + if (!progress) + progress |= runSolverPass(true); } while (progress); if (FFlag::DebugLuauLogSolver) + { dumpBindings(rootScope, opts); + } - LUAU_ASSERT(done()); + if (FFlag::DebugLuauLogSolverToJson) + { + logger.captureBoundarySnapshot(rootScope, unsolvedConstraints); + printf("Logger output:\n%s\n", logger.compileOutput().c_str()); + } } bool ConstraintSolver::done() @@ -129,21 +164,21 @@ bool ConstraintSolver::done() return unsolvedConstraints.empty(); } -bool ConstraintSolver::tryDispatch(const Constraint* constraint) +bool ConstraintSolver::tryDispatch(NotNull constraint, bool force) { - if (isBlocked(constraint)) + if (!force && isBlocked(constraint)) return false; bool success = false; if (auto sc = get(*constraint)) - success = tryDispatch(*sc); + success = tryDispatch(*sc, constraint, force); else if (auto psc = get(*constraint)) - success = tryDispatch(*psc); + success = tryDispatch(*psc, constraint, force); else if (auto gc = get(*constraint)) - success = tryDispatch(*gc); + success = tryDispatch(*gc, constraint, force); else if (auto ic = get(*constraint)) - success = tryDispatch(*ic, constraint); + success = tryDispatch(*ic, constraint, force); else LUAU_ASSERT(0); @@ -155,65 +190,66 @@ bool ConstraintSolver::tryDispatch(const Constraint* constraint) return success; } -bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c) +bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) { - unify(c.subType, c.superType); + if (isBlocked(c.subType)) + return block(c.subType, constraint); + else if (isBlocked(c.superType)) + return block(c.superType, constraint); + + unify(c.subType, c.superType, constraint->location); + unblock(c.subType); unblock(c.superType); return true; } -bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c) +bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) { - unify(c.subPack, c.superPack); + unify(c.subPack, c.superPack, constraint->location); unblock(c.subPack); unblock(c.superPack); return true; } -bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& constraint) +bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force) { - unify(constraint.subType, constraint.superType); + if (isBlocked(c.sourceType)) + return block(c.sourceType, constraint); - quantify(constraint.superType, constraint.scope); - unblock(constraint.subType); - unblock(constraint.superType); + if (isBlocked(c.generalizedType)) + asMutable(c.generalizedType)->ty.emplace(c.sourceType); + else + unify(c.generalizedType, c.sourceType, constraint->location); + + TypeId generalized = quantify(arena, c.sourceType, c.scope); + *asMutable(c.sourceType) = *generalized; + + unblock(c.generalizedType); + unblock(c.sourceType); return true; } -bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, const Constraint* constraint) +bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force) { - TypeId superType = follow(c.superType); - if (const FunctionTypeVar* ftv = get(superType)) - { - if (!ftv->generalized) - { - block(superType, constraint); - return false; - } - } - else if (get(superType)) - { - block(superType, constraint); - return false; - } - // TODO: Error if it's a primitive or something + if (isBlocked(c.superType)) + return block(c.superType, constraint); Instantiation inst(TxnLog::empty(), arena, TypeLevel{}); std::optional instantiated = inst.substitute(c.superType); LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS - unify(c.subType, *instantiated); + unify(c.subType, *instantiated, constraint->location); unblock(c.subType); return true; } -void ConstraintSolver::block_(BlockedConstraintId target, const Constraint* constraint) +void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { blocked[target].push_back(constraint); @@ -221,19 +257,21 @@ void ConstraintSolver::block_(BlockedConstraintId target, const Constraint* cons count += 1; } -void ConstraintSolver::block(const Constraint* target, const Constraint* constraint) +void ConstraintSolver::block(NotNull target, NotNull constraint) { block_(target, constraint); } -void ConstraintSolver::block(TypeId target, const Constraint* constraint) +bool ConstraintSolver::block(TypeId target, NotNull constraint) { block_(target, constraint); + return false; } -void ConstraintSolver::block(TypePackId target, const Constraint* constraint) +bool ConstraintSolver::block(TypePackId target, NotNull constraint) { block_(target, constraint); + return false; } void ConstraintSolver::unblock_(BlockedConstraintId progressed) @@ -243,7 +281,7 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) return; // unblocked should contain a value always, because of the above check - for (const Constraint* unblockedConstraint : it->second) + for (NotNull unblockedConstraint : it->second) { auto& count = blockedConstraints[unblockedConstraint]; // This assertion being hit indicates that `blocked` and @@ -257,7 +295,7 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) blocked.erase(it); } -void ConstraintSolver::unblock(const Constraint* progressed) +void ConstraintSolver::unblock(NotNull progressed) { return unblock_(progressed); } @@ -272,35 +310,33 @@ void ConstraintSolver::unblock(TypePackId progressed) return unblock_(progressed); } -bool ConstraintSolver::isBlocked(const Constraint* constraint) +bool ConstraintSolver::isBlocked(TypeId ty) +{ + return nullptr != get(follow(ty)); +} + +bool ConstraintSolver::isBlocked(NotNull constraint) { auto blockedIt = blockedConstraints.find(constraint); return blockedIt != blockedConstraints.end() && blockedIt->second > 0; } -void ConstraintSolver::reportErrors(const std::vector& errors) -{ - this->errors.insert(end(this->errors), begin(errors), end(errors)); -} - -void ConstraintSolver::unify(TypeId subType, TypeId superType) +void ConstraintSolver::unify(TypeId subType, TypeId superType, Location location) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; + Unifier u{arena, Mode::Strict, location, Covariant, sharedState}; u.tryUnify(subType, superType); u.log.commit(); - reportErrors(u.errors); } -void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack) +void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, Location location) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; + Unifier u{arena, Mode::Strict, location, Covariant, sharedState}; u.tryUnify(subPack, superPack); u.log.commit(); - reportErrors(u.errors); } } // namespace Luau diff --git a/Analysis/src/ConstraintSolverLogger.cpp b/Analysis/src/ConstraintSolverLogger.cpp new file mode 100644 index 0000000..2f93c28 --- /dev/null +++ b/Analysis/src/ConstraintSolverLogger.cpp @@ -0,0 +1,139 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/ConstraintSolverLogger.h" + +namespace Luau +{ + +static std::string dumpScopeAndChildren(const Scope2* scope, ToStringOptions& opts) +{ + std::string output = "{\"bindings\":{"; + + bool comma = false; + for (const auto& [name, type] : scope->bindings) + { + if (comma) + output += ","; + + output += "\""; + output += name.c_str(); + output += "\": \""; + + ToStringResult result = toStringDetailed(type, opts); + opts.nameMap = std::move(result.nameMap); + output += result.name; + output += "\""; + + comma = true; + } + + output += "},\"children\":["; + comma = false; + + for (const Scope2* child : scope->children) + { + if (comma) + output += ","; + + output += dumpScopeAndChildren(child, opts); + comma = true; + } + + output += "]}"; + return output; +} + +static std::string dumpConstraintsToDot(std::vector>& constraints, ToStringOptions& opts) +{ + std::string result = "digraph Constraints {\\n"; + + std::unordered_set> contained; + for (NotNull c : constraints) + { + contained.insert(c); + } + + for (NotNull c : constraints) + { + std::string id = std::to_string(reinterpret_cast(c.get())); + result += id; + result += " [label=\\\""; + result += toString(*c, opts).c_str(); + result += "\\\"];\\n"; + + for (NotNull dep : c->dependencies) + { + if (contained.count(dep) == 0) + continue; + + result += std::to_string(reinterpret_cast(dep.get())); + result += " -> "; + result += id; + result += ";\\n"; + } + } + + result += "}"; + + return result; +} + +std::string ConstraintSolverLogger::compileOutput() +{ + std::string output = "["; + bool comma = false; + + for (const std::string& snapshot : snapshots) + { + if (comma) + output += ","; + output += snapshot; + + comma = true; + } + + output += "]"; + return output; +} + +void ConstraintSolverLogger::captureBoundarySnapshot(const Scope2* rootScope, std::vector>& unsolvedConstraints) +{ + std::string snapshot = "{\"type\":\"boundary\",\"rootScope\":"; + + snapshot += dumpScopeAndChildren(rootScope, opts); + snapshot += ",\"constraintGraph\":\""; + snapshot += dumpConstraintsToDot(unsolvedConstraints, opts); + snapshot += "\"}"; + + snapshots.push_back(std::move(snapshot)); +} + +void ConstraintSolverLogger::prepareStepSnapshot( + const Scope2* rootScope, NotNull current, std::vector>& unsolvedConstraints) +{ + // LUAU_ASSERT(!preparedSnapshot); + + std::string snapshot = "{\"type\":\"step\",\"rootScope\":"; + + snapshot += dumpScopeAndChildren(rootScope, opts); + snapshot += ",\"constraintGraph\":\""; + snapshot += dumpConstraintsToDot(unsolvedConstraints, opts); + snapshot += "\",\"currentId\":\""; + snapshot += std::to_string(reinterpret_cast(current.get())); + snapshot += "\",\"current\":\""; + snapshot += toString(*current, opts); + snapshot += "\"}"; + + preparedSnapshot = std::move(snapshot); +} + +void ConstraintSolverLogger::commitPreparedStepSnapshot() +{ + if (preparedSnapshot) + { + snapshots.push_back(std::move(*preparedSnapshot)); + preparedSnapshot = std::nullopt; + } +} + +} // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 741a35c..9e02506 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1,16 +1,17 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Frontend.h" -#include "Luau/Common.h" #include "Luau/Clone.h" +#include "Luau/Common.h" #include "Luau/Config.h" -#include "Luau/FileResolver.h" #include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintSolver.h" +#include "Luau/FileResolver.h" #include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/StringUtils.h" #include "Luau/TimeTrace.h" +#include "Luau/TypeChecker2.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" @@ -216,7 +217,7 @@ ErrorVec accumulateErrors( continue; const SourceNode& sourceNode = it->second; - queue.insert(queue.end(), sourceNode.requires.begin(), sourceNode.requires.end()); + queue.insert(queue.end(), sourceNode.requireSet.begin(), sourceNode.requireSet.end()); // FIXME: If a module has a syntax error, we won't be able to re-report it here. // The solution is probably to move errors from Module to SourceNode @@ -586,7 +587,7 @@ bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& chec path.push_back(top); // push children - for (const ModuleName& dep : top->requires) + for (const ModuleName& dep : top->requireSet) { auto it = sourceNodes.find(dep); if (it != sourceNodes.end()) @@ -738,7 +739,7 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) { - for (const auto& dep : module.second.requires) + for (const auto& dep : module.second.requireSet) reverseDeps[dep].push_back(module.first); } @@ -797,9 +798,14 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const Sco cs.run(); result->scope2s = std::move(cgb.scopes); + result->astTypes = std::move(cgb.astTypes); + result->astTypePacks = std::move(cgb.astTypePacks); + result->astOriginalCallTypes = std::move(cgb.astOriginalCallTypes); result->clonePublicInterface(iceHandler); + Luau::check(sourceModule, result.get()); + return result; } @@ -841,8 +847,8 @@ std::pair Frontend::getSourceNode(CheckResult& check SourceModule result = parse(name, source->source, opts); result.type = source->type; - RequireTraceResult& requireTrace = requires[name]; - requireTrace = traceRequires(fileResolver, result.root, name); + RequireTraceResult& require = requireTrace[name]; + require = traceRequires(fileResolver, result.root, name); SourceNode& sourceNode = sourceNodes[name]; SourceModule& sourceModule = sourceModules[name]; @@ -851,7 +857,7 @@ std::pair Frontend::getSourceNode(CheckResult& check sourceModule.environmentName = environmentName; sourceNode.name = name; - sourceNode.requires.clear(); + sourceNode.requireSet.clear(); sourceNode.requireLocations.clear(); sourceNode.dirtySourceModule = false; @@ -861,10 +867,10 @@ std::pair Frontend::getSourceNode(CheckResult& check sourceNode.dirtyModuleForAutocomplete = true; } - for (const auto& [moduleName, location] : requireTrace.requires) - sourceNode.requires.insert(moduleName); + for (const auto& [moduleName, location] : require.requireList) + sourceNode.requireSet.insert(moduleName); - sourceNode.requireLocations = requireTrace.requires; + sourceNode.requireLocations = require.requireList; return {&sourceNode, &sourceModule}; } @@ -925,8 +931,8 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const std::optional FrontendModuleResolver::resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) { // FIXME I think this can be pushed into the FileResolver. - auto it = frontend->requires.find(currentModuleName); - if (it == frontend->requires.end()) + auto it = frontend->requireTrace.find(currentModuleName); + if (it == frontend->requireTrace.end()) { // CLI-43699 // If we can't find the current module name, that's because we bypassed the frontend's initializer @@ -1025,7 +1031,7 @@ void Frontend::clear() sourceModules.clear(); moduleResolver.modules.clear(); moduleResolverForAutocomplete.modules.clear(); - requires.clear(); + requireTrace.clear(); } } // namespace Luau diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index f145a51..77c6242 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -40,7 +40,7 @@ TypeId Instantiation::clean(TypeId ty) const FunctionTypeVar* ftv = log->getMutable(ty); LUAU_ASSERT(ftv); - FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; + FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; clone.magicFunction = ftv->magicFunction; clone.tags = ftv->tags; clone.argNames = ftv->argNames; diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 200b7d1..50868e5 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -2282,7 +2282,7 @@ private: size_t getReturnCount(TypeId ty) { if (auto ftv = get(ty)) - return size(ftv->retType); + return size(ftv->retTypes); if (auto itv = get(ty)) { @@ -2291,7 +2291,7 @@ private: for (TypeId part : itv->parts) if (auto ftv = get(follow(part))) - result = std::max(result, size(ftv->retType)); + result = std::max(result, size(ftv->retTypes)); return result; } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 11403be..d36665e 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -17,6 +17,7 @@ LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false); LUAU_FASTFLAGVARIABLE(LuauReplaceReplacer, false); +LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau { @@ -273,6 +274,18 @@ bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice) return ok; } +bool isSubtype(TypePackId subPack, TypePackId superPack, InternalErrorReporter& ice) +{ + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState}; + u.anyIsTop = true; + + u.tryUnify(subPack, superPack); + const bool ok = u.errors.empty() && u.log.empty(); + return ok; +} + template static bool areNormal_(const T& t, const std::unordered_set& seen, InternalErrorReporter& ice) { @@ -390,6 +403,7 @@ struct Normalize final : TypeVarVisitor bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override { CHECK_ITERATION_LIMIT(false); + LUAU_ASSERT(!ty->normal); ConstrainedTypeVar* ctv = const_cast(&ctvRef); @@ -401,14 +415,21 @@ struct Normalize final : TypeVarVisitor std::vector newParts = normalizeUnion(parts); - const bool normal = areNormal(newParts, seen, ice); - - if (newParts.size() == 1) - *asMutable(ty) = BoundTypeVar{newParts[0]}; + if (FFlag::LuauQuantifyConstrained) + { + ctv->parts = std::move(newParts); + } else - *asMutable(ty) = UnionTypeVar{std::move(newParts)}; + { + const bool normal = areNormal(newParts, seen, ice); - asMutable(ty)->normal = normal; + if (newParts.size() == 1) + *asMutable(ty) = BoundTypeVar{newParts[0]}; + else + *asMutable(ty) = UnionTypeVar{std::move(newParts)}; + + asMutable(ty)->normal = normal; + } return false; } @@ -421,9 +442,9 @@ struct Normalize final : TypeVarVisitor return false; traverse(ftv.argTypes); - traverse(ftv.retType); + traverse(ftv.retTypes); - asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retType, seen, ice); + asMutable(ty)->normal = areNormal(ftv.argTypes, seen, ice) && areNormal(ftv.retTypes, seen, ice); return false; } @@ -465,7 +486,14 @@ struct Normalize final : TypeVarVisitor checkNormal(ttv.indexer->indexResultType); } - asMutable(ty)->normal = normal; + // An unsealed table can never be normal, ditto for free tables iff the type it is bound to is also not normal. + if (FFlag::LuauQuantifyConstrained) + { + if (ttv.state == TableState::Generic || ttv.state == TableState::Sealed || (ttv.state == TableState::Free && follow(ty)->normal)) + asMutable(ty)->normal = normal; + } + else + asMutable(ty)->normal = normal; return false; } diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 2177537..2004d15 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -2,15 +2,32 @@ #include "Luau/Quantify.h" +#include "Luau/ConstraintGraphBuilder.h" // TODO for Scope2; move to separate header +#include "Luau/TxnLog.h" +#include "Luau/Substitution.h" #include "Luau/VisitTypeVar.h" #include "Luau/ConstraintGraphBuilder.h" // TODO for Scope2; move to separate header LUAU_FASTFLAG(LuauAlwaysQuantify); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAGVARIABLE(LuauQuantifyConstrained, false) namespace Luau { +/// @return true if outer encloses inner +static bool subsumes(Scope2* outer, Scope2* inner) +{ + while (inner) + { + if (inner == outer) + return true; + inner = inner->parent; + } + + return false; +} + struct Quantifier final : TypeVarOnceVisitor { TypeLevel level; @@ -62,6 +79,34 @@ struct Quantifier final : TypeVarOnceVisitor return false; } + bool visit(TypeId ty, const ConstrainedTypeVar&) override + { + if (FFlag::LuauQuantifyConstrained) + { + ConstrainedTypeVar* ctv = getMutable(ty); + + seenMutableType = true; + + if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ctv->scope) : !level.subsumes(ctv->level)) + return false; + + std::vector opts = std::move(ctv->parts); + + // We might transmute, so it's not safe to rely on the builtin traversal logic + for (TypeId opt : opts) + traverse(opt); + + if (opts.size() == 1) + *asMutable(ty) = BoundTypeVar{opts[0]}; + else + *asMutable(ty) = UnionTypeVar{std::move(opts)}; + + return false; + } + else + return true; + } + bool visit(TypeId ty, const TableTypeVar&) override { LUAU_ASSERT(getMutable(ty)); @@ -73,8 +118,12 @@ struct Quantifier final : TypeVarOnceVisitor if (ttv.state == TableState::Free) seenMutableType = true; - if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) - return false; + if (!FFlag::LuauQuantifyConstrained) + { + if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) + return false; + } + if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ttv.scope) : !level.subsumes(ttv.level)) { if (ttv.state == TableState::Unsealed) @@ -156,4 +205,104 @@ void quantify(TypeId ty, Scope2* scope) ftv->generalized = true; } +struct PureQuantifier : Substitution +{ + Scope2* scope; + std::vector insertedGenerics; + std::vector insertedGenericPacks; + + PureQuantifier(const TxnLog* log, TypeArena* arena, Scope2* scope) + : Substitution(log, arena) + , scope(scope) + { + } + + bool isDirty(TypeId ty) override + { + LUAU_ASSERT(ty == follow(ty)); + + if (auto ftv = get(ty)) + { + return subsumes(scope, ftv->scope); + } + else if (auto ttv = get(ty)) + { + return ttv->state == TableState::Free && subsumes(scope, ttv->scope); + } + + return false; + } + + bool isDirty(TypePackId tp) override + { + if (auto ftp = get(tp)) + { + return subsumes(scope, ftp->scope); + } + + return false; + } + + TypeId clean(TypeId ty) override + { + if (auto ftv = get(ty)) + { + TypeId result = arena->addType(GenericTypeVar{}); + insertedGenerics.push_back(result); + return result; + } + else if (auto ttv = get(ty)) + { + TypeId result = arena->addType(TableTypeVar{}); + TableTypeVar* resultTable = getMutable(result); + LUAU_ASSERT(resultTable); + + *resultTable = *ttv; + resultTable->scope = nullptr; + resultTable->state = TableState::Generic; + + return result; + } + + return ty; + } + + TypePackId clean(TypePackId tp) override + { + if (auto ftp = get(tp)) + { + TypePackId result = arena->addTypePack(TypePackVar{GenericTypePack{}}); + insertedGenericPacks.push_back(result); + return result; + } + + return tp; + } + + bool ignoreChildren(TypeId ty) override + { + return ty->persistent; + } + bool ignoreChildren(TypePackId ty) override + { + return ty->persistent; + } +}; + +TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope) +{ + PureQuantifier quantifier{TxnLog::empty(), arena, scope}; + std::optional result = quantifier.substitute(ty); + LUAU_ASSERT(result); + + FunctionTypeVar* ftv = getMutable(*result); + LUAU_ASSERT(ftv); + ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end()); + + // TODO: Set hasNoGenerics. + + return *result; +} + } // namespace Luau diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index 8ed245f..c036a7a 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -28,7 +28,7 @@ struct RequireTracer : AstVisitor AstExprGlobal* global = expr->func->as(); if (global && global->name == "require" && expr->args.size >= 1) - requires.push_back(expr); + requireCalls.push_back(expr); return true; } @@ -84,9 +84,9 @@ struct RequireTracer : AstVisitor ModuleInfo moduleContext{currentModuleName}; // seed worklist with require arguments - work.reserve(requires.size()); + work.reserve(requireCalls.size()); - for (AstExprCall* require : requires) + for (AstExprCall* require : requireCalls) work.push_back(require->args.data[0]); // push all dependent expressions to the work stack; note that the vector is modified during traversal @@ -125,15 +125,15 @@ struct RequireTracer : AstVisitor } // resolve all requires according to their argument - result.requires.reserve(requires.size()); + result.requireList.reserve(requireCalls.size()); - for (AstExprCall* require : requires) + for (AstExprCall* require : requireCalls) { AstExpr* arg = require->args.data[0]; if (const ModuleInfo* info = result.exprs.find(arg)) { - result.requires.push_back({info->name, require->location}); + result.requireList.push_back({info->name, require->location}); ModuleInfo infoCopy = *info; // copy *info out since next line invalidates info! result.exprs[require] = std::move(infoCopy); @@ -151,7 +151,7 @@ struct RequireTracer : AstVisitor DenseHashMap locals; std::vector work; - std::vector requires; + std::vector requireCalls; }; RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 5a22dee..9c4ce82 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -27,7 +27,7 @@ void Tarjan::visitChildren(TypeId ty, int index) if (const FunctionTypeVar* ftv = get(ty)) { visitChild(ftv->argTypes); - visitChild(ftv->retType); + visitChild(ftv->retTypes); } else if (const TableTypeVar* ttv = get(ty)) { @@ -442,7 +442,7 @@ void Substitution::replaceChildren(TypeId ty) if (FunctionTypeVar* ftv = getMutable(ty)) { ftv->argTypes = replace(ftv->argTypes); - ftv->retType = replace(ftv->retType); + ftv->retTypes = replace(ftv->retTypes); } else if (TableTypeVar* ttv = getMutable(ty)) { diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 9b396c8..6b677bb 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -154,7 +154,7 @@ void StateDot::visitChildren(TypeId ty, int index) finishNode(); visitChild(ftv->argTypes, index, "arg"); - visitChild(ftv->retType, index, "ret"); + visitChild(ftv->retTypes, index, "ret"); } else if (const TableTypeVar* ttv = get(ty)) { diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 8490350..eee0dee 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -18,6 +18,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) * Fair warning: Setting this will break a lot of Luau unit tests. */ LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) +LUAU_FASTFLAGVARIABLE(LuauToStringTableBracesNewlines, false) namespace Luau { @@ -225,6 +226,11 @@ struct StringifierState result.name += s; } + void emit(int i) + { + emit(std::to_string(i).c_str()); + } + void indent() { indentation += 4; @@ -392,6 +398,13 @@ struct TypeVarStringifier state.emit("]]"); } + void operator()(TypeId, const BlockedTypeVar& btv) + { + state.emit("*blocked-"); + state.emit(btv.index); + state.emit("*"); + } + void operator()(TypeId, const PrimitiveTypeVar& ptv) { switch (ptv.type) @@ -478,8 +491,8 @@ struct TypeVarStringifier if (FFlag::LuauLowerBoundsCalculation) { - auto retBegin = begin(ftv.retType); - auto retEnd = end(ftv.retType); + auto retBegin = begin(ftv.retTypes); + auto retEnd = end(ftv.retTypes); if (retBegin != retEnd) { ++retBegin; @@ -489,7 +502,7 @@ struct TypeVarStringifier } else { - if (auto retPack = get(follow(ftv.retType))) + if (auto retPack = get(follow(ftv.retTypes))) { if (retPack->head.size() == 1 && !retPack->tail) plural = false; @@ -499,7 +512,7 @@ struct TypeVarStringifier if (plural) state.emit("("); - stringify(ftv.retType); + stringify(ftv.retTypes); if (plural) state.emit(")"); @@ -557,22 +570,54 @@ struct TypeVarStringifier { case TableState::Sealed: state.result.invalid = true; - openbrace = "{| "; - closedbrace = " |}"; + if (FFlag::LuauToStringTableBracesNewlines) + { + openbrace = "{|"; + closedbrace = "|}"; + } + else + { + openbrace = "{| "; + closedbrace = " |}"; + } break; case TableState::Unsealed: - openbrace = "{ "; - closedbrace = " }"; + if (FFlag::LuauToStringTableBracesNewlines) + { + openbrace = "{"; + closedbrace = "}"; + } + else + { + openbrace = "{ "; + closedbrace = " }"; + } break; case TableState::Free: state.result.invalid = true; - openbrace = "{- "; - closedbrace = " -}"; + if (FFlag::LuauToStringTableBracesNewlines) + { + openbrace = "{-"; + closedbrace = "-}"; + } + else + { + openbrace = "{- "; + closedbrace = " -}"; + } break; case TableState::Generic: state.result.invalid = true; - openbrace = "{+ "; - closedbrace = " +}"; + if (FFlag::LuauToStringTableBracesNewlines) + { + openbrace = "{+"; + closedbrace = "+}"; + } + else + { + openbrace = "{+ "; + closedbrace = " +}"; + } break; } @@ -591,6 +636,8 @@ struct TypeVarStringifier bool comma = false; if (ttv.indexer) { + if (FFlag::LuauToStringTableBracesNewlines) + state.newline(); state.emit("["); stringify(ttv.indexer->indexType); state.emit("]: "); @@ -607,6 +654,10 @@ struct TypeVarStringifier state.emit(","); state.newline(); } + else if (FFlag::LuauToStringTableBracesNewlines) + { + state.newline(); + } size_t length = state.result.name.length() - oldLength; @@ -633,6 +684,13 @@ struct TypeVarStringifier } state.dedent(); + if (FFlag::LuauToStringTableBracesNewlines) + { + if (comma) + state.newline(); + else + state.emit(" "); + } state.emit(closedbrace); state.unsee(&ttv); @@ -1247,14 +1305,14 @@ std::string toStringNamedFunction(const std::string& funcName, const FunctionTyp state.emit("): "); - size_t retSize = size(ftv.retType); - bool hasTail = !finite(ftv.retType); - bool wrap = get(follow(ftv.retType)) && (hasTail ? retSize != 0 : retSize != 1); + size_t retSize = size(ftv.retTypes); + bool hasTail = !finite(ftv.retTypes); + bool wrap = get(follow(ftv.retTypes)) && (hasTail ? retSize != 0 : retSize != 1); if (wrap) state.emit("("); - tvs.stringify(ftv.retType); + tvs.stringify(ftv.retTypes); if (wrap) state.emit(")"); @@ -1329,9 +1387,9 @@ std::string toString(const Constraint& c, ToStringOptions& opts) } else if (const GeneralizationConstraint* gc = Luau::get_if(&c.c)) { - ToStringResult subStr = toStringDetailed(gc->subType, opts); + ToStringResult subStr = toStringDetailed(gc->generalizedType, opts); opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(gc->superType, opts); + ToStringResult superStr = toStringDetailed(gc->sourceType, opts); opts.nameMap = std::move(superStr.nameMap); return subStr.name + " ~ gen " + superStr.name; } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 0f4534b..6cca712 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -94,6 +94,11 @@ public: } } + AstType* operator()(const BlockedTypeVar& btv) + { + return allocator->alloc(Location(), std::nullopt, AstName("*blocked*")); + } + AstType* operator()(const ConstrainedTypeVar& ctv) { AstArray types; @@ -271,7 +276,7 @@ public: } AstArray returnTypes; - const auto& [retVector, retTail] = flatten(ftv.retType); + const auto& [retVector, retTail] = flatten(ftv.retTypes); returnTypes.size = retVector.size(); returnTypes.data = static_cast(allocator->allocate(sizeof(AstType*) * returnTypes.size)); for (size_t i = 0; i < returnTypes.size; ++i) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp new file mode 100644 index 0000000..7f5ba68 --- /dev/null +++ b/Analysis/src/TypeChecker2.cpp @@ -0,0 +1,160 @@ + +#include "Luau/TypeChecker2.h" + +#include + +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" +#include "Luau/Clone.h" +#include "Luau/Normalize.h" + +namespace Luau +{ + +struct TypeChecker2 : public AstVisitor +{ + const SourceModule* sourceModule; + Module* module; + InternalErrorReporter ice; // FIXME accept a pointer from Frontend + + TypeChecker2(const SourceModule* sourceModule, Module* module) + : sourceModule(sourceModule) + , module(module) + { + } + + using AstVisitor::visit; + + TypePackId lookupPack(AstExpr* expr) + { + TypePackId* tp = module->astTypePacks.find(expr); + LUAU_ASSERT(tp); + return follow(*tp); + } + + TypeId lookupType(AstExpr* expr) + { + TypeId* ty = module->astTypes.find(expr); + LUAU_ASSERT(ty); + return follow(*ty); + } + + bool visit(AstStatAssign* assign) override + { + size_t count = std::min(assign->vars.size, assign->values.size); + + for (size_t i = 0; i < count; ++i) + { + AstExpr* lhs = assign->vars.data[i]; + TypeId* lhsType = module->astTypes.find(lhs); + LUAU_ASSERT(lhsType); + + AstExpr* rhs = assign->values.data[i]; + TypeId* rhsType = module->astTypes.find(rhs); + LUAU_ASSERT(rhsType); + + if (!isSubtype(*rhsType, *lhsType, ice)) + { + reportError(TypeMismatch{*lhsType, *rhsType}, rhs->location); + } + } + + return true; + } + + bool visit(AstExprCall* call) override + { + TypePackId expectedRetType = lookupPack(call); + TypeId functionType = lookupType(call->func); + + TypeArena arena; + TypePack args; + for (const auto& arg : call->args) + { + TypeId argTy = module->astTypes[arg]; + LUAU_ASSERT(argTy); + args.head.push_back(argTy); + } + + TypePackId argsTp = arena.addTypePack(args); + FunctionTypeVar ftv{argsTp, expectedRetType}; + TypeId expectedType = arena.addType(ftv); + if (!isSubtype(expectedType, functionType, ice)) + { + unfreeze(module->interfaceTypes); + CloneState cloneState; + expectedType = clone(expectedType, module->interfaceTypes, cloneState); + freeze(module->interfaceTypes); + reportError(TypeMismatch{expectedType, functionType}, call->location); + } + + return true; + } + + bool visit(AstExprIndexName* indexName) override + { + TypeId leftType = lookupType(indexName->expr); + TypeId resultType = lookupType(indexName); + + // leftType must have a property called indexName->index + + if (auto ttv = get(leftType)) + { + auto it = ttv->props.find(indexName->index.value); + if (it == ttv->props.end()) + { + reportError(UnknownProperty{leftType, indexName->index.value}, indexName->location); + } + else if (!isSubtype(resultType, it->second.type, ice)) + { + reportError(TypeMismatch{resultType, it->second.type}, indexName->location); + } + } + else + { + reportError(UnknownProperty{leftType, indexName->index.value}, indexName->location); + } + + return true; + } + + bool visit(AstExprConstantNumber* number) override + { + TypeId actualType = lookupType(number); + TypeId numberType = getSingletonTypes().numberType; + + if (!isSubtype(actualType, numberType, ice)) + { + reportError(TypeMismatch{actualType, numberType}, number->location); + } + + return true; + } + + bool visit(AstExprConstantString* string) override + { + TypeId actualType = lookupType(string); + TypeId stringType = getSingletonTypes().stringType; + + if (!isSubtype(actualType, stringType, ice)) + { + reportError(TypeMismatch{actualType, stringType}, string->location); + } + + return true; + } + + void reportError(TypeErrorData&& data, const Location& location) + { + module->errors.emplace_back(location, sourceModule->name, std::move(data)); + } +}; + +void check(const SourceModule& sourceModule, Module* module) +{ + TypeChecker2 typeChecker{&sourceModule, module}; + + sourceModule.root->visit(&typeChecker); +} + +} // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 447cd02..fd1b3b8 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -18,6 +18,7 @@ #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" #include #include @@ -30,7 +31,6 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) -LUAU_FASTFLAGVARIABLE(LuauExpectedPropTypeFromIndexer, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix2, false) @@ -42,9 +42,9 @@ LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); LUAU_FASTFLAGVARIABLE(LuauApplyTypeFunctionFix, false); -LUAU_FASTFLAGVARIABLE(LuauSuccessTypingForEqualityOperations, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) +LUAU_FASTFLAG(LuauQuantifyConstrained) LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false) LUAU_FASTFLAGVARIABLE(LuauNonCopyableTypeVarFields, false) @@ -260,7 +260,6 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , booleanType(getSingletonTypes().booleanType) , threadType(getSingletonTypes().threadType) , anyType(getSingletonTypes().anyType) - , optionalNumberType(getSingletonTypes().optionalNumberType) , anyTypePack(getSingletonTypes().anyTypePack) , duplicateTypeAliases{{false, {}}} { @@ -679,7 +678,7 @@ static std::optional tryGetTypeGuardPredicate(const AstExprBinary& ex void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) { - ExprResult result = checkExpr(scope, *statement.condition); + WithPredicate result = checkExpr(scope, *statement.condition); ScopePtr ifScope = childScope(scope, statement.thenbody->location); resolve(result.predicates, ifScope, true); @@ -712,7 +711,7 @@ ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const Locat void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) { - ExprResult result = checkExpr(scope, *statement.condition); + WithPredicate result = checkExpr(scope, *statement.condition); ScopePtr whileScope = childScope(scope, statement.body->location); resolve(result.predicates, whileScope, true); @@ -728,16 +727,64 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) checkExpr(repScope, *statement.condition); } -void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, const Location& location) +void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location) { Unifier state = mkUnifier(location); - state.unifyLowerBound(subTy, superTy); + state.unifyLowerBound(subTy, superTy, demotedLevel); state.log.commit(); reportErrors(state.errors); } +struct Demoter : Substitution +{ + Demoter(TypeArena* arena) + : Substitution(TxnLog::empty(), arena) + { + } + + bool isDirty(TypeId ty) override + { + return get(ty); + } + + bool isDirty(TypePackId tp) override + { + return get(tp); + } + + TypeId clean(TypeId ty) override + { + auto ftv = get(ty); + LUAU_ASSERT(ftv); + return addType(FreeTypeVar{demotedLevel(ftv->level)}); + } + + TypePackId clean(TypePackId tp) override + { + auto ftp = get(tp); + LUAU_ASSERT(ftp); + return addTypePack(TypePackVar{FreeTypePack{demotedLevel(ftp->level)}}); + } + + TypeLevel demotedLevel(TypeLevel level) + { + return TypeLevel{level.level + 5000, level.subLevel}; + } + + void demote(std::vector>& expectedTypes) + { + if (!FFlag::LuauQuantifyConstrained) + return; + for (std::optional& ty : expectedTypes) + { + if (ty) + ty = substitute(*ty); + } + } +}; + void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) { std::vector> expectedTypes; @@ -760,11 +807,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) } } + Demoter demoter{¤tModule->internalTypes}; + demoter.demote(expectedTypes); + TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; if (FFlag::LuauReturnTypeInferenceInNonstrict ? FFlag::LuauLowerBoundsCalculation : useConstrainedIntersections()) { - unifyLowerBound(retPack, scope->returnType, return_.location); + unifyLowerBound(retPack, scope->returnType, demoter.demotedLevel(scope->level), return_.location); return; } @@ -1230,7 +1280,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) unify(retPack, varPack, forin.location); } else - unify(iterFunc->retType, varPack, forin.location); + unify(iterFunc->retTypes, varPack, forin.location); check(loopScope, *forin.body); } @@ -1611,7 +1661,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo currentModule->getModuleScope()->bindings[global.name] = Binding{fnType, global.location}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) { RecursionCounter _rc(&checkRecursionCount); if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) @@ -1620,7 +1670,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& return {errorRecoveryType(scope)}; } - ExprResult result; + WithPredicate result; if (auto a = expr.as()) result = checkExpr(scope, *a->expr, expectedType); @@ -1682,7 +1732,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& return result; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLocal& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLocal& expr) { std::optional lvalue = tryGetLValue(expr); LUAU_ASSERT(lvalue); // Guaranteed to not be nullopt - AstExprLocal is an LValue. @@ -1696,7 +1746,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLo return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) { std::optional lvalue = tryGetLValue(expr); LUAU_ASSERT(lvalue); // Guaranteed to not be nullopt - AstExprGlobal is an LValue. @@ -1708,7 +1758,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGl return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) { TypePackId varargPack = checkExprPack(scope, expr).type; @@ -1738,9 +1788,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCall& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCall& expr) { - ExprResult result = checkExprPack(scope, expr); + WithPredicate result = checkExprPack(scope, expr); TypePackId retPack = follow(result.type); if (auto pack = get(retPack)) @@ -1770,7 +1820,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa ice("Unknown TypePack type!", expr.location); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) { Name name = expr.index.value; @@ -2031,7 +2081,7 @@ TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) return ty; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) { TypeId ty = checkLValue(scope, expr); @@ -2042,7 +2092,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn return {ty}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) { auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, expectedType); @@ -2108,8 +2158,7 @@ TypeId TypeChecker::checkExprTable( if (errors.empty()) exprType = expectedProp.type; } - else if (expectedTable->indexer && (FFlag::LuauExpectedPropTypeFromIndexer ? maybeString(expectedTable->indexer->indexType) - : isString(expectedTable->indexer->indexType))) + else if (expectedTable->indexer && maybeString(expectedTable->indexer->indexType)) { ErrorVec errors = tryUnify(exprType, expectedTable->indexer->indexResultType, k->location); if (errors.empty()) @@ -2147,7 +2196,7 @@ TypeId TypeChecker::checkExprTable( return addType(table); } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) { RecursionCounter _rc(&checkRecursionCount); if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) @@ -2201,7 +2250,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa { if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) expectedResultType = prop->second.type; - else if (FFlag::LuauExpectedPropTypeFromIndexer && expectedIndexType && maybeString(*expectedIndexType)) + else if (expectedIndexType && maybeString(*expectedIndexType)) expectedResultType = expectedIndexResultType; } else if (expectedUnion) @@ -2236,9 +2285,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa return {checkExprTable(scope, expr, fieldTypes, expectedType)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr) { - ExprResult result = checkExpr(scope, *expr.expr); + WithPredicate result = checkExpr(scope, *expr.expr); TypeId operandType = follow(result.type); switch (expr.op) @@ -2466,62 +2515,50 @@ TypeId TypeChecker::checkRelationalOperation( std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); - if (FFlag::LuauSuccessTypingForEqualityOperations) + if (leftMetatable != rightMetatable) { - if (leftMetatable != rightMetatable) + bool matches = false; + if (isEquality) { - bool matches = false; - if (isEquality) + if (const UnionTypeVar* utv = get(leftType); utv && rightMetatable) { - if (const UnionTypeVar* utv = get(leftType); utv && rightMetatable) + for (TypeId leftOption : utv) { - for (TypeId leftOption : utv) + if (getMetatable(follow(leftOption)) == rightMetatable) { - if (getMetatable(follow(leftOption)) == rightMetatable) + matches = true; + break; + } + } + } + + if (!matches) + { + if (const UnionTypeVar* utv = get(rhsType); utv && leftMetatable) + { + for (TypeId rightOption : utv) + { + if (getMetatable(follow(rightOption)) == leftMetatable) { matches = true; break; } } } - - if (!matches) - { - if (const UnionTypeVar* utv = get(rhsType); utv && leftMetatable) - { - for (TypeId rightOption : utv) - { - if (getMetatable(follow(rightOption)) == leftMetatable) - { - matches = true; - break; - } - } - } - } - } - - - if (!matches) - { - reportError( - expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); - return errorRecoveryType(booleanType); } } - } - else - { - if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) + + + if (!matches) { reportError( expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); return errorRecoveryType(booleanType); } } + if (leftMetatable) { std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location); @@ -2532,7 +2569,7 @@ TypeId TypeChecker::checkRelationalOperation( if (isEquality) { Unifier state = mkUnifier(expr.location); - state.tryUnify(addTypePack({booleanType}), ftv->retType); + state.tryUnify(addTypePack({booleanType}), ftv->retTypes); if (!state.errors.empty()) { @@ -2721,7 +2758,7 @@ TypeId TypeChecker::checkBinaryOperation( } } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr) { if (expr.op == AstExprBinary::And) { @@ -2752,8 +2789,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; - ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); - ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); + WithPredicate lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); + WithPredicate rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); PredicateVec predicates; @@ -2770,18 +2807,18 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi } else { - ExprResult lhs = checkExpr(scope, *expr.left); - ExprResult rhs = checkExpr(scope, *expr.right); + WithPredicate lhs = checkExpr(scope, *expr.left); + WithPredicate rhs = checkExpr(scope, *expr.right); // Intentionally discarding predicates with other operators. return {checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)}; } } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr) { TypeId annotationType = resolveType(scope, *expr.annotation); - ExprResult result = checkExpr(scope, *expr.expr, annotationType); + WithPredicate result = checkExpr(scope, *expr.expr, annotationType); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. if (canUnify(annotationType, result.type, expr.location).empty()) @@ -2794,7 +2831,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy return {errorRecoveryType(annotationType), std::move(result.predicates)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) { const size_t oldSize = currentModule->errors.size(); @@ -2808,17 +2845,17 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) { - ExprResult result = checkExpr(scope, *expr.condition); + WithPredicate result = checkExpr(scope, *expr.condition); ScopePtr trueScope = childScope(scope, expr.trueExpr->location); resolve(result.predicates, trueScope, true); - ExprResult trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); + WithPredicate trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); ScopePtr falseScope = childScope(scope, expr.falseExpr->location); resolve(result.predicates, falseScope, false); - ExprResult falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); + WithPredicate falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); if (falseType.type == trueType.type) return {trueType.type}; @@ -3170,7 +3207,7 @@ std::pair TypeChecker::checkFunctionSignature( retPack = anyTypePack; else if (expectedFunctionType) { - auto [head, tail] = flatten(expectedFunctionType->retType); + auto [head, tail] = flatten(expectedFunctionType->retTypes); // Do not infer 'nil' as function return type if (!tail && head.size() == 1 && isNil(head[0])) @@ -3354,7 +3391,7 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE if (useConstrainedIntersections()) { - TypePackId retPack = follow(funTy->retType); + TypePackId retPack = follow(funTy->retTypes); // It is possible for a function to have no annotation and no return statement, and yet still have an ascribed return type // if it is expected to conform to some other interface. (eg the function may be a lambda passed as a callback) if (!hasReturn(function.body) && !function.returnAnnotation.has_value() && get(retPack)) @@ -3367,20 +3404,20 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE else { // We explicitly don't follow here to check if we have a 'true' free type instead of bound one - if (get_if(&funTy->retType->ty)) - *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; + if (get_if(&funTy->retTypes->ty)) + *asMutable(funTy->retTypes) = TypePack{{}, std::nullopt}; } bool reachesImplicitReturn = getFallthrough(function.body) != nullptr; - if (reachesImplicitReturn && !allowsNoReturnValues(follow(funTy->retType))) + if (reachesImplicitReturn && !allowsNoReturnValues(follow(funTy->retTypes))) { // If we're in nonstrict mode we want to only report this missing return // statement if there are type annotations on the function. In strict mode // we report it regardless. if (!isNonstrictMode() || function.returnAnnotation) { - reportError(getEndLocation(function), FunctionExitsWithoutReturning{funTy->retType}); + reportError(getEndLocation(function), FunctionExitsWithoutReturning{funTy->retTypes}); } } } @@ -3388,7 +3425,7 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE ice("Checking non functional type"); } -ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) +WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) { if (auto a = expr.as()) return checkExprPack(scope, *a); @@ -3654,7 +3691,7 @@ void TypeChecker::checkArgumentList( } } -ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const AstExprCall& expr) +WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExprCall& expr) { // evaluate type of function // decompose an intersection into its component overloads @@ -3722,7 +3759,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); - ExprResult argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); + WithPredicate argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); TypePackId argPack = argListResult.type; if (get(argPack)) @@ -3766,7 +3803,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A if (!overload && !overloadsThatDont.empty()) overload = get(overloadsThatDont[0]); if (overload) - return {errorRecoveryTypePack(overload->retType)}; + return {errorRecoveryTypePack(overload->retTypes)}; return {errorRecoveryTypePack(retPack)}; } @@ -3775,7 +3812,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st { std::vector> expectedTypes; - auto assignOption = [this, &expectedTypes](size_t index, std::optional ty) { + auto assignOption = [this, &expectedTypes](size_t index, TypeId ty) { if (index == expectedTypes.size()) { expectedTypes.push_back(ty); @@ -3790,7 +3827,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } else { - std::vector result = reduceUnion({*el, *ty}); + std::vector result = reduceUnion({*el, ty}); el = result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); } } @@ -3810,7 +3847,8 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st if (argsTail) { - if (const VariadicTypePack* vtp = get(follow(*argsTail))) + argsTail = follow(*argsTail); + if (const VariadicTypePack* vtp = get(*argsTail)) { while (index < argumentCount) assignOption(index++, vtp->ty); @@ -3819,11 +3857,14 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } } + Demoter demoter{¤tModule->internalTypes}; + demoter.demote(expectedTypes); + return expectedTypes; } -std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, +std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { LUAU_ASSERT(argLocations); @@ -3918,14 +3959,14 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (ftv->magicFunction) { // TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458 - if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) + if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) return *ret; } Unifier state = mkUnifier(expr.location); // Unify return types - checkArgumentList(scope, state, retPack, ftv->retType, /*argLocations*/ {}); + checkArgumentList(scope, state, retPack, ftv->retTypes, /*argLocations*/ {}); if (!state.errors.empty()) { return {}; @@ -3996,7 +4037,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // we eagerly assume that that's what you actually meant and we commit to it. // This could be incorrect if the function has an additional overload that // actually works. - // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); + // checkArgumentList(scope, editedState, retPack, ftv->retTypes, retLocations, CountMismatch::Return); return true; } } @@ -4027,7 +4068,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // we eagerly assume that that's what you actually meant and we commit to it. // This could be incorrect if the function has an additional overload that // actually works. - // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); + // checkArgumentList(scope, editedState, retPack, ftv->retTypes, retLocations, CountMismatch::Return); return true; } } @@ -4085,7 +4126,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast // Unify return types if (const FunctionTypeVar* ftv = get(overload)) { - checkArgumentList(scope, state, retPack, ftv->retType, {}); + checkArgumentList(scope, state, retPack, ftv->retTypes, {}); checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); } @@ -4110,7 +4151,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast return; } -ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, +WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, bool substituteFreeForNil, const std::vector& instantiateGenerics, const std::vector>& expectedTypes) { TypePackId pack = addTypePack(TypePack{}); @@ -4401,10 +4442,24 @@ TypeId Anyification::clean(TypeId ty) } else if (auto ctv = get(ty)) { - auto [t, ok] = normalize(ty, *arena, *iceHandler); - if (!ok) - normalizationTooComplex = true; - return t; + if (FFlag::LuauQuantifyConstrained) + { + std::vector copy = ctv->parts; + for (TypeId& ty : copy) + ty = replace(ty); + TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)}); + auto [t, ok] = normalize(res, *arena, *iceHandler); + if (!ok) + normalizationTooComplex = true; + return t; + } + else + { + auto [t, ok] = normalize(ty, *arena, *iceHandler); + if (!ok) + normalizationTooComplex = true; + return t; + } } else return anyType; diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index ba09df5..3d97e6e 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -66,7 +66,7 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId t } else if (const auto& itf = get(index)) { - std::optional r = first(follow(itf->retType)); + std::optional r = first(follow(itf->retTypes)); if (!r) return getSingletonTypes().nilType; else diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 33bfe25..5776293 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -29,8 +29,8 @@ LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) namespace Luau { -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); +std::optional> magicFunctionFormat( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); TypeId follow(TypeId t) { @@ -408,41 +408,48 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) return false; } -FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) +BlockedTypeVar::BlockedTypeVar() + : index(++nextIndex) +{ +} + +int BlockedTypeVar::nextIndex = 0; + +FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { } -FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) +FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : level(level) , argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { } -FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, +FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retTypes, std::optional defn, bool hasSelf) : generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { } FunctionTypeVar::FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, - TypePackId retType, std::optional defn, bool hasSelf) + TypePackId retTypes, std::optional defn, bool hasSelf) : level(level) , generics(generics) , genericPacks(genericPacks) , argTypes(argTypes) - , retType(retType) + , retTypes(retTypes) , definition(std::move(defn)) , hasSelf(hasSelf) { @@ -488,7 +495,7 @@ bool areEqual(SeenSet& seen, const FunctionTypeVar& lhs, const FunctionTypeVar& if (!areEqual(seen, *lhs.argTypes, *rhs.argTypes)) return false; - if (!areEqual(seen, *lhs.retType, *rhs.retType)) + if (!areEqual(seen, *lhs.retTypes, *rhs.retTypes)) return false; return true; @@ -678,7 +685,6 @@ static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent* static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}; static TypeVar anyType_{AnyTypeVar{}, /*persistent*/ true}; static TypeVar errorType_{ErrorTypeVar{}, /*persistent*/ true}; -static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}, /*persistent*/ true}; static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true}; static TypePackVar errorTypePack_{Unifiable::Error{}}; @@ -692,7 +698,6 @@ SingletonTypes::SingletonTypes() , trueType(&trueType_) , falseType(&falseType_) , anyType(&anyType_) - , optionalNumberType(&optionalNumberType_) , anyTypePack(&anyTypePack_) , arena(new TypeArena) { @@ -825,7 +830,7 @@ void persist(TypeId ty) else if (auto ftv = get(t)) { persist(ftv->argTypes); - persist(ftv->retType); + persist(ftv->retTypes); } else if (auto ttv = get(t)) { @@ -1100,10 +1105,10 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha return result; } -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +std::optional> magicFunctionFormat( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { - auto [paramPack, _predicates] = exprResult; + auto [paramPack, _predicates] = withPredicate; TypeArena& arena = typechecker.currentModule->internalTypes; @@ -1142,7 +1147,7 @@ std::optional> magicFunctionFormat( if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize)) typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}}); - return ExprResult{arena.addTypePack({typechecker.stringType})}; + return WithPredicate{arena.addTypePack({typechecker.stringType})}; } std::vector filterMap(TypeId type, TypeIdPredicate predicate) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 414b05f..877663d 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,6 +22,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) +LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau { @@ -1288,13 +1289,13 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); innerState.ctx = CountMismatch::Result; - innerState.tryUnify_(subFunction->retType, superFunction->retType); + innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); if (!reported) { if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); - else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType)) + else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) reportError( @@ -1312,7 +1313,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); ctx = CountMismatch::Result; - tryUnify_(subFunction->retType, superFunction->retType); + tryUnify_(subFunction->retTypes, superFunction->retTypes); } if (FFlag::LuauTxnLogRefreshFunctionPointers) @@ -2177,7 +2178,7 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas else if (auto fun = state.log.getMutable(ty)) { queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); + queueTypePack(queue, seenTypePacks, state, fun->retTypes, anyTypePack); } else if (auto table = state.log.getMutable(ty)) { @@ -2322,7 +2323,7 @@ void Unifier::tryUnifyWithConstrainedSuperTypeVar(TypeId subTy, TypeId superTy) superC->parts.push_back(subTy); } -void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy) +void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel) { // The duplication between this and regular typepack unification is tragic. @@ -2357,7 +2358,7 @@ void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy) if (!freeTailPack) return; - TypeLevel level = freeTailPack->level; + TypeLevel level = FFlag::LuauQuantifyConstrained ? demotedLevel : freeTailPack->level; TypePack* tp = getMutable(log.replace(tailPack, TypePack{})); diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 597b2f0..a34f760 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -1075,6 +1075,8 @@ void BytecodeBuilder::validate() const LUAU_ASSERT(i <= insns.size()); } + std::vector openCaptures; + // second pass: validate the rest of the bytecode for (size_t i = 0; i < insns.size();) { @@ -1121,6 +1123,8 @@ void BytecodeBuilder::validate() const case LOP_CLOSEUPVALS: VREG(LUAU_INSN_A(insn)); + while (openCaptures.size() && openCaptures.back() >= LUAU_INSN_A(insn)) + openCaptures.pop_back(); break; case LOP_GETIMPORT: @@ -1388,8 +1392,12 @@ void BytecodeBuilder::validate() const switch (LUAU_INSN_A(insn)) { case LCT_VAL: + VREG(LUAU_INSN_B(insn)); + break; + case LCT_REF: VREG(LUAU_INSN_B(insn)); + openCaptures.push_back(LUAU_INSN_B(insn)); break; case LCT_UPVAL: @@ -1409,6 +1417,12 @@ void BytecodeBuilder::validate() const LUAU_ASSERT(i <= insns.size()); } + // all CAPTURE REF instructions must have a CLOSEUPVALS instruction after them in the bytecode stream + // this doesn't guarantee safety as it doesn't perform basic block based analysis, but if this fails + // then the bytecode is definitely unsafe to run since the compiler won't generate backwards branches + // except for loop edges + LUAU_ASSERT(openCaptures.empty()); + #undef VREG #undef VREGEND #undef VUPVAL diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 7431cde..52dc924 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -246,6 +246,14 @@ struct Compiler f.canInline = true; f.stackSize = stackSize; f.costModel = modelCost(func->body, func->args.data, func->args.size); + + // track functions that only ever return a single value so that we can convert multret calls to fixedret calls + if (allPathsEndWithReturn(func->body)) + { + ReturnVisitor returnVisitor(this); + stat->visit(&returnVisitor); + f.returnsOne = returnVisitor.returnsOne; + } } upvals.clear(); // note: instead of std::move above, we copy & clear to preserve capacity for future pushes @@ -260,6 +268,19 @@ struct Compiler { if (AstExprCall* expr = node->as()) { + // Optimization: convert multret calls to functions that always return one value to fixedret calls; this facilitates inlining + if (options.optimizationLevel >= 2) + { + AstExprFunction* func = getFunctionExpr(expr->func); + Function* fi = func ? functions.find(func) : nullptr; + + if (fi && fi->returnsOne) + { + compileExprTemp(node, target); + return false; + } + } + // We temporarily swap out regTop to have targetTop work correctly... // This is a crude hack but it's necessary for correctness :( RegScope rs(this, target); @@ -447,7 +468,9 @@ struct Compiler return false; } - // TODO: we can compile multret functions if all returns of the function are multret as well + // we can't inline multret functions because the caller expects L->top to be adjusted: + // - inlined return compiles to a JUMP, and we don't have an instruction that adjusts L->top arbitrarily + // - even if we did, right now all L->top adjustments are immediately consumed by the next instruction, and for now we want to preserve that if (multRet) { bytecode.addDebugRemark("inlining failed: can't convert fixed returns to multret"); @@ -492,7 +515,7 @@ struct Compiler size_t oldLocals = localStack.size(); // note that we push the frame early; this is needed to block recursive inline attempts - inlineFrames.push_back({func, target, targetCount}); + inlineFrames.push_back({func, oldLocals, target, targetCount}); // evaluate all arguments; note that we don't emit code for constant arguments (relying on constant folding) for (size_t i = 0; i < func->args.size; ++i) @@ -593,6 +616,8 @@ struct Compiler { for (size_t i = 0; i < targetCount; ++i) bytecode.emitABC(LOP_LOADNIL, uint8_t(target + i), 0, 0); + + closeLocals(oldLocals); } popLocals(oldLocals); @@ -2355,6 +2380,8 @@ struct Compiler compileExprListTemp(stat->list, frame.target, frame.targetCount, /* targetTop= */ false); + closeLocals(frame.localOffset); + if (!fallthrough) { size_t jumpLabel = bytecode.emitLabel(); @@ -3316,6 +3343,48 @@ struct Compiler std::vector upvals; }; + struct ReturnVisitor: AstVisitor + { + Compiler* self; + bool returnsOne = true; + + ReturnVisitor(Compiler* self) + : self(self) + { + } + + bool visit(AstExpr* expr) override + { + return false; + } + + bool visit(AstStatReturn* stat) override + { + if (stat->list.size == 1) + { + AstExpr* value = stat->list.data[0]; + + if (AstExprCall* expr = value->as()) + { + AstExprFunction* func = self->getFunctionExpr(expr->func); + Function* fi = func ? self->functions.find(func) : nullptr; + + returnsOne &= fi && fi->returnsOne; + } + else if (value->is()) + { + returnsOne = false; + } + } + else + { + returnsOne = false; + } + + return false; + } + }; + struct RegScope { RegScope(Compiler* self) @@ -3351,6 +3420,7 @@ struct Compiler uint64_t costModel = 0; unsigned int stackSize = 0; bool canInline = false; + bool returnsOne = false; }; struct Local @@ -3384,6 +3454,8 @@ struct Compiler { AstExprFunction* func; + size_t localOffset; + uint8_t target; uint8_t targetCount; diff --git a/Sources.cmake b/Sources.cmake index 99007e8..f261cba 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -65,12 +65,13 @@ target_sources(Luau.CodeGen PRIVATE target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/AstQuery.h Analysis/include/Luau/Autocomplete.h - Analysis/include/Luau/NotNull.h Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/Clone.h Analysis/include/Luau/Config.h + Analysis/include/Luau/Constraint.h Analysis/include/Luau/ConstraintGraphBuilder.h Analysis/include/Luau/ConstraintSolver.h + Analysis/include/Luau/ConstraintSolverLogger.h Analysis/include/Luau/Documentation.h Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h @@ -97,6 +98,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TxnLog.h Analysis/include/Luau/TypeArena.h Analysis/include/Luau/TypeAttach.h + Analysis/include/Luau/TypeChecker2.h Analysis/include/Luau/TypedAllocator.h Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypePack.h @@ -113,8 +115,10 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/BuiltinDefinitions.cpp Analysis/src/Clone.cpp Analysis/src/Config.cpp + Analysis/src/Constraint.cpp Analysis/src/ConstraintGraphBuilder.cpp Analysis/src/ConstraintSolver.cpp + Analysis/src/ConstraintSolverLogger.cpp Analysis/src/Error.cpp Analysis/src/Frontend.cpp Analysis/src/Instantiation.cpp @@ -136,6 +140,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/TxnLog.cpp Analysis/src/TypeArena.cpp Analysis/src/TypeAttach.cpp + Analysis/src/TypeChecker2.cpp Analysis/src/TypedAllocator.cpp Analysis/src/TypeInfer.cpp Analysis/src/TypePack.cpp @@ -245,7 +250,6 @@ if(TARGET Luau.UnitTest) tests/AstQuery.test.cpp tests/AstVisitor.test.cpp tests/Autocomplete.test.cpp - tests/NotNull.test.cpp tests/BuiltinDefinitions.test.cpp tests/Compiler.test.cpp tests/Config.test.cpp diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 5e02c2e..bdcb85c 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -418,7 +418,7 @@ typedef struct Table CommonHeader; - uint8_t flags; /* 1<

flags = 0 +#define invalidateTMcache(t) t->tmcache = 0 // empty hash data points to dummynode so that we can always dereference it const LuaNode luaH_dummynode = { @@ -479,7 +479,7 @@ Table* luaH_new(lua_State* L, int narray, int nhash) Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); luaC_init(L, t, LUA_TTABLE); t->metatable = NULL; - t->flags = cast_byte(~0); + t->tmcache = cast_byte(~0); t->array = NULL; t->sizearray = 0; t->lastfree = 0; @@ -778,7 +778,7 @@ Table* luaH_clone(lua_State* L, Table* tt) Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); luaC_init(L, t, LUA_TTABLE); t->metatable = tt->metatable; - t->flags = tt->flags; + t->tmcache = tt->tmcache; t->array = NULL; t->sizearray = 0; t->lsizenode = 0; @@ -835,5 +835,5 @@ void luaH_clear(Table* tt) } /* back to empty -> no tag methods present */ - tt->flags = cast_byte(~0); + tt->tmcache = cast_byte(~0); } diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index 9b99506..e7df4e5 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -88,8 +88,8 @@ const TValue* luaT_gettm(Table* events, TMS event, TString* ename) const TValue* tm = luaH_getstr(events, ename); LUAU_ASSERT(event <= TM_EQ); if (ttisnil(tm)) - { /* no tag method? */ - events->flags |= cast_byte(1u << event); /* cache this fact */ + { /* no tag method? */ + events->tmcache |= cast_byte(1u << event); /* cache this fact */ return NULL; } else diff --git a/VM/src/ltm.h b/VM/src/ltm.h index e1b95c2..a522394 100644 --- a/VM/src/ltm.h +++ b/VM/src/ltm.h @@ -41,10 +41,10 @@ typedef enum } TMS; // clang-format on -#define gfasttm(g, et, e) ((et) == NULL ? NULL : ((et)->flags & (1u << (e))) ? NULL : luaT_gettm(et, e, (g)->tmname[e])) +#define gfasttm(g, et, e) ((et) == NULL ? NULL : ((et)->tmcache & (1u << (e))) ? NULL : luaT_gettm(et, e, (g)->tmname[e])) #define fasttm(l, et, e) gfasttm(l->global, et, e) -#define fastnotm(et, e) ((et) == NULL || ((et)->flags & (1u << (e)))) +#define fastnotm(et, e) ((et) == NULL || ((et)->tmcache & (1u << (e)))) LUAI_DATA const char* const luaT_typenames[]; LUAI_DATA const char* const luaT_eventname[]; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index dea1ab1..f3b0bca 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1992,6 +1992,7 @@ local fp: @1= f auto ac = autocomplete('1'); + REQUIRE_EQ("({| x: number, y: number |}) -> number", toString(requireType("f"))); CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } @@ -2620,7 +2621,6 @@ a = if temp then even elseif true then temp else e@9 TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_else_regression") { - ScopedFastFlag FFlagLuauIfElseExprFixCompletionIssue("LuauIfElseExprFixCompletionIssue", true); check(R"( local abcdef = 0; local temp = false diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 6eee254..036bf12 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -4992,6 +4992,147 @@ RETURN R1 1 )"); } +TEST_CASE("InlineCapture") +{ + // if the argument is captured by a nested closure, normally we can rely on capture by value + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local x = ... +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +NEWCLOSURE R2 P1 +CAPTURE VAL R1 +RETURN R2 1 +)"); + + // if the argument is a constant, we move it to a register so that capture by value can happen + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local y = foo(42) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R2 42 +NEWCLOSURE R1 P1 +CAPTURE VAL R2 +RETURN R1 1 +)"); + + // if the argument is an externally mutated variable, we copy it to an argument and capture it by value + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return function() return a end +end + +local x x = 42 +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R1 +LOADN R1 42 +MOVE R3 R1 +NEWCLOSURE R2 P1 +CAPTURE VAL R3 +RETURN R2 1 +)"); + + // finally, if the argument is mutated internally, we must capture it by reference and close the upvalue + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = a or 42 + return function() return a end +end + +local y = foo() +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +LOADNIL R2 +ORK R2 R2 K1 +NEWCLOSURE R1 P1 +CAPTURE REF R2 +CLOSEUPVALS R2 +RETURN R1 1 +)"); + + // note that capture might need to be performed during the fallthrough block + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + a = a or 42 + print(function() return a end) +end + +local x = ... +local y = foo(x) +return y +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +MOVE R3 R1 +ORK R3 R3 K1 +GETIMPORT R4 3 +NEWCLOSURE R5 P1 +CAPTURE REF R3 +CALL R4 1 0 +LOADNIL R2 +CLOSEUPVALS R3 +RETURN R2 1 +)"); + + // note that mutation and capture might be inside internal control flow + // TODO: this has an oddly redundant CLOSEUPVALS after JUMP; it's not due to inlining, and is an artifact of how StatBlock/StatReturn interact + // fixing this would reduce the number of redundant CLOSEUPVALS a bit but it only affects bytecode size as these instructions aren't executed + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + if not a then + local b b = 42 + return function() return b end + end +end + +local x = ... +local y = foo(x) +return y, x +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +GETVARARGS R1 1 +JUMPIF R1 L0 +LOADNIL R3 +LOADN R3 42 +NEWCLOSURE R2 P1 +CAPTURE REF R3 +CLOSEUPVALS R3 +JUMP L1 +CLOSEUPVALS R3 +L0: LOADNIL R2 +L1: MOVE R3 R2 +MOVE R4 R1 +RETURN R3 2 +)"); +} + TEST_CASE("InlineFallthrough") { // if the function doesn't return, we still fill the results with nil @@ -5044,27 +5185,6 @@ RETURN R1 -1 )"); } -TEST_CASE("InlineCapture") -{ - // can't inline function with nested functions that capture locals because they might be constants - CHECK_EQ("\n" + compileFunction(R"( -local function foo(a) - local function bar() - return a - end - return bar() -end -)", - 1, 2), - R"( -NEWCLOSURE R1 P0 -CAPTURE VAL R0 -MOVE R2 R1 -CALL R2 0 -1 -RETURN R2 -1 -)"); -} - TEST_CASE("InlineArgMismatch") { // when inlining a function, we must respect all the usual rules @@ -5491,6 +5611,96 @@ RETURN R2 1 )"); } +TEST_CASE("InlineMultret") +{ + // inlining a function in multret context is prohibited since we can't adjust L->top outside of CALL/GETVARARGS + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a() +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); + + // however, if we can deduce statically that a function always returns a single value, the inlining will work + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +LOADN R1 42 +RETURN R1 1 +)"); + + // this analysis will also propagate through other functions + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return a +end + +local function bar(a) + return foo(a) +end + +return bar(42) +)", + 2, 2), + R"( +DUPCLOSURE R0 K0 +DUPCLOSURE R1 K1 +LOADN R2 42 +RETURN R2 1 +)"); + + // we currently don't do this analysis fully for recursive functions since they can't be inlined anyway + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return foo(a) +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +CAPTURE VAL R0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); + + // and unfortunately we can't do this analysis for builtins or method calls due to getfenv + CHECK_EQ("\n" + compileFunction(R"( +local function foo(a) + return math.abs(a) +end + +return foo(42) +)", + 1, 2), + R"( +DUPCLOSURE R0 K0 +MOVE R1 R0 +LOADN R2 42 +CALL R1 1 -1 +RETURN R1 -1 +)"); +} + TEST_CASE("ReturnConsecutive") { // we can return a single local directly diff --git a/tests/ConstraintGraphBuilder.test.cpp b/tests/ConstraintGraphBuilder.test.cpp index ab5af4f..96b2161 100644 --- a/tests/ConstraintGraphBuilder.test.cpp +++ b/tests/ConstraintGraphBuilder.test.cpp @@ -17,13 +17,13 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello_world") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); REQUIRE(2 == constraints.size()); ToStringOptions opts; - CHECK("a <: string" == toString(*constraints[0], opts)); - CHECK("b <: a" == toString(*constraints[1], opts)); + CHECK("string <: a" == toString(*constraints[0], opts)); + CHECK("a <: b" == toString(*constraints[1], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "primitives") @@ -36,15 +36,34 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "primitives") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); - REQUIRE(4 == constraints.size()); + REQUIRE(3 == constraints.size()); ToStringOptions opts; - CHECK("a <: string" == toString(*constraints[0], opts)); - CHECK("b <: number" == toString(*constraints[1], opts)); - CHECK("c <: boolean" == toString(*constraints[2], opts)); - CHECK("d <: nil" == toString(*constraints[3], opts)); + CHECK("string <: a" == toString(*constraints[0], opts)); + CHECK("number <: b" == toString(*constraints[1], opts)); + CHECK("boolean <: c" == toString(*constraints[2], opts)); +} + +TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "nil_primitive") +{ + AstStatBlock* block = parse(R"( + local function a() return nil end + local b = a() + )"); + + cgb.visit(block); + auto constraints = collectConstraints(cgb.rootScope); + + ToStringOptions opts; + REQUIRE(5 <= constraints.size()); + + CHECK("*blocked-1* ~ gen () -> (a...)" == toString(*constraints[0], opts)); + CHECK("b ~ inst *blocked-1*" == toString(*constraints[1], opts)); + CHECK("() -> (c...) <: b" == toString(*constraints[2], opts)); + CHECK("c... <: d" == toString(*constraints[3], opts)); + CHECK("nil <: a..." == toString(*constraints[4], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application") @@ -55,15 +74,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); REQUIRE(4 == constraints.size()); ToStringOptions opts; - CHECK("a <: string" == toString(*constraints[0], opts)); + CHECK("string <: a" == toString(*constraints[0], opts)); CHECK("b ~ inst a" == toString(*constraints[1], opts)); - CHECK("(string) -> (c, d...) <: b" == toString(*constraints[2], opts)); - CHECK("e <: c" == toString(*constraints[3], opts)); + CHECK("(string) -> (c...) <: b" == toString(*constraints[2], opts)); + CHECK("c... <: d" == toString(*constraints[3], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") @@ -75,13 +94,13 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); REQUIRE(2 == constraints.size()); ToStringOptions opts; - CHECK("a ~ gen (b) -> (c...)" == toString(*constraints[0], opts)); - CHECK("b <: c..." == toString(*constraints[1], opts)); + CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts)); + CHECK("a <: b..." == toString(*constraints[1], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "recursive_function") @@ -93,15 +112,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "recursive_function") )"); cgb.visit(block); - std::vector constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(cgb.rootScope); REQUIRE(4 == constraints.size()); ToStringOptions opts; - CHECK("a ~ gen (b) -> (c...)" == toString(*constraints[0], opts)); - CHECK("d ~ inst a" == toString(*constraints[1], opts)); - CHECK("(b) -> (e, f...) <: d" == toString(*constraints[2], opts)); - CHECK("e <: c..." == toString(*constraints[3], opts)); + CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts)); + CHECK("c ~ inst (a) -> (b...)" == toString(*constraints[1], opts)); + CHECK("(a) -> (d...) <: c" == toString(*constraints[2], opts)); + CHECK("d... <: b..." == toString(*constraints[3], opts)); } TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 232ec2d..ac22f65 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -345,7 +345,7 @@ void Fixture::dumpErrors(std::ostream& os, const std::vector& errors) if (error.location.begin.line >= lines.size()) { os << "\tSource not available?" << std::endl; - return; + continue; } std::string_view theLine = lines[error.location.begin.line]; @@ -430,6 +430,7 @@ ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() : Fixture() , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} { + BlockedTypeVar::nextIndex = 0; } ModuleName fromString(std::string_view name) diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index c055466..b9c2470 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -97,8 +97,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "find_a_require") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(1, res.requires.size()); - CHECK_EQ(res.requires[0].first, "Modules/Foo/Bar"); + CHECK_EQ(1, res.requireList.size()); + CHECK_EQ(res.requireList[0].first, "Modules/Foo/Bar"); } // It could be argued that this should not work. @@ -113,7 +113,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "find_a_require_inside_a_function") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(1, res.requires.size()); + CHECK_EQ(1, res.requireList.size()); } TEST_CASE_FIXTURE(FrontendFixture, "real_source") @@ -138,7 +138,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "real_source") NaiveFileResolver naiveFileResolver; auto res = traceRequires(&naiveFileResolver, program, ""); - CHECK_EQ(8, res.requires.size()); + CHECK_EQ(8, res.requireList.size()); } TEST_CASE_FIXTURE(FrontendFixture, "automatically_check_dependent_scripts") diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 89b13ab..d585b73 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -102,7 +102,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") const FunctionTypeVar* ftv = get(methodType); REQUIRE(ftv != nullptr); - std::optional methodReturnType = first(ftv->retType); + std::optional methodReturnType = first(ftv->retTypes); REQUIRE(methodReturnType); CHECK_EQ(methodReturnType, counterCopy); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index c055610..50dcbad 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -13,6 +13,57 @@ using namespace Luau; TEST_SUITE_BEGIN("NonstrictModeTests"); +TEST_CASE_FIXTURE(Fixture, "globals") +{ + CheckResult result = check(R"( + --!nonstrict + foo = true + foo = "now i'm a string!" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "globals2") +{ + ScopedFastFlag sff[]{ + {"LuauReturnTypeInferenceInNonstrict", true}, + {"LuauLowerBoundsCalculation", true}, + }; + + CheckResult result = check(R"( + --!nonstrict + foo = function() return 1 end + foo = "now i'm a string!" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("() -> number", toString(tm->wantedType)); + CHECK_EQ("string", toString(tm->givenType)); + CHECK_EQ("() -> number", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "globals_everywhere") +{ + CheckResult result = check(R"( + --!nonstrict + foo = 1 + + if true then + bar = 2 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "function_returns_number_or_string") { ScopedFastFlag sff[]{{"LuauReturnTypeInferenceInNonstrict", true}, {"LuauLowerBoundsCalculation", true}}; @@ -51,7 +102,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") REQUIRE_EQ("any", toString(args[0])); REQUIRE_EQ("any", toString(args[1])); - auto rets = flatten(ftv->retType).first; + auto rets = flatten(ftv->retTypes).first; REQUIRE_EQ(0, rets.size()); } diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 2876175..284230c 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -837,6 +837,7 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, }; // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. @@ -867,16 +868,17 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect CHECK("{+ y: number +}" == toString(args[2])); CHECK("{+ z: string +}" == toString(args[3])); - std::vector ret = flatten(ftv->retType).first; + std::vector ret = flatten(ftv->retTypes).first; REQUIRE(1 == ret.size()); - CHECK("{| x: a & {- w: boolean, y: number, z: string -} |}" == toString(ret[0])); + CHECK("{| x: a & {+ w: boolean, y: number, z: string +} |}" == toString(ret[0])); } TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_3") { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, }; // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. @@ -906,16 +908,17 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect CHECK("t1 where t1 = {+ y: t1 +}" == toString(args[1])); CHECK("{+ z: string +}" == toString(args[2])); - std::vector ret = flatten(ftv->retType).first; + std::vector ret = flatten(ftv->retTypes).first; REQUIRE(1 == ret.size()); - CHECK("{| x: {- x: boolean, y: t1, z: string -} |} where t1 = {+ y: t1 +}" == toString(ret[0])); + CHECK("{| x: {+ x: boolean, y: t1, z: string +} |} where t1 = {+ y: t1 +}" == toString(ret[0])); } TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersection_4") { ScopedFastFlag flags[] = { {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, }; // We use a function and inferred parameter types to prevent intermediate normalizations from being performed. @@ -944,13 +947,13 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect REQUIRE(3 == args.size()); CHECK("{+ x: boolean +}" == toString(args[0])); - CHECK("{+ y: t1 +} where t1 = {| x: {- x: boolean, y: t1, z: string -} |}" == toString(args[1])); + CHECK("{+ y: t1 +} where t1 = {| x: {+ x: boolean, y: t1, z: string +} |}" == toString(args[1])); CHECK("{+ z: string +}" == toString(args[2])); - std::vector ret = flatten(ftv->retType).first; + std::vector ret = flatten(ftv->retTypes).first; REQUIRE(1 == ret.size()); - CHECK("t1 where t1 = {| x: {- x: boolean, y: t1, z: string -} |}" == toString(ret[0])); + CHECK("t1 where t1 = {| x: {+ x: boolean, y: t1, z: string +} |}" == toString(ret[0])); } TEST_CASE_FIXTURE(Fixture, "nested_table_normalization_with_non_table__no_ice") @@ -1062,4 +1065,29 @@ export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "normalization_does_not_convert_ever") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + CheckResult result = check(R"( + --!strict + local function f() + if math.random() > 0.5 then + return true + end + type Ret = typeof(f()) + if math.random() > 0.5 then + return "something" + end + return "something" :: Ret + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("() -> boolean | string", toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tests/NotNull.test.cpp b/tests/NotNull.test.cpp index 1a323c8..ed1c25e 100644 --- a/tests/NotNull.test.cpp +++ b/tests/NotNull.test.cpp @@ -75,9 +75,9 @@ TEST_CASE("basic_stuff") t->y = 3.14f; const NotNull u = t; - // u->x = 44; // nope + u->x = 44; int v = u->x; - CHECK(v == 5); + CHECK(v == 44); bar(a); @@ -96,8 +96,11 @@ TEST_CASE("basic_stuff") TEST_CASE("hashable") { std::unordered_map, const char*> map; - NotNull a{new int(8)}; - NotNull b{new int(10)}; + int a_ = 8; + int b_ = 10; + + NotNull a{&a_}; + NotNull b{&b_}; std::string hello = "hello"; std::string world = "world"; @@ -108,9 +111,47 @@ TEST_CASE("hashable") CHECK_EQ(2, map.size()); CHECK_EQ(hello.c_str(), map[a]); CHECK_EQ(world.c_str(), map[b]); +} - delete a; - delete b; +TEST_CASE("const") +{ + int p = 0; + int q = 0; + + NotNull n{&p}; + + *n = 123; + + NotNull m = n; // Conversion from NotNull to NotNull is allowed + + CHECK(123 == *m); // readonly access of m is ok + + // *m = 321; // nope. m points at const data. + + // NotNull o = m; // nope. Conversion from NotNull to NotNull is forbidden + + NotNull n2{&q}; + m = n2; // ok. m points to const data, but is not itself const + + const NotNull m2 = n; + // m2 = n2; // nope. m2 is const. + *m2 = 321; // ok. m2 is const, but points to mutable data + + CHECK(321 == *n); +} + +TEST_CASE("const_compatibility") +{ + int* raw = new int(8); + + NotNull a(raw); + NotNull b(raw); + NotNull c = a; + // NotNull d = c; // nope - no conversion from const to non-const + + CHECK_EQ(*c, 8); + + delete raw; } TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 4d9fad1..4d2e94e 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -60,6 +60,42 @@ TEST_CASE_FIXTURE(Fixture, "named_table") CHECK_EQ("TheTable", toString(&table)); } +TEST_CASE_FIXTURE(Fixture, "empty_table") +{ + ScopedFastFlag LuauToStringTableBracesNewlines("LuauToStringTableBracesNewlines", true); + CheckResult result = check(R"( + local a: {} + )"); + + CHECK_EQ("{| |}", toString(requireType("a"))); + + // Should stay the same with useLineBreaks enabled + ToStringOptions opts; + opts.useLineBreaks = true; + CHECK_EQ("{| |}", toString(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") +{ + ScopedFastFlag LuauToStringTableBracesNewlines("LuauToStringTableBracesNewlines", true); + CheckResult result = check(R"( + local a: { prop: string, anotherProp: number, thirdProp: boolean } + )"); + + ToStringOptions opts; + opts.useLineBreaks = true; + opts.indent = true; + + //clang-format off + CHECK_EQ("{|\n" + " anotherProp: number,\n" + " prop: string,\n" + " thirdProp: boolean\n" + "|}", + toString(requireType("a"), opts)); + //clang-format on +} + TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index b9e1ae9..ccdd2b3 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -70,7 +70,7 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotations_are_checked") const FunctionTypeVar* ftv = get(fiftyType); REQUIRE(ftv != nullptr); - TypePackId retPack = ftv->retType; + TypePackId retPack = ftv->retTypes; const TypePack* tp = get(retPack); REQUIRE(tp != nullptr); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index a28ba49..036a667 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -45,7 +45,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_return_type") const FunctionTypeVar* takeFiveType = get(requireType("take_five")); REQUIRE(takeFiveType != nullptr); - std::vector retVec = flatten(takeFiveType->retType).first; + std::vector retVec = flatten(takeFiveType->retTypes).first; REQUIRE(!retVec.empty()); REQUIRE_EQ(*follow(retVec[0]), *typeChecker.numberType); @@ -345,7 +345,7 @@ TEST_CASE_FIXTURE(Fixture, "local_function") const FunctionTypeVar* ftv = get(h); REQUIRE(ftv != nullptr); - std::optional rt = first(ftv->retType); + std::optional rt = first(ftv->retTypes); REQUIRE(bool(rt)); TypeId retType = follow(*rt); @@ -361,7 +361,7 @@ TEST_CASE_FIXTURE(Fixture, "func_expr_doesnt_leak_free") LUAU_REQUIRE_NO_ERRORS(result); const Luau::FunctionTypeVar* fn = get(requireType("p")); REQUIRE(fn); - auto ret = first(fn->retType); + auto ret = first(fn->retTypes); REQUIRE(ret); REQUIRE(get(follow(*ret))); } @@ -460,7 +460,7 @@ TEST_CASE_FIXTURE(Fixture, "complicated_return_types_require_an_explicit_annotat const FunctionTypeVar* functionType = get(requireType("most_of_the_natural_numbers")); - std::optional retType = first(functionType->retType); + std::optional retType = first(functionType->retTypes); REQUIRE(retType); CHECK(get(*retType)); } @@ -1619,4 +1619,56 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "quantify_constrained_types") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + CheckResult result = check(R"( + --!strict + local function foo(f) + f(5) + f("hi") + local function g() + return f + end + local h = g() + h(true) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("((boolean | number | string) -> (a...)) -> ()", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "call_o_with_another_argument_after_foo_was_quantified") +{ + ScopedFastFlag sff[]{ + {"LuauLowerBoundsCalculation", true}, + {"LuauQuantifyConstrained", true}, + }; + + CheckResult result = check(R"( + local function f(o) + local t = {} + t[o] = true + + local function foo(o) + o.m1(5) + t[o] = nil + end + + o.m1("hi") + + return t + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + // TODO: check the normalized type of f +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index fbda8be..edb5adc 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -224,7 +224,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function") const FunctionTypeVar* idFun = get(idType); REQUIRE(idFun); auto [args, varargs] = flatten(idFun->argTypes); - auto [rets, varrets] = flatten(idFun->retType); + auto [rets, varrets] = flatten(idFun->retTypes); CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); @@ -247,7 +247,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") const FunctionTypeVar* idFun = get(idType); REQUIRE(idFun); auto [args, varargs] = flatten(idFun->argTypes); - auto [rets, varrets] = flatten(idFun->retType); + auto [rets, varrets] = flatten(idFun->retTypes); CHECK_EQ(idFun->generics.size(), 1); CHECK_EQ(idFun->genericPacks.size(), 0); @@ -882,7 +882,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") const FunctionTypeVar* foo = get(follow(fooProp->type)); REQUIRE(bool(foo)); - std::optional ret_ = first(foo->retType); + std::optional ret_ = first(foo->retTypes); REQUIRE(bool(ret_)); TypeId ret = follow(*ret_); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 0361493..fd9b1dd 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -90,7 +90,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") const FunctionTypeVar* functionType = get(requireType("add")); - std::optional retType = first(functionType->retType); + std::optional retType = first(functionType->retTypes); REQUIRE(retType.has_value()); CHECK_EQ(typeChecker.numberType, follow(*retType)); CHECK_EQ(requireType("n"), typeChecker.numberType); @@ -777,8 +777,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") TEST_CASE_FIXTURE(BuiltinsFixture, "equality_operations_succeed_if_any_union_branch_succeeds") { - ScopedFastFlag sff("LuauSuccessTypingForEqualityOperations", true); - CheckResult result = check(R"( local mm = {} type Foo = typeof(setmetatable({}, mm)) diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 22fb3b6..487e597 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -472,6 +472,7 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") ScopedFastFlag sff[]{ {"LuauLowerBoundsCalculation", true}, {"LuauNormalizeFlagIsConservative", true}, + {"LuauQuantifyConstrained", true}, }; CheckResult result = check(R"( @@ -494,8 +495,8 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") )"); LUAU_REQUIRE_NO_ERRORS(result); - // TODO: We're missing generics a... and b... - CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); + // TODO: We're missing generics b... + CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 207b3cf..cefba4b 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -13,8 +13,8 @@ using namespace Luau; namespace { -std::optional> magicFunctionInstanceIsA( - TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +std::optional> magicFunctionInstanceIsA( + TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { if (expr.args.size != 1) return std::nullopt; @@ -32,7 +32,7 @@ std::optional> magicFunctionInstanceIsA( unfreeze(typeChecker.globalTypes); TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); freeze(typeChecker.globalTypes); - return ExprResult{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; + return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; } struct RefinementClassFixture : Fixture diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index d622d4a..87d4965 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -642,7 +642,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2") const TableTypeVar* argType = get(follow(argVec[0])); REQUIRE(argType != nullptr); - std::vector retVec = flatten(ftv->retType).first; + std::vector retVec = flatten(ftv->retTypes).first; const TableTypeVar* retType = get(follow(retVec[0])); REQUIRE(retType != nullptr); @@ -691,7 +691,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_value_property_in_literal") const FunctionTypeVar* fType = get(requireType("f")); REQUIRE(fType != nullptr); - auto retType_ = first(fType->retType); + auto retType_ = first(fType->retTypes); REQUIRE(bool(retType_)); auto retType = get(follow(*retType_)); @@ -1881,7 +1881,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") REQUIRE(prop.type); const FunctionTypeVar* ftv = get(follow(prop.type)); REQUIRE(ftv); - const TypePack* res = get(follow(ftv->retType)); + const TypePack* res = get(follow(ftv->retTypes)); REQUIRE(res); REQUIRE(res->head.size() == 1); const MetatableTypeVar* mtv = get(follow(res->head[0])); @@ -2584,7 +2584,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_quantify_table_that_belongs_to_outer_sc const FunctionTypeVar* newType = get(follow(counterType->props["new"].type)); REQUIRE(newType); - std::optional newRetType = *first(newType->retType); + std::optional newRetType = *first(newType->retTypes); REQUIRE(newRetType); const MetatableTypeVar* newRet = get(follow(*newRetType)); @@ -2977,7 +2977,6 @@ TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") { - ScopedFastFlag luauExpectedPropTypeFromIndexer{"LuauExpectedPropTypeFromIndexer", true}; ScopedFastFlag luauSubtypingAddOptPropsToUnsealedTables{"LuauSubtypingAddOptPropsToUnsealedTables", true}; CheckResult result = check(R"( @@ -2992,8 +2991,6 @@ TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra_2") { - ScopedFastFlag luauExpectedPropTypeFromIndexer{"LuauExpectedPropTypeFromIndexer", true}; - CheckResult result = check(R"( type X = {[any]: string | boolean} diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index cf0c988..6257cda 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -13,8 +13,9 @@ #include -LUAU_FASTFLAG(LuauLowerBoundsCalculation) -LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) +LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); using namespace Luau; @@ -43,10 +44,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_error") CheckResult result = check("local a = 7 local b = 'hi' a = b"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{ - requireType("a"), - requireType("b"), - }})); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); } TEST_CASE_FIXTURE(Fixture, "tc_error_2") @@ -86,6 +84,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_locals_via_assignment_from_its_call_site") TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") { ScopedFastFlag sff[]{ + {"DebugLuauDeferredConstraintResolution", false}, {"LuauReturnTypeInferenceInNonstrict", true}, {"LuauLowerBoundsCalculation", true}, }; @@ -236,10 +235,14 @@ TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") CHECK_EQ("boolean", toString(err->table)); CHECK_EQ("x", err->key); - CHECK_EQ("*unknown*", toString(requireType("c"))); - CHECK_EQ("*unknown*", toString(requireType("d"))); - CHECK_EQ("*unknown*", toString(requireType("e"))); - CHECK_EQ("*unknown*", toString(requireType("f"))); + // TODO: Should we assert anything about these tests when DCR is being used? + if (!FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("*unknown*", toString(requireType("c"))); + CHECK_EQ("*unknown*", toString(requireType("d"))); + CHECK_EQ("*unknown*", toString(requireType("e"))); + CHECK_EQ("*unknown*", toString(requireType("f"))); + } } TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowing") @@ -352,40 +355,6 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") CHECK(nullptr != get(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "globals") -{ - CheckResult result = check(R"( - --!nonstrict - foo = true - foo = "now i'm a string!" - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("foo"))); -} - -TEST_CASE_FIXTURE(Fixture, "globals2") -{ - ScopedFastFlag sff[]{ - {"LuauReturnTypeInferenceInNonstrict", true}, - {"LuauLowerBoundsCalculation", true}, - }; - - CheckResult result = check(R"( - --!nonstrict - foo = function() return 1 end - foo = "now i'm a string!" - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("() -> number", toString(tm->wantedType)); - CHECK_EQ("string", toString(tm->givenType)); - CHECK_EQ("() -> number", toString(requireType("foo"))); -} - TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") { CheckResult result = check(R"( @@ -400,23 +369,6 @@ TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") CHECK_EQ("foo", us->name); } -TEST_CASE_FIXTURE(Fixture, "globals_everywhere") -{ - CheckResult result = check(R"( - --!nonstrict - foo = 1 - - if true then - bar = 2 - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("foo"))); - CHECK_EQ("any", toString(requireType("bar"))); -} - TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_do") { CheckResult result = check(R"( @@ -447,21 +399,6 @@ TEST_CASE_FIXTURE(Fixture, "checking_should_not_ice") CHECK_EQ("any", toString(requireType("value"))); } -// TEST_CASE_FIXTURE(Fixture, "infer_method_signature_of_argument") -// { -// CheckResult result = check(R"( -// function f(a) -// if a.cond then -// return a.method() -// end -// end -// )"); - -// LUAU_REQUIRE_NO_ERRORS(result); - -// CHECK_EQ("A", toString(requireType("f"))); -// } - TEST_CASE_FIXTURE(Fixture, "cyclic_follow") { check(R"( diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 118863f..bcd3049 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -26,7 +26,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_multi_return") const FunctionTypeVar* takeTwoType = get(requireType("take_two")); REQUIRE(takeTwoType != nullptr); - const auto& [returns, tail] = flatten(takeTwoType->retType); + const auto& [returns, tail] = flatten(takeTwoType->retTypes); CHECK_EQ(2, returns.size()); CHECK_EQ(typeChecker.numberType, follow(returns[0])); @@ -73,7 +73,7 @@ TEST_CASE_FIXTURE(Fixture, "last_element_of_return_statement_can_itself_be_a_pac const FunctionTypeVar* takeOneMoreType = get(requireType("take_three")); REQUIRE(takeOneMoreType != nullptr); - const auto& [rets, tail] = flatten(takeOneMoreType->retType); + const auto& [rets, tail] = flatten(takeOneMoreType->retTypes); REQUIRE_EQ(3, rets.size()); CHECK_EQ(typeChecker.numberType, follow(rets[0])); @@ -105,10 +105,10 @@ TEST_CASE_FIXTURE(Fixture, "return_type_should_be_empty_if_nothing_is_returned") LUAU_REQUIRE_NO_ERRORS(result); const FunctionTypeVar* fTy = get(requireType("f")); REQUIRE(fTy != nullptr); - CHECK_EQ(0, size(fTy->retType)); + CHECK_EQ(0, size(fTy->retTypes)); const FunctionTypeVar* gTy = get(requireType("g")); REQUIRE(gTy != nullptr); - CHECK_EQ(0, size(gTy->retType)); + CHECK_EQ(0, size(gTy->retTypes)); } TEST_CASE_FIXTURE(Fixture, "no_return_size_should_be_zero") @@ -125,15 +125,15 @@ TEST_CASE_FIXTURE(Fixture, "no_return_size_should_be_zero") const FunctionTypeVar* fTy = get(requireType("f")); REQUIRE(fTy != nullptr); - CHECK_EQ(1, size(follow(fTy->retType))); + CHECK_EQ(1, size(follow(fTy->retTypes))); const FunctionTypeVar* gTy = get(requireType("g")); REQUIRE(gTy != nullptr); - CHECK_EQ(0, size(gTy->retType)); + CHECK_EQ(0, size(gTy->retTypes)); const FunctionTypeVar* hTy = get(requireType("h")); REQUIRE(hTy != nullptr); - CHECK_EQ(0, size(hTy->retType)); + CHECK_EQ(0, size(hTy->retTypes)); } TEST_CASE_FIXTURE(Fixture, "varargs_inference_through_multiple_scopes") diff --git a/tools/natvis/Analysis.natvis b/tools/natvis/Analysis.natvis index 5de0140..b9ea314 100644 --- a/tools/natvis/Analysis.natvis +++ b/tools/natvis/Analysis.natvis @@ -6,40 +6,40 @@ - {{ index=0, value={*($T1*)storage} }} - {{ index=1, value={*($T2*)storage} }} - {{ index=2, value={*($T3*)storage} }} - {{ index=3, value={*($T4*)storage} }} - {{ index=4, value={*($T5*)storage} }} - {{ index=5, value={*($T6*)storage} }} - {{ index=6, value={*($T7*)storage} }} - {{ index=7, value={*($T8*)storage} }} - {{ index=8, value={*($T9*)storage} }} - {{ index=9, value={*($T10*)storage} }} - {{ index=10, value={*($T11*)storage} }} - {{ index=11, value={*($T12*)storage} }} - {{ index=12, value={*($T13*)storage} }} - {{ index=13, value={*($T14*)storage} }} - {{ index=14, value={*($T15*)storage} }} - {{ index=15, value={*($T16*)storage} }} - {{ index=16, value={*($T17*)storage} }} - {{ index=17, value={*($T18*)storage} }} - {{ index=18, value={*($T19*)storage} }} - {{ index=19, value={*($T20*)storage} }} - {{ index=20, value={*($T21*)storage} }} - {{ index=21, value={*($T22*)storage} }} - {{ index=22, value={*($T23*)storage} }} - {{ index=23, value={*($T24*)storage} }} - {{ index=24, value={*($T25*)storage} }} - {{ index=25, value={*($T26*)storage} }} - {{ index=26, value={*($T27*)storage} }} - {{ index=27, value={*($T28*)storage} }} - {{ index=28, value={*($T29*)storage} }} - {{ index=29, value={*($T30*)storage} }} - {{ index=30, value={*($T31*)storage} }} - {{ index=31, value={*($T32*)storage} }} + {{ typeId=0, value={*($T1*)storage} }} + {{ typeId=1, value={*($T2*)storage} }} + {{ typeId=2, value={*($T3*)storage} }} + {{ typeId=3, value={*($T4*)storage} }} + {{ typeId=4, value={*($T5*)storage} }} + {{ typeId=5, value={*($T6*)storage} }} + {{ typeId=6, value={*($T7*)storage} }} + {{ typeId=7, value={*($T8*)storage} }} + {{ typeId=8, value={*($T9*)storage} }} + {{ typeId=9, value={*($T10*)storage} }} + {{ typeId=10, value={*($T11*)storage} }} + {{ typeId=11, value={*($T12*)storage} }} + {{ typeId=12, value={*($T13*)storage} }} + {{ typeId=13, value={*($T14*)storage} }} + {{ typeId=14, value={*($T15*)storage} }} + {{ typeId=15, value={*($T16*)storage} }} + {{ typeId=16, value={*($T17*)storage} }} + {{ typeId=17, value={*($T18*)storage} }} + {{ typeId=18, value={*($T19*)storage} }} + {{ typeId=19, value={*($T20*)storage} }} + {{ typeId=20, value={*($T21*)storage} }} + {{ typeId=21, value={*($T22*)storage} }} + {{ typeId=22, value={*($T23*)storage} }} + {{ typeId=23, value={*($T24*)storage} }} + {{ typeId=24, value={*($T25*)storage} }} + {{ typeId=25, value={*($T26*)storage} }} + {{ typeId=26, value={*($T27*)storage} }} + {{ typeId=27, value={*($T28*)storage} }} + {{ typeId=28, value={*($T29*)storage} }} + {{ typeId=29, value={*($T30*)storage} }} + {{ typeId=30, value={*($T31*)storage} }} + {{ typeId=31, value={*($T32*)storage} }} - typeId + typeId *($T1*)storage *($T2*)storage *($T3*)storage From 6d14bdadf45c78b4aae695f636757d4cb4d65464 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 23 Jun 2022 18:44:07 -0700 Subject: [PATCH 17/19] Sync to upstream/release/533 --- Analysis/include/Luau/Constraint.h | 14 +- .../include/Luau/ConstraintGraphBuilder.h | 41 ++- Analysis/include/Luau/ConstraintSolver.h | 5 +- .../include/Luau/ConstraintSolverLogger.h | 4 +- Analysis/include/Luau/Error.h | 45 ++- Analysis/include/Luau/IostreamHelpers.h | 1 + Analysis/include/Luau/Module.h | 9 +- Analysis/include/Luau/Normalize.h | 4 +- Analysis/include/Luau/RecursionCounter.h | 15 +- Analysis/include/Luau/Scope.h | 18 + Analysis/include/Luau/TypeVar.h | 1 - Analysis/include/Luau/Unifiable.h | 1 + Analysis/include/Luau/Unifier.h | 4 - Analysis/include/Luau/VisitTypeVar.h | 2 +- Analysis/src/Clone.cpp | 4 +- Analysis/src/Constraint.cpp | 3 +- Analysis/src/ConstraintGraphBuilder.cpp | 285 +++++++++++++-- Analysis/src/ConstraintSolver.cpp | 35 +- Analysis/src/Error.cpp | 93 ++++- Analysis/src/Frontend.cpp | 2 + Analysis/src/IostreamHelpers.cpp | 2 + Analysis/src/Module.cpp | 1 - Analysis/src/Quantify.cpp | 9 +- Analysis/src/Scope.cpp | 32 ++ Analysis/src/ToString.cpp | 6 + Analysis/src/TypeChecker2.cpp | 173 +++++++++ Analysis/src/TypeInfer.cpp | 80 ++--- Analysis/src/TypeVar.cpp | 28 +- Analysis/src/Unifiable.cpp | 8 + Analysis/src/Unifier.cpp | 331 +----------------- CLI/Analyze.cpp | 4 + CMakeLists.txt | 9 + Common/include/Luau/Bytecode.h | 19 +- Compiler/include/Luau/BytecodeBuilder.h | 2 + Compiler/src/BytecodeBuilder.cpp | 16 +- Compiler/src/Compiler.cpp | 6 +- VM/src/ludata.cpp | 2 + VM/src/lvmload.cpp | 4 +- tests/Compiler.test.cpp | 34 +- tests/Fixture.h | 2 +- tests/Module.test.cpp | 1 - tests/Normalize.test.cpp | 1 - tests/RuntimeLimits.test.cpp | 7 +- tests/ToString.test.cpp | 2 - tests/TypeInfer.aliases.test.cpp | 87 ++++- tests/TypeInfer.annotations.test.cpp | 85 ++++- tests/TypeInfer.generics.test.cpp | 13 - tests/TypeInfer.modules.test.cpp | 20 +- tests/TypeInfer.refinements.test.cpp | 2 - tests/TypeInfer.singletons.test.cpp | 4 - tests/TypeInfer.tables.test.cpp | 73 ---- tests/TypeInfer.test.cpp | 3 - tests/TypeInfer.unionTypes.test.cpp | 6 - tests/VisitTypeVar.test.cpp | 9 +- 54 files changed, 972 insertions(+), 695 deletions(-) diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index c62166e..8a41c9e 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -1,10 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Location.h" #include "Luau/NotNull.h" #include "Luau/Variant.h" +#include #include #include @@ -47,18 +47,24 @@ struct InstantiationConstraint TypeId superType; }; -using ConstraintV = Variant; +// name(namedType) = name +struct NameConstraint +{ + TypeId namedType; + std::string name; +}; + +using ConstraintV = Variant; using ConstraintPtr = std::unique_ptr; struct Constraint { - Constraint(ConstraintV&& c, Location location); + explicit Constraint(ConstraintV&& c); Constraint(const Constraint&) = delete; Constraint& operator=(const Constraint&) = delete; ConstraintV c; - Location location; std::vector> dependencies; }; diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index da774a2..9b11869 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -17,20 +17,7 @@ namespace Luau { -struct Scope2 -{ - // The parent scope of this scope. Null if there is no parent (i.e. this - // is the module-level scope). - Scope2* parent = nullptr; - // All the children of this scope. - std::vector children; - std::unordered_map bindings; // TODO: I think this can be a DenseHashMap - TypePackId returnType; - // All constraints belonging to this scope. - std::vector constraints; - - std::optional lookup(Symbol sym); -}; +struct Scope2; struct ConstraintGraphBuilder { @@ -47,6 +34,10 @@ struct ConstraintGraphBuilder // A mapping of AST node to TypePackId. DenseHashMap astTypePacks{nullptr}; DenseHashMap astOriginalCallTypes{nullptr}; + // Types resolved from type annotations. Analogous to astTypes. + DenseHashMap astResolvedTypes{nullptr}; + // Type packs resolved from type annotations. Analogous to astTypePacks. + DenseHashMap astResolvedTypePacks{nullptr}; explicit ConstraintGraphBuilder(TypeArena* arena); @@ -73,9 +64,8 @@ struct ConstraintGraphBuilder * Adds a new constraint with no dependencies to a given scope. * @param scope the scope to add the constraint to. Must not be null. * @param cv the constraint variant to add. - * @param location the location to attribute to the constraint. */ - void addConstraint(Scope2* scope, ConstraintV cv, Location location); + void addConstraint(Scope2* scope, ConstraintV cv); /** * Adds a constraint to a given scope. @@ -99,6 +89,7 @@ struct ConstraintGraphBuilder void visit(Scope2* scope, AstStatReturn* ret); void visit(Scope2* scope, AstStatAssign* assign); void visit(Scope2* scope, AstStatIf* ifStatement); + void visit(Scope2* scope, AstStatTypeAlias* alias); TypePackId checkExprList(Scope2* scope, const AstArray& exprs); @@ -124,6 +115,24 @@ struct ConstraintGraphBuilder * @param fn the function expression to check. */ void checkFunctionBody(Scope2* scope, AstExprFunction* fn); + + /** + * Resolves a type from its AST annotation. + * @param scope the scope that the type annotation appears within. + * @param ty the AST annotation to resolve. + * @return the type of the AST annotation. + **/ + TypeId resolveType(Scope2* scope, AstType* ty); + + /** + * Resolves a type pack from its AST annotation. + * @param scope the scope that the type annotation appears within. + * @param tp the AST annotation to resolve. + * @return the type pack of the AST annotation. + **/ + TypePackId resolveTypePack(Scope2* scope, AstTypePack* tp); + + TypePackId resolveTypePack(Scope2* scope, const AstTypeList& list); }; /** diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 7e6d446..4870157 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -55,6 +55,7 @@ struct ConstraintSolver bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force); bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force); bool tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const NameConstraint& c, NotNull constraint); void block(NotNull target, NotNull constraint); /** @@ -85,7 +86,7 @@ struct ConstraintSolver * @param subType the sub-type to unify. * @param superType the super-type to unify. */ - void unify(TypeId subType, TypeId superType, Location location); + void unify(TypeId subType, TypeId superType); /** * Creates a new Unifier and performs a single unification operation. Commits @@ -93,7 +94,7 @@ struct ConstraintSolver * @param subPack the sub-type pack to unify. * @param superPack the super-type pack to unify. */ - void unify(TypePackId subPack, TypePackId superPack, Location location); + void unify(TypePackId subPack, TypePackId superPack); private: /** diff --git a/Analysis/include/Luau/ConstraintSolverLogger.h b/Analysis/include/Luau/ConstraintSolverLogger.h index 2b195d7..55336a2 100644 --- a/Analysis/include/Luau/ConstraintSolverLogger.h +++ b/Analysis/include/Luau/ConstraintSolverLogger.h @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/Constraint.h" +#include "Luau/NotNull.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" #include diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index b453067..a132396 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -169,6 +169,13 @@ struct GenericError bool operator==(const GenericError& rhs) const; }; +struct InternalError +{ + std::string message; + + bool operator==(const InternalError& rhs) const; +}; + struct CannotCallNonFunction { TypeId ty; @@ -293,12 +300,12 @@ struct NormalizationTooComplex } }; -using TypeErrorData = - Variant; +using TypeErrorData = Variant; struct TypeError { @@ -339,7 +346,13 @@ T* get(TypeError& e) using ErrorVec = std::vector; +struct TypeErrorToStringOptions +{ + FileResolver* fileResolver = nullptr; +}; + std::string toString(const TypeError& error); +std::string toString(const TypeError& error, TypeErrorToStringOptions options); bool containsParseErrorName(const TypeError& error); @@ -356,4 +369,24 @@ struct InternalErrorReporter [[noreturn]] void ice(const std::string& message); }; +class InternalCompilerError : public std::exception { +public: + explicit InternalCompilerError(const std::string& message, const std::string& moduleName) + : message(message) + , moduleName(moduleName) + { + } + explicit InternalCompilerError(const std::string& message, const std::string& moduleName, const Location& location) + : message(message) + , moduleName(moduleName) + , location(location) + { + } + virtual const char* what() const throw(); + + const std::string message; + const std::string moduleName; + const std::optional location; +}; + } // namespace Luau diff --git a/Analysis/include/Luau/IostreamHelpers.h b/Analysis/include/Luau/IostreamHelpers.h index ee99429..05b9451 100644 --- a/Analysis/include/Luau/IostreamHelpers.h +++ b/Analysis/include/Luau/IostreamHelpers.h @@ -30,6 +30,7 @@ std::ostream& operator<<(std::ostream& lhs, const OccursCheckFailed& error); std::ostream& operator<<(std::ostream& lhs, const UnknownRequire& error); std::ostream& operator<<(std::ostream& lhs, const UnknownPropButFoundLikeProp& e); std::ostream& operator<<(std::ostream& lhs, const GenericError& error); +std::ostream& operator<<(std::ostream& lhs, const InternalError& error); std::ostream& operator<<(std::ostream& lhs, const FunctionExitsWithoutReturning& error); std::ostream& operator<<(std::ostream& lhs, const MissingProperties& error); std::ostream& operator<<(std::ostream& lhs, const IllegalRequire& error); diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index e979b3f..39f8dfb 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -1,10 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Error.h" #include "Luau/FileResolver.h" #include "Luau/ParseOptions.h" -#include "Luau/Error.h" #include "Luau/ParseResult.h" +#include "Luau/Scope.h" #include "Luau/TypeArena.h" #include @@ -19,7 +20,9 @@ struct Module; using ScopePtr = std::shared_ptr; using ModulePtr = std::shared_ptr; -struct Scope2; + +class AstType; +class AstTypePack; /// Root of the AST of a parsed source file struct SourceModule @@ -73,6 +76,8 @@ struct Module DenseHashMap astExpectedTypes{nullptr}; DenseHashMap astOriginalCallTypes{nullptr}; DenseHashMap astOverloadResolvedTypes{nullptr}; + DenseHashMap astResolvedTypes{nullptr}; + DenseHashMap astResolvedTypePacks{nullptr}; std::unordered_map declaredGlobals; ErrorVec errors; diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index d4c7698..f5fd988 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -9,8 +9,8 @@ namespace Luau struct InternalErrorReporter; -bool isSubtype(TypeId superTy, TypeId subTy, InternalErrorReporter& ice); -bool isSubtype(TypePackId superTy, TypePackId subTy, InternalErrorReporter& ice); +bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice); +bool isSubtype(TypePackId subTy, TypePackId superTy, InternalErrorReporter& ice); std::pair normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice); std::pair normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice); diff --git a/Analysis/include/Luau/RecursionCounter.h b/Analysis/include/Luau/RecursionCounter.h index 03ae2c8..f964dbf 100644 --- a/Analysis/include/Luau/RecursionCounter.h +++ b/Analysis/include/Luau/RecursionCounter.h @@ -6,8 +6,6 @@ #include #include -LUAU_FASTFLAG(LuauRecursionLimitException); - namespace Luau { @@ -39,21 +37,12 @@ private: struct RecursionLimiter : RecursionCounter { - // TODO: remove ctx after LuauRecursionLimitException is removed - RecursionLimiter(int* count, int limit, const char* ctx) + RecursionLimiter(int* count, int limit) : RecursionCounter(count) { - LUAU_ASSERT(ctx); if (limit > 0 && *count > limit) { - if (FFlag::LuauRecursionLimitException) - throw RecursionLimitException(); - else - { - std::string m = "Internal recursion counter limit exceeded: "; - m += ctx; - throw std::runtime_error(m); - } + throw RecursionLimitException(); } } }; diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 4533840..cef4b94 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Constraint.h" #include "Luau/Location.h" #include "Luau/TypeVar.h" @@ -64,4 +65,21 @@ struct Scope std::unordered_map typeAliasTypePackParameters; }; +struct Scope2 +{ + // The parent scope of this scope. Null if there is no parent (i.e. this + // is the module-level scope). + Scope2* parent = nullptr; + // All the children of this scope. + std::vector children; + std::unordered_map bindings; // TODO: I think this can be a DenseHashMap + std::unordered_map typeBindings; + TypePackId returnType; + // All constraints belonging to this scope. + std::vector constraints; + + std::optional lookup(Symbol sym); + std::optional lookupTypeBinding(const Name& name); +}; + } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index ff7708d..20f4107 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -287,7 +287,6 @@ struct FunctionTypeVar bool hasSelf; Tags tags; bool hasNoGenerics = false; - bool generalized = false; }; enum class TableState diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index fdc3948..4ff9171 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -117,6 +117,7 @@ struct Generic explicit Generic(const Name& name); explicit Generic(Scope2* scope); Generic(TypeLevel level, const Name& name); + Generic(Scope2* scope, const Name& name); int index; TypeLevel level; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index b51a485..4af324c 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -79,12 +79,8 @@ private: void tryUnifySingletons(TypeId subTy, TypeId superTy); void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); - void DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); - void tryUnifyFreeTable(TypeId subTy, TypeId superTy); - void tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection); void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); - void tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer); TypeId widen(TypeId ty); TypePackId widen(TypePackId tp); diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 642522c..5fd43f0 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -169,7 +169,7 @@ struct GenericTypeVarVisitor void traverse(TypeId ty) { - RecursionLimiter limiter{&recursionCounter, FInt::LuauVisitRecursionLimit, "TypeVarVisitor"}; + RecursionLimiter limiter{&recursionCounter, FInt::LuauVisitRecursionLimit}; if (visit_detail::hasSeen(seen, ty)) { diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 248262c..df4e0a6 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -317,7 +317,7 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) if (tp->persistent) return tp; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit, "cloning TypePackId"); + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); TypePackId& res = cloneState.seenTypePacks[tp]; @@ -335,7 +335,7 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) if (typeId->persistent) return typeId; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit, "cloning TypeId"); + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); TypeId& res = cloneState.seenTypes[typeId]; diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index 6cb0e4e..64e3a66 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -5,9 +5,8 @@ namespace Luau { -Constraint::Constraint(ConstraintV&& c, Location location) +Constraint::Constraint(ConstraintV&& c) : c(std::move(c)) - , location(location) { } diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index fa627e7..d9e8d23 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -2,28 +2,13 @@ #include "Luau/ConstraintGraphBuilder.h" +#include "Luau/Scope.h" + namespace Luau { const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp -std::optional Scope2::lookup(Symbol sym) -{ - Scope2* s = this; - - while (true) - { - auto it = s->bindings.find(sym); - if (it != s->bindings.end()) - return it->second; - - if (s->parent) - s = s->parent; - else - return std::nullopt; - } -} - ConstraintGraphBuilder::ConstraintGraphBuilder(TypeArena* arena) : singletonTypes(getSingletonTypes()) , arena(arena) @@ -59,10 +44,10 @@ Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent) return borrow; } -void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv, Location location) +void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv) { LUAU_ASSERT(scope); - scope->constraints.emplace_back(new Constraint{std::move(cv), location}); + scope->constraints.emplace_back(new Constraint{std::move(cv)}); } void ConstraintGraphBuilder::addConstraint(Scope2* scope, std::unique_ptr c) @@ -79,6 +64,13 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) rootScope = scopes.back().second.get(); rootScope->returnType = freshTypePack(rootScope); + // TODO: We should share the global scope. + rootScope->typeBindings["nil"] = singletonTypes.nilType; + rootScope->typeBindings["number"] = singletonTypes.numberType; + rootScope->typeBindings["string"] = singletonTypes.stringType; + rootScope->typeBindings["boolean"] = singletonTypes.booleanType; + rootScope->typeBindings["thread"] = singletonTypes.threadType; + visit(rootScope, block); } @@ -102,6 +94,8 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat) checkPack(scope, e->expr); else if (auto i = stat->as()) visit(scope, i); + else if (auto a = stat->as()) + visit(scope, a); else LUAU_ASSERT(0); } @@ -114,8 +108,14 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) for (AstLocal* local : local->vars) { - // TODO annotations TypeId ty = freshType(scope); + + if (local->annotation) + { + TypeId annotation = resolveType(scope, local->annotation); + addConstraint(scope, SubtypeConstraint{ty, annotation}); + } + varTypes.push_back(ty); scope->bindings[local] = ty; } @@ -136,14 +136,14 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) { std::vector tailValues{varTypes.begin() + i, varTypes.end()}; TypePackId tailPack = arena->addTypePack(std::move(tailValues)); - addConstraint(scope, PackSubtypeConstraint{exprPack, tailPack}, local->location); + addConstraint(scope, PackSubtypeConstraint{exprPack, tailPack}); } } else { TypeId exprType = check(scope, local->values.data[i]); if (i < varTypes.size()) - addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}, local->vars.data[i]->location); + addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}); } } } @@ -188,7 +188,7 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function checkFunctionBody(innerScope, function->func); - std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}, function->location}}; + std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; addConstraints(c.get(), innerScope); addConstraint(scope, std::move(c)); @@ -240,7 +240,7 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) checkFunctionBody(innerScope, function->func); - std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}, function->location}}; + std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; addConstraints(c.get(), innerScope); addConstraint(scope, std::move(c)); @@ -251,13 +251,26 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatReturn* ret) LUAU_ASSERT(scope); TypePackId exprTypes = checkPack(scope, ret->list); - addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}, ret->location); + addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}); } void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) { LUAU_ASSERT(scope); + // In order to enable mutually-recursive type aliases, we need to + // populate the type bindings before we actually check any of the + // alias statements. Since we're not ready to actually resolve + // any of the annotations, we just use a fresh type for now. + for (AstStat* stat : block->body) + { + if (auto alias = stat->as()) + { + TypeId initialType = freshType(scope); + scope->typeBindings[alias->name.value] = initialType; + } + } + for (AstStat* stat : block->body) visit(scope, stat); } @@ -267,7 +280,7 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign) TypePackId varPackId = checkExprList(scope, assign->vars); TypePackId valuePack = checkPack(scope, assign->values); - addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}, assign->location); + addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}); } void ConstraintGraphBuilder::visit(Scope2* scope, AstStatIf* ifStatement) @@ -284,6 +297,28 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatIf* ifStatement) } } +void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias) +{ + // TODO: Exported type aliases + // TODO: Generic type aliases + + auto it = scope->typeBindings.find(alias->name.value); + // This should always be here since we do a separate pass over the + // AST to set up typeBindings. If it's not, we've somehow skipped + // this alias in that first pass. + LUAU_ASSERT(it != scope->typeBindings.end()); + + TypeId ty = resolveType(scope, alias->type); + + // Rather than using a subtype constraint, we instead directly bind + // the free type we generated in the first pass to the resolved type. + // This prevents a case where you could cause another constraint to + // bind the free alias type to an unrelated type, causing havoc. + asMutable(it->second)->ty.emplace(ty); + + addConstraint(scope, NameConstraint{ty, alias->name.value}); +} + TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray exprs) { LUAU_ASSERT(scope); @@ -350,13 +385,13 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) astOriginalCallTypes[call->func] = fnType; TypeId instantiatedType = freshType(scope); - addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}, expr->location); + addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}); TypePackId rets = freshTypePack(scope); FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); TypeId inferredFnType = arena->addType(ftv); - addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}, expr->location); + addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}); result = rets; } else @@ -413,7 +448,7 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) TypePack onePack{{typeResult}, freshTypePack(scope)}; TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); - addConstraint(scope, PackSubtypeConstraint{packResult, oneTypePack}, expr->location); + addConstraint(scope, PackSubtypeConstraint{packResult, oneTypePack}); return typeResult; } @@ -454,7 +489,7 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName) TypeId expectedTableType = arena->addType(std::move(ttv)); - addConstraint(scope, SubtypeConstraint{obj, expectedTableType}, indexName->location); + addConstraint(scope, SubtypeConstraint{obj, expectedTableType}); return result; } @@ -465,8 +500,7 @@ TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) TableTypeVar* ttv = getMutable(ty); LUAU_ASSERT(ttv); - auto createIndexer = [this, scope, ttv]( - TypeId currentIndexType, TypeId currentResultType, Location itemLocation, std::optional keyLocation) { + auto createIndexer = [this, scope, ttv](TypeId currentIndexType, TypeId currentResultType) { if (!ttv->indexer) { TypeId indexType = this->freshType(scope); @@ -474,8 +508,8 @@ TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) ttv->indexer = TableIndexer{indexType, resultType}; } - addConstraint(scope, SubtypeConstraint{ttv->indexer->indexType, currentIndexType}, keyLocation ? *keyLocation : itemLocation); - addConstraint(scope, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}, itemLocation); + addConstraint(scope, SubtypeConstraint{ttv->indexer->indexType, currentIndexType}); + addConstraint(scope, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}); }; for (const AstExprTable::Item& item : expr->items) @@ -495,13 +529,13 @@ TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) } else { - createIndexer(keyTy, itemTy, item.value->location, item.key->location); + createIndexer(keyTy, itemTy); } } else { TypeId numberType = singletonTypes.numberType; - createIndexer(numberType, itemTy, item.value->location, std::nullopt); + createIndexer(numberType, itemTy); } } @@ -514,15 +548,29 @@ std::pair ConstraintGraphBuilder::checkFunctionSignature(Scope2 TypePackId returnType = freshTypePack(innerScope); innerScope->returnType = returnType; + if (fn->returnAnnotation) + { + TypePackId annotatedRetType = resolveTypePack(innerScope, *fn->returnAnnotation); + addConstraint(innerScope, PackSubtypeConstraint{returnType, annotatedRetType}); + } + std::vector argTypes; for (AstLocal* local : fn->args) { TypeId t = freshType(innerScope); argTypes.push_back(t); - innerScope->bindings[local] = t; // TODO annotations + innerScope->bindings[local] = t; + + if (local->annotation) + { + TypeId argAnnotation = resolveType(innerScope, local->annotation); + addConstraint(innerScope, SubtypeConstraint{t, argAnnotation}); + } } + // TODO: Vararg annotation. + FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; TypeId actualFunctionType = arena->addType(std::move(actualFunction)); LUAU_ASSERT(actualFunctionType); @@ -541,10 +589,171 @@ void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* f if (nullptr != getFallthrough(fn->body)) { TypePackId empty = arena->addTypePack({}); // TODO we could have CSG retain one of these forever - addConstraint(scope, PackSubtypeConstraint{scope->returnType, empty}, fn->body->location); + addConstraint(scope, PackSubtypeConstraint{scope->returnType, empty}); } } +TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty) +{ + TypeId result = nullptr; + + if (auto ref = ty->as()) + { + // TODO: Support imported types w/ require tracing. + // TODO: Support generic type references. + LUAU_ASSERT(!ref->prefix); + LUAU_ASSERT(!ref->hasParameterList); + + // TODO: If it doesn't exist, should we introduce a free binding? + // This is probably important for handling type aliases. + result = scope->lookupTypeBinding(ref->name.value).value_or(singletonTypes.errorRecoveryType()); + } + else if (auto tab = ty->as()) + { + TableTypeVar::Props props; + std::optional indexer; + + for (const AstTableProp& prop : tab->props) + { + std::string name = prop.name.value; + // TODO: Recursion limit. + TypeId propTy = resolveType(scope, prop.type); + // TODO: Fill in location. + props[name] = {propTy}; + } + + if (tab->indexer) + { + // TODO: Recursion limit. + indexer = TableIndexer{ + resolveType(scope, tab->indexer->indexType), + resolveType(scope, tab->indexer->resultType), + }; + } + + // TODO: Remove TypeLevel{} here, we don't need it. + result = arena->addType(TableTypeVar{props, indexer, TypeLevel{}, TableState::Sealed}); + } + else if (auto fn = ty->as()) + { + // TODO: Generic functions. + // TODO: Scope (though it may not be needed). + // TODO: Recursion limit. + TypePackId argTypes = resolveTypePack(scope, fn->argTypes); + TypePackId returnTypes = resolveTypePack(scope, fn->returnTypes); + + // TODO: Is this the right constructor to use? + result = arena->addType(FunctionTypeVar{argTypes, returnTypes}); + + FunctionTypeVar* ftv = getMutable(result); + ftv->argNames.reserve(fn->argNames.size); + for (const auto& el : fn->argNames) + { + if (el) + { + const auto& [name, location] = *el; + ftv->argNames.push_back(FunctionArgument{name.value, location}); + } + else + { + ftv->argNames.push_back(std::nullopt); + } + } + } + else if (auto tof = ty->as()) + { + // TODO: Recursion limit. + TypeId exprType = check(scope, tof->expr); + result = exprType; + } + else if (auto unionAnnotation = ty->as()) + { + std::vector parts; + for (AstType* part : unionAnnotation->types) + { + // TODO: Recursion limit. + parts.push_back(resolveType(scope, part)); + } + + result = arena->addType(UnionTypeVar{parts}); + } + else if (auto intersectionAnnotation = ty->as()) + { + std::vector parts; + for (AstType* part : intersectionAnnotation->types) + { + // TODO: Recursion limit. + parts.push_back(resolveType(scope, part)); + } + + result = arena->addType(IntersectionTypeVar{parts}); + } + else if (auto boolAnnotation = ty->as()) + { + result = arena->addType(SingletonTypeVar(BooleanSingleton{boolAnnotation->value})); + } + else if (auto stringAnnotation = ty->as()) + { + result = arena->addType(SingletonTypeVar(StringSingleton{std::string(stringAnnotation->value.data, stringAnnotation->value.size)})); + } + else if (ty->is()) + { + result = singletonTypes.errorRecoveryType(); + } + else + { + LUAU_ASSERT(0); + result = singletonTypes.errorRecoveryType(); + } + + astResolvedTypes[ty] = result; + return result; +} + +TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, AstTypePack* tp) +{ + TypePackId result; + if (auto expl = tp->as()) + { + result = resolveTypePack(scope, expl->typeList); + } + else if (auto var = tp->as()) + { + TypeId ty = resolveType(scope, var->variadicType); + result = arena->addTypePack(TypePackVar{VariadicTypePack{ty}}); + } + else if (auto gen = tp->as()) + { + result = arena->addTypePack(TypePackVar{GenericTypePack{scope, gen->genericName.value}}); + } + else + { + LUAU_ASSERT(0); + result = singletonTypes.errorRecoveryTypePack(); + } + + astResolvedTypePacks[tp] = result; + return result; +} + +TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, const AstTypeList& list) +{ + std::vector head; + + for (AstType* headTy : list.types) + { + head.push_back(resolveType(scope, headTy)); + } + + std::optional tail = std::nullopt; + if (list.tailType) + { + tail = resolveTypePack(scope, list.tailType); + } + + return arena->addTypePack(TypePack{head, tail}); +} + void collectConstraints(std::vector>& result, Scope2* scope) { for (const auto& c : scope->constraints) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 41dfd89..9e35523 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -2,6 +2,7 @@ #include "Luau/ConstraintSolver.h" #include "Luau/Instantiation.h" +#include "Luau/Location.h" #include "Luau/Quantify.h" #include "Luau/ToString.h" #include "Luau/Unifier.h" @@ -179,6 +180,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*gc, constraint, force); else if (auto ic = get(*constraint)) success = tryDispatch(*ic, constraint, force); + else if (auto nc = get(*constraint)) + success = tryDispatch(*nc, constraint); else LUAU_ASSERT(0); @@ -197,7 +200,7 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNulllocation); + unify(c.subType, c.superType); unblock(c.subType); unblock(c.superType); @@ -207,7 +210,7 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) { - unify(c.subPack, c.superPack, constraint->location); + unify(c.subPack, c.superPack); unblock(c.subPack); unblock(c.superPack); @@ -222,7 +225,7 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNullty.emplace(c.sourceType); else - unify(c.generalizedType, c.sourceType, constraint->location); + unify(c.generalizedType, c.sourceType); TypeId generalized = quantify(arena, c.sourceType, c.scope); *asMutable(c.sourceType) = *generalized; @@ -243,12 +246,28 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull instantiated = inst.substitute(c.superType); LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS - unify(c.subType, *instantiated, constraint->location); + unify(c.subType, *instantiated); unblock(c.subType); return true; } +bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull constraint) +{ + if (isBlocked(c.namedType)) + return block(c.namedType, constraint); + + TypeId target = follow(c.namedType); + if (TableTypeVar* ttv = getMutable(target)) + ttv->name = c.name; + else if (MetatableTypeVar* mtv = getMutable(target)) + mtv->syntheticName = c.name; + else + return block(c.namedType, constraint); + + return true; +} + void ConstraintSolver::block_(BlockedConstraintId target, NotNull constraint) { blocked[target].push_back(constraint); @@ -321,19 +340,19 @@ bool ConstraintSolver::isBlocked(NotNull constraint) return blockedIt != blockedConstraints.end() && blockedIt->second > 0; } -void ConstraintSolver::unify(TypeId subType, TypeId superType, Location location) +void ConstraintSolver::unify(TypeId subType, TypeId superType) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, Mode::Strict, location, Covariant, sharedState}; + Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; u.tryUnify(subType, superType); u.log.commit(); } -void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, Location location) +void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, Mode::Strict, location, Covariant, sharedState}; + Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; u.tryUnify(subPack, superPack); u.log.commit(); diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index f443a3c..93cb65b 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,6 +7,9 @@ #include +LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleNameResolution, false) +LUAU_FASTFLAGVARIABLE(LuauUseInternalCompilerErrorException, false) + static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { std::string s = "expects "; @@ -49,6 +52,8 @@ namespace Luau struct ErrorConverter { + FileResolver* fileResolver = nullptr; + std::string operator()(const Luau::TypeMismatch& tm) const { std::string givenTypeName = Luau::toString(tm.givenType); @@ -62,8 +67,18 @@ struct ErrorConverter { if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) { - result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + - "' from '" + *wantedDefinitionModule + "'"; + if (FFlag::LuauTypeMismatchModuleNameResolution && fileResolver != nullptr) + { + std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule); + std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule); + result = "Type '" + givenTypeName + "' from '" + givenModuleName + "' could not be converted into '" + wantedTypeName + + "' from '" + wantedModuleName + "'"; + } + else + { + result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + + "' from '" + *wantedDefinitionModule + "'"; + } } } } @@ -78,7 +93,14 @@ struct ErrorConverter if (!tm.reason.empty()) result += tm.reason + " "; - result += Luau::toString(*tm.error); + if (FFlag::LuauTypeMismatchModuleNameResolution) + { + result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver}); + } + else + { + result += Luau::toString(*tm.error); + } } else if (!tm.reason.empty()) { @@ -280,6 +302,11 @@ struct ErrorConverter return e.message; } + std::string operator()(const Luau::InternalError& e) const + { + return e.message; + } + std::string operator()(const Luau::CannotCallNonFunction& e) const { return "Cannot call non-function " + toString(e.ty); @@ -598,6 +625,11 @@ bool GenericError::operator==(const GenericError& rhs) const return message == rhs.message; } +bool InternalError::operator==(const InternalError& rhs) const +{ + return message == rhs.message; +} + bool CannotCallNonFunction::operator==(const CannotCallNonFunction& rhs) const { return ty == rhs.ty; @@ -685,7 +717,12 @@ bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const std::string toString(const TypeError& error) { - ErrorConverter converter; + return toString(error, TypeErrorToStringOptions{}); +} + +std::string toString(const TypeError& error, TypeErrorToStringOptions options) +{ + ErrorConverter converter{options.fileResolver}; return Luau::visit(converter, error.data); } @@ -773,6 +810,9 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState) else if constexpr (std::is_same_v) { } + else if constexpr (std::is_same_v) + { + } else if constexpr (std::is_same_v) { e.ty = clone(e.ty); @@ -847,22 +887,51 @@ void copyErrors(ErrorVec& errors, TypeArena& destArena) void InternalErrorReporter::ice(const std::string& message, const Location& location) { - std::runtime_error error("Internal error in " + moduleName + " at " + toString(location) + ": " + message); + if (FFlag::LuauUseInternalCompilerErrorException) + { + InternalCompilerError error(message, moduleName, location); - if (onInternalError) - onInternalError(error.what()); + if (onInternalError) + onInternalError(error.what()); - throw error; + throw error; + } + else + { + std::runtime_error error("Internal error in " + moduleName + " at " + toString(location) + ": " + message); + + if (onInternalError) + onInternalError(error.what()); + + throw error; + } } void InternalErrorReporter::ice(const std::string& message) { - std::runtime_error error("Internal error in " + moduleName + ": " + message); + if (FFlag::LuauUseInternalCompilerErrorException) + { + InternalCompilerError error(message, moduleName); - if (onInternalError) - onInternalError(error.what()); + if (onInternalError) + onInternalError(error.what()); - throw error; + throw error; + } + else + { + std::runtime_error error("Internal error in " + moduleName + ": " + message); + + if (onInternalError) + onInternalError(error.what()); + + throw error; + } +} + +const char* InternalCompilerError::what() const throw() +{ + return this->message.data(); } } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 9e02506..85c5dbc 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -801,6 +801,8 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const Sco result->astTypes = std::move(cgb.astTypes); result->astTypePacks = std::move(cgb.astTypePacks); result->astOriginalCallTypes = std::move(cgb.astOriginalCallTypes); + result->astResolvedTypes = std::move(cgb.astResolvedTypes); + result->astResolvedTypePacks = std::move(cgb.astResolvedTypePacks); result->clonePublicInterface(iceHandler); diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 048167a..e4fac45 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -111,6 +111,8 @@ static void errorToString(std::ostream& stream, const T& err) } else if constexpr (std::is_same_v) stream << "GenericError { " << err.message << " }"; + else if constexpr (std::is_same_v) + stream << "InternalError { " << err.message << " }"; else if constexpr (std::is_same_v) stream << "CannotCallNonFunction { " << toString(err.ty) << " }"; else if constexpr (std::is_same_v) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 4d157e6..95eb125 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -11,7 +11,6 @@ #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" -#include "Luau/ConstraintGraphBuilder.h" // FIXME: For Scope2 TODO pull out into its own header #include diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 2004d15..40e14c6 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -2,11 +2,10 @@ #include "Luau/Quantify.h" -#include "Luau/ConstraintGraphBuilder.h" // TODO for Scope2; move to separate header -#include "Luau/TxnLog.h" +#include "Luau/Scope.h" #include "Luau/Substitution.h" +#include "Luau/TxnLog.h" #include "Luau/VisitTypeVar.h" -#include "Luau/ConstraintGraphBuilder.h" // TODO for Scope2; move to separate header LUAU_FASTFLAG(LuauAlwaysQuantify); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); @@ -177,8 +176,6 @@ void quantify(TypeId ty, TypeLevel level) if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) ftv->hasNoGenerics = true; - - ftv->generalized = true; } void quantify(TypeId ty, Scope2* scope) @@ -201,8 +198,6 @@ void quantify(TypeId ty, Scope2* scope) if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) ftv->hasNoGenerics = true; - - ftv->generalized = true; } struct PureQuantifier : Substitution diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 011e28d..66aaee1 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -121,4 +121,36 @@ std::optional Scope::linearSearchForBinding(const std::string& name, bo return std::nullopt; } +std::optional Scope2::lookup(Symbol sym) +{ + Scope2* s = this; + + while (true) + { + auto it = s->bindings.find(sym); + if (it != s->bindings.end()) + return it->second; + + if (s->parent) + s = s->parent; + else + return std::nullopt; + } +} + +std::optional Scope2::lookupTypeBinding(const Name& name) +{ + Scope2* s = this; + while (s) + { + auto it = s->typeBindings.find(name); + if (it != s->typeBindings.end()) + return it->second; + + s = s->parent; + } + + return std::nullopt; +} + } // namespace Luau diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index eee0dee..eb7b9cd 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1401,6 +1401,12 @@ std::string toString(const Constraint& c, ToStringOptions& opts) opts.nameMap = std::move(superStr.nameMap); return subStr.name + " ~ inst " + superStr.name; } + else if (const NameConstraint* nc = Luau::get(c)) + { + ToStringResult namedStr = toStringDetailed(nc->namedType, opts); + opts.nameMap = std::move(namedStr.nameMap); + return "@name(" + namedStr.name + ") = " + nc->name; + } else { LUAU_ASSERT(false); diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 7f5ba68..63e5800 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -7,6 +7,9 @@ #include "Luau/AstQuery.h" #include "Luau/Clone.h" #include "Luau/Normalize.h" +#include "Luau/ConstraintGraphBuilder.h" // FIXME move Scope2 into its own header +#include "Luau/Unifier.h" +#include "Luau/ToString.h" namespace Luau { @@ -39,6 +42,104 @@ struct TypeChecker2 : public AstVisitor return follow(*ty); } + TypeId lookupAnnotation(AstType* annotation) + { + TypeId* ty = module->astResolvedTypes.find(annotation); + LUAU_ASSERT(ty); + return follow(*ty); + } + + TypePackId reconstructPack(AstArray exprs, TypeArena& arena) + { + std::vector head; + + for (size_t i = 0; i < exprs.size - 1; ++i) + { + head.push_back(lookupType(exprs.data[i])); + } + + TypePackId tail = lookupPack(exprs.data[exprs.size - 1]); + return arena.addTypePack(TypePack{head, tail}); + } + + Scope2* findInnermostScope(Location location) + { + Scope2* bestScope = module->getModuleScope2(); + Location bestLocation = module->scope2s[0].first; + + for (size_t i = 0; i < module->scope2s.size(); ++i) + { + auto& [scopeBounds, scope] = module->scope2s[i]; + if (scopeBounds.encloses(location)) + { + if (scopeBounds.begin > bestLocation.begin || scopeBounds.end < bestLocation.end) + { + bestScope = scope.get(); + bestLocation = scopeBounds; + } + } + else + { + // TODO: Is this sound? This relies on the fact that scopes are inserted + // into the scope list in the order that they appear in the AST. + break; + } + } + + return bestScope; + } + + bool visit(AstStatLocal* local) override + { + for (size_t i = 0; i < local->values.size; ++i) + { + AstExpr* value = local->values.data[i]; + if (i == local->values.size - 1) + { + if (i < local->values.size) + { + TypePackId valueTypes = lookupPack(value); + auto it = begin(valueTypes); + for (size_t j = i; j < local->vars.size; ++j) + { + if (it == end(valueTypes)) + { + break; + } + + AstLocal* var = local->vars.data[i]; + if (var->annotation) + { + TypeId varType = lookupAnnotation(var->annotation); + if (!isSubtype(*it, varType, ice)) + { + reportError(TypeMismatch{varType, *it}, value->location); + } + } + + ++it; + } + } + } + else + { + TypeId valueType = lookupType(value); + AstLocal* var = local->vars.data[i]; + + if (var->annotation) + { + TypeId varType = lookupAnnotation(var->annotation); + if (!isSubtype(varType, valueType, ice)) + { + reportError(TypeMismatch{varType, valueType}, value->location); + } + } + } + } + + return true; + } + bool visit(AstStatAssign* assign) override { size_t count = std::min(assign->vars.size, assign->values.size); @@ -62,6 +163,30 @@ struct TypeChecker2 : public AstVisitor return true; } + bool visit(AstStatReturn* ret) override + { + Scope2* scope = findInnermostScope(ret->location); + TypePackId expectedRetType = scope->returnType; + + TypeArena arena; + TypePackId actualRetType = reconstructPack(ret->list, arena); + + UnifierSharedState sharedState{&ice}; + Unifier u{&arena, Mode::Strict, ret->location, Covariant, sharedState}; + u.anyIsTop = true; + + u.tryUnify(actualRetType, expectedRetType); + const bool ok = u.errors.empty() && u.log.empty(); + + if (!ok) + { + for (const TypeError& e : u.errors) + module->errors.push_back(e); + } + + return true; + } + bool visit(AstExprCall* call) override { TypePackId expectedRetType = lookupPack(call); @@ -91,6 +216,35 @@ struct TypeChecker2 : public AstVisitor return true; } + bool visit(AstExprFunction* fn) override + { + TypeId inferredFnTy = lookupType(fn); + const FunctionTypeVar* inferredFtv = get(inferredFnTy); + LUAU_ASSERT(inferredFtv); + + auto argIt = begin(inferredFtv->argTypes); + for (const auto& arg : fn->args) + { + if (argIt == end(inferredFtv->argTypes)) + break; + + if (arg->annotation) + { + TypeId inferredArgTy = *argIt; + TypeId annotatedArgTy = lookupAnnotation(arg->annotation); + + if (!isSubtype(annotatedArgTy, inferredArgTy, ice)) + { + reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); + } + } + + ++argIt; + } + + return true; + } + bool visit(AstExprIndexName* indexName) override { TypeId leftType = lookupType(indexName->expr); @@ -144,6 +298,25 @@ struct TypeChecker2 : public AstVisitor return true; } + bool visit(AstType* ty) override + { + return true; + } + + bool visit(AstTypeReference* ty) override + { + Scope2* scope = findInnermostScope(ty->location); + + // TODO: Imported types + // TODO: Generic types + if (!scope->lookupTypeBinding(ty->name.value)) + { + reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location); + } + + return true; + } + void reportError(TypeErrorData&& data, const Location& location) { module->errors.emplace_back(location, sourceModule->name, std::move(data)); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index fd1b3b8..44635e8 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -35,13 +35,9 @@ LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix2, false) LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) -LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) -LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) -LUAU_FASTFLAGVARIABLE(LuauRecursionLimitException, false); -LUAU_FASTFLAGVARIABLE(LuauApplyTypeFunctionFix, false); LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) LUAU_FASTFLAG(LuauQuantifyConstrained) @@ -275,22 +271,15 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) { - if (FFlag::LuauRecursionLimitException) - { - try - { - return checkWithoutRecursionCheck(module, mode, environmentScope); - } - catch (const RecursionLimitException&) - { - reportErrorCodeTooComplex(module.root->location); - return std::move(currentModule); - } - } - else + try { return checkWithoutRecursionCheck(module, mode, environmentScope); } + catch (const RecursionLimitException&) + { + reportErrorCodeTooComplex(module.root->location); + return std::move(currentModule); + } } ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mode mode, std::optional environmentScope) @@ -445,22 +434,15 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) reportErrorCodeTooComplex(block.location); return; } - if (FFlag::LuauRecursionLimitException) - { - try - { - checkBlockWithoutRecursionCheck(scope, block); - } - catch (const RecursionLimitException&) - { - reportErrorCodeTooComplex(block.location); - return; - } - } - else + try { checkBlockWithoutRecursionCheck(scope, block); } + catch (const RecursionLimitException&) + { + reportErrorCodeTooComplex(block.location); + return; + } } void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) @@ -1917,7 +1899,7 @@ std::optional TypeChecker::getIndexTypeFromType( for (TypeId t : utv) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "getIndexTypeForType unions"); + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); // Not needed when we normalize types. if (get(follow(t))) @@ -1967,7 +1949,7 @@ std::optional TypeChecker::getIndexTypeFromType( for (TypeId t : itv->parts) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit, "getIndexTypeFromType intersections"); + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) parts.push_back(*ty); @@ -2190,7 +2172,7 @@ TypeId TypeChecker::checkExprTable( } } - TableState state = (expr.items.size == 0 || isNonstrictMode() || FFlag::LuauUnsealedTableLiteral) ? TableState::Unsealed : TableState::Sealed; + TableState state = TableState::Unsealed; TableTypeVar table = TableTypeVar{std::move(props), indexer, scope->level, state}; table.definitionModuleName = currentModuleName; return addType(table); @@ -5175,9 +5157,7 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack bool ApplyTypeFunction::isDirty(TypeId ty) { - if (FFlag::LuauApplyTypeFunctionFix && typeArguments.count(ty)) - return true; - else if (!FFlag::LuauApplyTypeFunctionFix && get(ty)) + if (typeArguments.count(ty)) return true; else if (const FreeTypeVar* ftv = get(ty)) { @@ -5191,9 +5171,7 @@ bool ApplyTypeFunction::isDirty(TypeId ty) bool ApplyTypeFunction::isDirty(TypePackId tp) { - if (FFlag::LuauApplyTypeFunctionFix && typePackArguments.count(tp)) - return true; - else if (!FFlag::LuauApplyTypeFunctionFix && get(tp)) + if (typePackArguments.count(tp)) return true; else return false; @@ -5218,29 +5196,15 @@ bool ApplyTypeFunction::ignoreChildren(TypePackId tp) TypeId ApplyTypeFunction::clean(TypeId ty) { TypeId& arg = typeArguments[ty]; - if (FFlag::LuauApplyTypeFunctionFix) - { - LUAU_ASSERT(arg); - return arg; - } - else if (arg) - return arg; - else - return addType(FreeTypeVar{level}); + LUAU_ASSERT(arg); + return arg; } TypePackId ApplyTypeFunction::clean(TypePackId tp) { TypePackId& arg = typePackArguments[tp]; - if (FFlag::LuauApplyTypeFunctionFix) - { - LUAU_ASSERT(arg); - return arg; - } - else if (arg) - return arg; - else - return addTypePack(FreeTypePack{level}); + LUAU_ASSERT(arg); + return arg; } TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, @@ -5273,7 +5237,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, TypeId target = follow(instantiated); bool needsClone = follow(tf.type) == target; - bool shouldMutate = (!FFlag::LuauOnlyMutateInstantiatedTables || getTableType(tf.type)); + bool shouldMutate = getTableType(tf.type); TableTypeVar* ttv = getMutableTableType(target); if (shouldMutate && ttv && needsClone) diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 5776293..ade70d7 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -23,7 +23,6 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) namespace Luau @@ -172,22 +171,15 @@ bool isString(TypeId ty) // Returns true when ty is a supertype of string bool maybeString(TypeId ty) { - if (FFlag::LuauSubtypingAddOptPropsToUnsealedTables) - { - ty = follow(ty); + ty = follow(ty); - if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) - return true; + if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) + return true; - if (auto utv = get(ty)) - return std::any_of(begin(utv), end(utv), maybeString); + if (auto utv = get(ty)) + return std::any_of(begin(utv), end(utv), maybeString); - return false; - } - else - { - return isString(ty); - } + return false; } bool isThread(TypeId ty) @@ -369,7 +361,7 @@ bool maybeSingleton(TypeId ty) bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) { - RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit, "hasLength"); + RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit); ty = follow(ty); @@ -750,13 +742,15 @@ TypeId SingletonTypes::makeStringMetatable() TableTypeVar::Props stringLib = { {"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, {"char", {arena->addType(FunctionTypeVar{numberVariadicList, arena->addTypePack({stringType})})}}, - {"find", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalNumber, optionalBoolean}, {}, {optionalNumber, optionalNumber})}}, + {"find", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), + arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})})}}, {"format", {formatFn}}, // FIXME {"gmatch", {gmatchFunc}}, {"gsub", {gsubFunc}}, {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, {"lower", {stringToStringType}}, - {"match", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalNumber}, {}, {optionalString})}}, + {"match", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}), + arena->addTypePack(TypePackVar{VariadicTypePack{optionalString}})})}}, {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, {"reverse", {stringToStringType}}, {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index fe87835..8d23aa4 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -53,6 +53,14 @@ Generic::Generic(TypeLevel level, const Name& name) { } +Generic::Generic(Scope2* scope, const Name& name) + : index(++nextIndex) + , scope(scope) + , name(name) + , explicitName(true) +{ +} + int Generic::nextIndex = 0; Error::Error() diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 877663d..6147e11 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -17,11 +17,8 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINT(LuauTypeInferIterationLimit); LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); -LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); -LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) -LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau @@ -354,7 +351,7 @@ void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool i void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "TypeId tryUnify_"); + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); ++sharedState.counters.iterationCount; @@ -983,7 +980,7 @@ void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "TypePackId tryUnify_"); + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); ++sharedState.counters.iterationCount; @@ -1316,12 +1313,9 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal tryUnify_(subFunction->retTypes, superFunction->retTypes); } - if (FFlag::LuauTxnLogRefreshFunctionPointers) - { - // Updating the log may have invalidated the function pointers - superFunction = log.getMutable(superTy); - subFunction = log.getMutable(subTy); - } + // Updating the log may have invalidated the function pointers + superFunction = log.getMutable(superTy); + subFunction = log.getMutable(subTy); ctx = context; @@ -1360,9 +1354,6 @@ struct Resetter void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { - if (!FFlag::LuauTableSubtypingVariance2) - return DEPRECATED_tryUnifyTables(subTy, superTy, isIntersection); - TableTypeVar* superTable = log.getMutable(superTy); TableTypeVar* subTable = log.getMutable(subTy); @@ -1379,8 +1370,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto subIter = subTable->props.find(propName); - if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && - !isOptional(superProp.type)) + if (subIter == subTable->props.end() && subTable->state == TableState::Unsealed && !isOptional(superProp.type)) missingProperties.push_back(propName); } @@ -1398,7 +1388,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto superIter = superTable->props.find(propName); - if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || !isOptional(subProp.type))) + if (superIter == superTable->props.end()) extraProperties.push_back(propName); } @@ -1443,7 +1433,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); } - else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) + else if (subTable->state == TableState::Unsealed && isOptional(prop.type)) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) @@ -1512,9 +1502,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else if (variance == Covariant) { } - else if (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables && isOptional(prop.type)) - { - } else if (superTable->state == TableState::Free) { PendingType* pendingSuper = log.queue(superTy); @@ -1639,296 +1626,6 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see return types->addType(UnionTypeVar{{getSingletonTypes().nilType, ty}}); } -void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) -{ - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); - Resetter resetter{&variance}; - variance = Invariant; - - TableTypeVar* superTable = log.getMutable(superTy); - TableTypeVar* subTable = log.getMutable(subTy); - - if (!superTable || !subTable) - ice("passed non-table types to unifyTables"); - - if (superTable->state == TableState::Sealed && subTable->state == TableState::Sealed) - return tryUnifySealedTables(subTy, superTy, isIntersection); - else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Unsealed) || - (superTable->state == TableState::Unsealed && subTable->state == TableState::Sealed)) - return tryUnifySealedTables(subTy, superTy, isIntersection); - else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Generic) || - (superTable->state == TableState::Generic && subTable->state == TableState::Sealed)) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - else if ((superTable->state == TableState::Free) != (subTable->state == TableState::Free)) // one table is free and the other is not - { - TypeId freeTypeId = subTable->state == TableState::Free ? subTy : superTy; - TypeId otherTypeId = subTable->state == TableState::Free ? superTy : subTy; - - return tryUnifyFreeTable(otherTypeId, freeTypeId); - } - else if (superTable->state == TableState::Free && subTable->state == TableState::Free) - { - tryUnifyFreeTable(subTy, superTy); - - // avoid creating a cycle when the types are already pointing at each other - if (follow(superTy) != follow(subTy)) - { - log.bindTable(superTy, subTy); - } - return; - } - else if (superTable->state != TableState::Sealed && subTable->state != TableState::Sealed) - { - // All free tables are checked in one of the branches above - LUAU_ASSERT(superTable->state != TableState::Free); - LUAU_ASSERT(subTable->state != TableState::Free); - - // Tables must have exactly the same props and their types must all unify - // I honestly have no idea if this is remotely close to reasonable. - for (const auto& [name, prop] : superTable->props) - { - const auto& r = subTable->props.find(name); - if (r == subTable->props.end()) - reportError(TypeError{location, UnknownProperty{subTy, name}}); - else - tryUnify_(r->second.type, prop.type); - } - - if (superTable->indexer && subTable->indexer) - tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - else if (superTable->indexer) - { - // passing/assigning a table without an indexer to something that has one - // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. - if (subTable->state == TableState::Unsealed) - { - log.changeIndexer(subTy, superTable->indexer); - } - else - reportError(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}}); - } - } - else if (superTable->state == TableState::Sealed) - { - // lt is sealed and so it must be possible for rt to have precisely the same shape - // Verify that this is the case, then bind rt to lt. - ice("unsealed tables are not working yet", location); - } - else if (subTable->state == TableState::Sealed) - return tryUnifyTables(superTy, subTy, isIntersection); - else - ice("tryUnifyTables"); -} - -void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) -{ - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); - TableTypeVar* freeTable = log.getMutable(superTy); - TableTypeVar* subTable = log.getMutable(subTy); - - if (!freeTable || !subTable) - ice("passed non-table types to tryUnifyFreeTable"); - - // Any properties in freeTable must unify with those in otherTable. - // Then bind freeTable to otherTable. - for (const auto& [freeName, freeProp] : freeTable->props) - { - if (auto subProp = findTablePropertyRespectingMeta(subTy, freeName)) - { - tryUnify_(*subProp, freeProp.type); - - /* - * TypeVars are commonly cyclic, so it is entirely possible - * for unifying a property of a table to change the table itself! - * We need to check for this and start over if we notice this occurring. - * - * I believe this is guaranteed to terminate eventually because this will - * only happen when a free table is bound to another table. - */ - if (!log.getMutable(superTy) || !log.getMutable(subTy)) - return tryUnify_(subTy, superTy); - - if (TableTypeVar* pendingFreeTtv = log.getMutable(superTy); pendingFreeTtv && pendingFreeTtv->boundTo) - return tryUnify_(subTy, superTy); - } - else - { - // If the other table is also free, then we are learning that it has more - // properties than we previously thought. Else, it is an error. - if (subTable->state == TableState::Free) - { - PendingType* pendingSub = log.queue(subTy); - TableTypeVar* pendingSubTtv = getMutable(pendingSub); - LUAU_ASSERT(pendingSubTtv); - pendingSubTtv->props.insert({freeName, freeProp}); - } - else - reportError(TypeError{location, UnknownProperty{subTy, freeName}}); - } - } - - if (freeTable->indexer && subTable->indexer) - { - Unifier innerState = makeChildUnifier(); - innerState.tryUnifyIndexer(*subTable->indexer, *freeTable->indexer); - - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); - - log.concat(std::move(innerState.log)); - } - else if (subTable->state == TableState::Free && freeTable->indexer) - { - log.changeIndexer(superTy, subTable->indexer); - } - - if (!freeTable->boundTo && subTable->state != TableState::Free) - { - log.bindTable(superTy, subTy); - } -} - -void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection) -{ - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); - TableTypeVar* superTable = log.getMutable(superTy); - TableTypeVar* subTable = log.getMutable(subTy); - - if (!superTable || !subTable) - ice("passed non-table types to unifySealedTables"); - - std::vector missingPropertiesInSuper; - bool isUnnamedTable = subTable->name == std::nullopt && subTable->syntheticName == std::nullopt; - bool errorReported = false; - - // Optimization: First test that the property sets are compatible without doing any recursive unification - if (!subTable->indexer) - { - for (const auto& [propName, superProp] : superTable->props) - { - auto subIter = subTable->props.find(propName); - if (subIter == subTable->props.end() && !isOptional(superProp.type)) - missingPropertiesInSuper.push_back(propName); - } - - if (!missingPropertiesInSuper.empty()) - { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); - return; - } - } - - Unifier innerState = makeChildUnifier(); - - // Tables must have exactly the same props and their types must all unify - for (const auto& it : superTable->props) - { - const auto& r = subTable->props.find(it.first); - if (r == subTable->props.end()) - { - if (isOptional(it.second.type)) - continue; - - missingPropertiesInSuper.push_back(it.first); - - innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - } - else - { - if (isUnnamedTable && r->second.location) - { - size_t oldErrorSize = innerState.errors.size(); - Location old = innerState.location; - innerState.location = *r->second.location; - innerState.tryUnify_(r->second.type, it.second.type); - innerState.location = old; - - if (oldErrorSize != innerState.errors.size() && !errorReported) - { - errorReported = true; - reportError(innerState.errors.back()); - } - } - else - { - innerState.tryUnify_(r->second.type, it.second.type); - } - } - } - - if (superTable->indexer || subTable->indexer) - { - if (superTable->indexer && subTable->indexer) - innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - else if (subTable->state == TableState::Unsealed) - { - if (superTable->indexer && !subTable->indexer) - { - log.changeIndexer(subTy, superTable->indexer); - } - } - else if (superTable->state == TableState::Unsealed) - { - if (subTable->indexer && !superTable->indexer) - { - log.changeIndexer(superTy, subTable->indexer); - } - } - else if (superTable->indexer) - { - innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); - for (const auto& [name, type] : subTable->props) - { - const auto& it = superTable->props.find(name); - if (it == superTable->props.end()) - innerState.tryUnify_(type.type, superTable->indexer->indexResultType); - } - } - else - innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - } - - if (!errorReported) - log.concat(std::move(innerState.log)); - else - return; - - if (!missingPropertiesInSuper.empty()) - { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); - return; - } - - // If the superTy is an immediate part of an intersection type, do not do extra-property check. - // Otherwise, we would falsely generate an extra-property-error for 's' in this code: - // local a: {n: number} & {s: string} = {n=1, s=""} - // When checking against the table '{n: number}'. - if (!isIntersection && superTable->state != TableState::Unsealed && !superTable->indexer) - { - // Check for extra properties in the subTy - std::vector extraPropertiesInSub; - - for (const auto& [subKey, subProp] : subTable->props) - { - const auto& superIt = superTable->props.find(subKey); - if (superIt == superTable->props.end()) - { - if (isOptional(subProp.type)) - continue; - - extraPropertiesInSub.push_back(subKey); - } - } - - if (!extraPropertiesInSub.empty()) - { - reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraPropertiesInSub), MissingProperties::Extra}}); - return; - } - } - - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); -} - void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { const MetatableTypeVar* superMetatable = get(superTy); @@ -2068,14 +1765,6 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) return fail(); } -void Unifier::tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer) -{ - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); - - tryUnify_(subIndexer.indexType, superIndexer.indexType); - tryUnify_(subIndexer.indexResultType, superIndexer.indexResultType); -} - static void queueTypePack(std::vector& queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) { while (true) @@ -2435,7 +2124,7 @@ void Unifier::occursCheck(TypeId needle, TypeId haystack) void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) { RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "occursCheck for TypeId"); + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); auto check = [&](TypeId tv) { occursCheck(seen, needle, tv); @@ -2506,7 +2195,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ ice("Expected needle pack to be free"); RecursionLimiter _ra(&sharedState.counters.recursionCount, - FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit, "occursCheck for TypePackId"); + FFlag::LuauAutocompleteDynamicLimits ? sharedState.counters.recursionLimit : FInt::LuauTypeInferRecursionLimit); while (!log.getMutable(haystack)) { diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 10cf17d..4bc8cab 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -9,6 +9,7 @@ #include "FileUtils.h" LUAU_FASTFLAG(DebugLuauTimeTracing) +LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution) enum class ReportFormat { @@ -49,6 +50,9 @@ static void reportError(const Luau::Frontend& frontend, ReportFormat format, con if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); + else if (FFlag::LuauTypeMismatchModuleNameResolution) + report(format, humanReadableName.c_str(), error.location, "TypeError", + Luau::toString(error, Luau::TypeErrorToStringOptions{frontend.fileResolver}).c_str()); else report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error).c_str()); } diff --git a/CMakeLists.txt b/CMakeLists.txt index c624a13..e256e23 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,7 @@ option(LUAU_BUILD_TESTS "Build tests" ON) option(LUAU_BUILD_WEB "Build Web module" OFF) option(LUAU_WERROR "Warnings as errors" OFF) option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF) +option(LUAU_EXTERN_C "Use extern C for all APIs" OFF) if(LUAU_STATIC_CRT) cmake_minimum_required(VERSION 3.15) @@ -115,6 +116,14 @@ target_compile_options(Luau.CodeGen PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) target_compile_options(isocline PRIVATE ${LUAU_OPTIONS} ${ISOCLINE_OPTIONS}) +if(LUAU_EXTERN_C) + # enable extern "C" for VM (lua.h, lualib.h) and Compiler (luacode.h) to make Luau friendlier to use from non-C++ languages + # note that we enable LUA_USE_LONGJMP=1 as well; otherwise functions like luaL_error will throw C++ exceptions, which can't be done from extern "C" functions + target_compile_definitions(Luau.VM PUBLIC LUA_USE_LONGJMP=1) + target_compile_definitions(Luau.VM PUBLIC LUA_API=extern\"C\") + target_compile_definitions(Luau.Compiler PUBLIC LUACODE_API=extern\"C\") +endif() + if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) # disable partial redundancy elimination which regresses interpreter codegen substantially in VS2022: # https://developercommunity.visualstudio.com/t/performance-regression-on-a-complex-interpreter-lo/1631863 diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index f71d893..218bb5d 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -7,7 +7,7 @@ // Creating the bytecode is outside the scope of this file and is handled by bytecode builder (BytecodeBuilder.h) and bytecode compiler (Compiler.h) // Note that ALL enums declared in this file are order-sensitive since the values are baked into bytecode that needs to be processed by legacy clients. -// Bytecode definitions +// # Bytecode definitions // Bytecode instructions are using "word code" - each instruction is one or many 32-bit words. // The first word in the instruction is always the instruction header, and *must* contain the opcode (enum below) in the least significant byte. // @@ -19,7 +19,7 @@ // Instruction word is sometimes followed by one extra word, indicated as AUX - this is just a 32-bit word and is decoded according to the specification for each opcode. // For each opcode the encoding is *static* - that is, based on the opcode you know a-priory how large the instruction is, with the exception of NEWCLOSURE -// Bytecode indices +// # Bytecode indices // Bytecode instructions commonly refer to integer values that define offsets or indices for various entities. For each type, there's a maximum encodable value. // Note that in some cases, the compiler will set a lower limit than the maximum encodable value is to prevent fragile code into bumping against the limits whenever we change the compilation details. // Additionally, in some specific instructions such as ANDK, the limit on the encoded value is smaller; this means that if a value is larger, a different instruction must be selected. @@ -29,6 +29,15 @@ // Constants: 0-2^23-1. Constants are stored in a table allocated with each proto; to allow for future bytecode tweaks the encodable value is limited to 23 bits. // Closures: 0-2^15-1. Closures are created from child protos via a child index; the limit is for the number of closures immediately referenced in each function. // Jumps: -2^23..2^23. Jump offsets are specified in word increments, so jumping over an instruction may sometimes require an offset of 2 or more. + +// # Bytecode versions +// Bytecode serialized format embeds a version number, that dictates both the serialized form as well as the allowed instructions. As long as the bytecode version falls into supported +// range (indicated by LBC_BYTECODE_MIN / LBC_BYTECODE_MAX) and was produced by Luau compiler, it should load and execute correctly. +// +// Note that Luau runtime doesn't provide indefinite bytecode compatibility: support for older versions gets removed over time. As such, bytecode isn't a durable storage format and it's expected +// that Luau users can recompile bytecode from source on Luau version upgrades if necessary. + +// Bytecode opcode, part of the instruction header enum LuauOpcode { // NOP: noop @@ -380,8 +389,10 @@ enum LuauOpcode // Bytecode tags, used internally for bytecode encoded as a string enum LuauBytecodeTag { - // Bytecode version - LBC_VERSION = 2, + // Bytecode version; runtime supports [MIN, MAX], compiler emits TARGET by default but may emit a higher version when flags are enabled + LBC_VERSION_MIN = 2, + LBC_VERSION_MAX = 2, + LBC_VERSION_TARGET = 2, // Types of constant table entries LBC_CONSTANT_NIL = 0, LBC_CONSTANT_BOOLEAN, diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index dbe5429..6ec10b5 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -119,6 +119,8 @@ public: static std::string getError(const std::string& message); + static uint8_t getVersion(); + private: struct Constant { diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index a34f760..301cf25 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -9,6 +9,9 @@ namespace Luau { +static_assert(LBC_VERSION_TARGET >= LBC_VERSION_MIN && LBC_VERSION_TARGET <= LBC_VERSION_MAX, "Invalid bytecode version setup"); +static_assert(LBC_VERSION_MAX <= 127, "Bytecode version should be 7-bit so that we can extend the serialization to use varint transparently"); + static const uint32_t kMaxConstantCount = 1 << 23; static const uint32_t kMaxClosureCount = 1 << 15; @@ -572,7 +575,10 @@ void BytecodeBuilder::finalize() bytecode.reserve(capacity); // assemble final bytecode blob - bytecode = char(LBC_VERSION); + uint8_t version = getVersion(); + LUAU_ASSERT(version >= LBC_VERSION_MIN && version <= LBC_VERSION_MAX); + + bytecode = char(version); writeStringTable(bytecode); @@ -1040,7 +1046,7 @@ void BytecodeBuilder::expandJumps() std::string BytecodeBuilder::getError(const std::string& message) { - // 0 acts as a special marker for error bytecode (it's equal to LBC_VERSION for valid bytecode blobs) + // 0 acts as a special marker for error bytecode (it's equal to LBC_VERSION_TARGET for valid bytecode blobs) std::string result; result += char(0); result += message; @@ -1048,6 +1054,12 @@ std::string BytecodeBuilder::getError(const std::string& message) return result; } +uint8_t BytecodeBuilder::getVersion() +{ + // This function usually returns LBC_VERSION_TARGET but may sometimes return a higher number (within LBC_VERSION_MIN/MAX) under fast flags + return LBC_VERSION_TARGET; +} + #ifdef LUAU_ASSERTENABLED void BytecodeBuilder::validate() const { diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 52dc924..e732256 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -16,8 +16,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompileIterNoPairs, false) - LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThresholdMaxBoost, 300) @@ -2672,7 +2670,7 @@ struct Compiler else if (builtin.isGlobal("pairs")) // for .. in pairs(t) { skipOp = LOP_FORGPREP_NEXT; - loopOp = FFlag::LuauCompileIterNoPairs ? LOP_FORGLOOP : LOP_FORGLOOP_NEXT; + loopOp = LOP_FORGLOOP; } } else if (stat->values.size == 2) @@ -2682,7 +2680,7 @@ struct Compiler if (builtin.isGlobal("next")) // for .. in next,t { skipOp = LOP_FORGPREP_NEXT; - loopOp = FFlag::LuauCompileIterNoPairs ? LOP_FORGLOOP : LOP_FORGLOOP_NEXT; + loopOp = LOP_FORGLOOP; } } } diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp index 2815268..c2110cb 100644 --- a/VM/src/ludata.cpp +++ b/VM/src/ludata.cpp @@ -26,6 +26,8 @@ void luaU_freeudata(lua_State* L, Udata* u, lua_Page* page) { void (*dtor)(lua_State*, void*) = nullptr; dtor = L->global->udatagc[u->tag]; + // TODO: access to L here is highly unsafe since this is called during internal GC traversal + // certain operations such as lua_getthreaddata are okay, but by and large this risks crashes on improper use if (dtor) dtor(L, u->data); } diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 8b742f1..86afddd 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -154,11 +154,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size return 1; } - if (version != LBC_VERSION) + if (version < LBC_VERSION_MIN || version > LBC_VERSION_MAX) { char chunkid[LUA_IDSIZE]; luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); - lua_pushfstring(L, "%s: bytecode version mismatch (expected %d, got %d)", chunkid, LBC_VERSION, version); + lua_pushfstring(L, "%s: bytecode version mismatch (expected [%d..%d], got %d)", chunkid, LBC_VERSION_MIN, LBC_VERSION_MAX, version); return 1; } diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 036bf12..655e48c 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -261,8 +261,6 @@ L1: RETURN R0 0 TEST_CASE("ForBytecode") { - ScopedFastFlag sff2("LuauCompileIterNoPairs", false); - // basic for loop: variable directly refers to internal iteration index (R2) CHECK_EQ("\n" + compileFunction0("for i=1,5 do print(i) end"), R"( LOADN R2 1 @@ -329,7 +327,7 @@ L0: GETIMPORT R5 3 MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 -L1: FORGLOOP_NEXT R0 L0 +L1: FORGLOOP R0 L0 2 RETURN R0 0 )"); @@ -342,7 +340,7 @@ L0: GETIMPORT R5 3 MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 -L1: FORGLOOP_NEXT R0 L0 +L1: FORGLOOP R0 L0 2 RETURN R0 0 )"); } @@ -2262,8 +2260,6 @@ TEST_CASE("TypeAliasing") TEST_CASE("DebugLineInfo") { - ScopedFastFlag sff("LuauCompileIterNoPairs", false); - Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); Luau::compileOrThrow(bcb, R"( @@ -2313,7 +2309,7 @@ return result 15: L0: MOVE R7 R1 15: MOVE R8 R5 15: CONCAT R1 R7 R8 -14: L1: FORGLOOP_NEXT R2 L0 +14: L1: FORGLOOP R2 L0 1 17: RETURN R1 1 )"); } @@ -2545,8 +2541,6 @@ a TEST_CASE("DebugSource") { - ScopedFastFlag sff("LuauCompileIterNoPairs", false); - const char* source = R"( local kSelectedBiomes = { ['Mountains'] = true, @@ -2614,7 +2608,7 @@ L0: MOVE R7 R1 MOVE R8 R5 CONCAT R1 R7 R8 14: for k in pairs(kSelectedBiomes) do -L1: FORGLOOP_NEXT R2 L0 +L1: FORGLOOP R2 L0 1 17: return result RETURN R1 1 )"); @@ -2622,8 +2616,6 @@ RETURN R1 1 TEST_CASE("DebugLocals") { - ScopedFastFlag sff("LuauCompileIterNoPairs", false); - const char* source = R"( function foo(e, f) local a = 1 @@ -2661,12 +2653,12 @@ end local 0: reg 5, start pc 5 line 5, end pc 8 line 5 local 1: reg 6, start pc 14 line 8, end pc 18 line 8 local 2: reg 7, start pc 14 line 8, end pc 18 line 8 -local 3: reg 3, start pc 21 line 12, end pc 24 line 12 -local 4: reg 3, start pc 26 line 16, end pc 30 line 16 -local 5: reg 0, start pc 0 line 3, end pc 34 line 21 -local 6: reg 1, start pc 0 line 3, end pc 34 line 21 -local 7: reg 2, start pc 1 line 4, end pc 34 line 21 -local 8: reg 3, start pc 34 line 21, end pc 34 line 21 +local 3: reg 3, start pc 22 line 12, end pc 25 line 12 +local 4: reg 3, start pc 27 line 16, end pc 31 line 16 +local 5: reg 0, start pc 0 line 3, end pc 35 line 21 +local 6: reg 1, start pc 0 line 3, end pc 35 line 21 +local 7: reg 2, start pc 1 line 4, end pc 35 line 21 +local 8: reg 3, start pc 35 line 21, end pc 35 line 21 3: LOADN R2 1 4: LOADN R5 1 4: LOADN R3 3 @@ -2683,7 +2675,7 @@ local 8: reg 3, start pc 34 line 21, end pc 34 line 21 8: MOVE R9 R6 8: MOVE R10 R7 8: CALL R8 2 0 -7: L3: FORGLOOP_NEXT R3 L2 +7: L3: FORGLOOP R3 L2 2 11: LOADN R3 2 12: GETIMPORT R4 1 12: LOADN R5 2 @@ -3795,8 +3787,6 @@ RETURN R0 1 TEST_CASE("SharedClosure") { - ScopedFastFlag sff("LuauCompileIterNoPairs", false); - // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( local val = ... @@ -3939,7 +3929,7 @@ L2: GETIMPORT R5 1 NEWCLOSURE R6 P1 CAPTURE VAL R3 CALL R5 1 0 -L3: FORGLOOP_NEXT R0 L2 +L3: FORGLOOP R0 L2 2 LOADN R2 1 LOADN R0 10 LOADN R1 1 diff --git a/tests/Fixture.h b/tests/Fixture.h index ffcd4b9..0e3735f 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -2,13 +2,13 @@ #pragma once #include "Luau/Config.h" -#include "Luau/ConstraintGraphBuilder.h" #include "Luau/FileResolver.h" #include "Luau/Frontend.h" #include "Luau/IostreamHelpers.h" #include "Luau/Linter.h" #include "Luau/Location.h" #include "Luau/ModuleResolver.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index d585b73..7c2f4d1 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -279,7 +279,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") int limit = 400; #endif ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; - ScopedFastFlag sff{"LuauRecursionLimitException", true}; TypeArena src; diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 284230c..a474b6e 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -12,7 +12,6 @@ using namespace Luau; struct NormalizeFixture : Fixture { ScopedFastFlag sff1{"LuauLowerBoundsCalculation", true}; - ScopedFastFlag sff2{"LuauTableSubtypingVariance2", true}; }; void createSomeClasses(TypeChecker& typeChecker) diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index bef38fc..6619147 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -264,10 +264,13 @@ TEST_CASE_FIXTURE(LimitFixture, "typescript_port_of_Result_type") } )LUA"; + CheckResult result = check(src); + CodeTooComplex ctc; + if (FFlag::LuauLowerBoundsCalculation) - (void)check(src); + LUAU_REQUIRE_ERRORS(result); else - CHECK_THROWS_AS(check(src), std::exception); + CHECK(hasError(result, &ctc)); } TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 4d2e94e..e03069a 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -409,8 +409,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") { - ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; - CheckResult result = check(R"( local base = {} function base:one() return 1 end diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 86cc970..d6f0a0c 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -7,8 +7,21 @@ using namespace Luau; +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) + TEST_SUITE_BEGIN("TypeAliases"); +TEST_CASE_FIXTURE(Fixture, "basic_alias") +{ + CheckResult result = check(R"( + type T = number + local x: T = 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("x"))); +} + TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") { CheckResult result = check(R"( @@ -24,6 +37,63 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g"))); } +TEST_CASE_FIXTURE(Fixture, "names_are_ascribed") +{ + CheckResult result = check(R"( + type T = { x: number } + local x: T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("T", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_steal_hoisted_type_alias") +{ + // This is a tricky case. In order to support recursive type aliases, + // we first walk the block and generate free types as placeholders. + // We then walk the AST as normal. If we declare a type alias as below, + // we generate a free type. We then begin our normal walk, examining + // local x: T = "foo", which establishes two constraints: + // a <: b + // string <: a + // We then visit the type alias, and establish that + // b <: number + // Then, when solving these constraints, we dispatch them in the order + // they appear above. This means that a ~ b, and a ~ string, thus + // b ~ string. This means the b <: number constraint has no effect. + // Essentially we've "stolen" the alias's type out from under it. + // This test ensures that we don't actually do this. + CheckResult result = check(R"( + local x: T = "foo" + type T = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK(result.errors[0] == TypeError{ + Location{{1, 21}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + getSingletonTypes().numberType, + getSingletonTypes().stringType, + }, + }); + } + else + { + CHECK(result.errors[0] == TypeError{ + Location{{1, 8}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + getSingletonTypes().numberType, + getSingletonTypes().stringType, + }, + }); + } +} + TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") { CheckResult result = check(R"( @@ -41,7 +111,22 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_whe CHECK_EQ(typeChecker.numberType, tm->givenType); } -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types") +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_aliases") +{ + CheckResult result = check(R"( + --!strict + type T = { f: number, g: U } + type U = { h: number, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = 3, g = { h = 5, i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_generic_aliases") { CheckResult result = check(R"( --!strict diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index ccdd2b3..3e2ad6d 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -30,11 +30,21 @@ TEST_CASE_FIXTURE(Fixture, "successful_check") dumpErrors(result); } +TEST_CASE_FIXTURE(Fixture, "variable_type_is_supertype") +{ + CheckResult result = check(R"( + local x: number = 1 + local y: number? = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "function_parameters_can_have_annotations") { CheckResult result = check(R"( function double(x: number) - return x * 2 + return 2 end local four = double(2) @@ -47,7 +57,7 @@ TEST_CASE_FIXTURE(Fixture, "function_parameter_annotations_are_checked") { CheckResult result = check(R"( function double(x: number) - return x * 2 + return 2 end local four = double("two") @@ -70,13 +80,13 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotations_are_checked") const FunctionTypeVar* ftv = get(fiftyType); REQUIRE(ftv != nullptr); - TypePackId retPack = ftv->retTypes; + TypePackId retPack = follow(ftv->retTypes); const TypePack* tp = get(retPack); REQUIRE(tp != nullptr); REQUIRE_EQ(1, tp->head.size()); - REQUIRE_EQ(typeChecker.anyType, tp->head[0]); + REQUIRE_EQ(typeChecker.anyType, follow(tp->head[0])); } TEST_CASE_FIXTURE(Fixture, "function_return_multret_annotations_are_checked") @@ -116,6 +126,23 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotation_should_continuously_parse LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "unknown_type_reference_generates_error") +{ + CheckResult result = check(R"( + local x: IDoNotExist + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(result.errors[0] == TypeError{ + Location{{1, 17}, {1, 28}}, + getMainSourceModule()->name, + UnknownSymbol{ + "IDoNotExist", + UnknownSymbol::Context::Type, + }, + }); +} + TEST_CASE_FIXTURE(Fixture, "typeof_variable_type_annotation_should_return_its_type") { CheckResult result = check(R"( @@ -632,7 +659,10 @@ int AssertionCatcher::tripped; TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice") { - ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; + ScopedFastFlag sffs[] = { + {"DebugLuauMagicTypes", true}, + {"LuauUseInternalCompilerErrorException", false}, + }; AssertionCatcher ac; @@ -646,9 +676,10 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice") TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_handler") { - ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; - - AssertionCatcher ac; + ScopedFastFlag sffs[] = { + {"DebugLuauMagicTypes", true}, + {"LuauUseInternalCompilerErrorException", false}, + }; bool caught = false; @@ -662,8 +693,44 @@ TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_handler") std::runtime_error); CHECK_EQ(true, caught); +} - frontend.iceHandler.onInternalError = {}; +TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag") +{ + ScopedFastFlag sffs[] = { + {"DebugLuauMagicTypes", true}, + {"LuauUseInternalCompilerErrorException", true}, + }; + + AssertionCatcher ac; + + CHECK_THROWS_AS(check(R"( + local a: _luau_ice = 55 + )"), + InternalCompilerError); + + LUAU_ASSERT(1 == AssertionCatcher::tripped); +} + +TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_exception_with_flag_handler") +{ + ScopedFastFlag sffs[] = { + {"DebugLuauMagicTypes", true}, + {"LuauUseInternalCompilerErrorException", true}, + }; + + bool caught = false; + + frontend.iceHandler.onInternalError = [&](const char*) { + caught = true; + }; + + CHECK_THROWS_AS(check(R"( + local a: _luau_ice = 55 + )"), + InternalCompilerError); + + CHECK_EQ(true, caught); } TEST_CASE_FIXTURE(Fixture, "luau_ice_is_not_special_without_the_flag") diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index edb5adc..97ba080 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -700,11 +700,6 @@ end TEST_CASE_FIXTURE(Fixture, "generic_functions_should_be_memory_safe") { - ScopedFastFlag sffs[] = { - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( --!strict -- At one point this produced a UAF @@ -979,8 +974,6 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") { - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; - // Mutability in type function application right now can create strange recursive types CheckResult result = check(R"( type Table = { a: number } @@ -1015,8 +1008,6 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument") { - ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; - CheckResult result = check(R"( local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) @@ -1123,8 +1114,6 @@ TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics1") { - ScopedFastFlag sff{"LuauApplyTypeFunctionFix", true}; - // https://github.com/Roblox/luau/issues/484 CheckResult result = check(R"( --!strict @@ -1153,8 +1142,6 @@ local complex: ComplexObject = { TEST_CASE_FIXTURE(Fixture, "apply_type_function_nested_generics2") { - ScopedFastFlag sff{"LuauApplyTypeFunctionFix", true}; - // https://github.com/Roblox/luau/issues/484 CheckResult result = check(R"( --!strict diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index afec20b..a0f670f 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -12,8 +12,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauTableSubtypingVariance2) - TEST_SUITE_BEGIN("TypeInferModules"); TEST_CASE_FIXTURE(BuiltinsFixture, "require") @@ -326,16 +324,9 @@ local b: B.T = a CheckResult result = frontend.check("game/C"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauTableSubtypingVariance2) - { - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' caused by: Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); - } - else - { - CHECK_EQ(toString(result.errors[0]), "Type 'T' from 'game/A' could not be converted into 'T' from 'game/B'"); - } } TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict_instantiated") @@ -367,16 +358,9 @@ local b: B.T = a CheckResult result = frontend.check("game/D"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauTableSubtypingVariance2) - { - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' caused by: Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); - } - else - { - CHECK_EQ(toString(result.errors[0]), "Type 'T' from 'game/B' could not be converted into 'T' from 'game/C'"); - } } TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index cefba4b..3f5dad3 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -353,8 +353,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_non_binary_expressions_actually_resol TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") { - ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local t: {x: number?} = {x = nil} diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index a90f434..4a88abe 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -260,10 +260,6 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") { - ScopedFastFlag sffs[]{ - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( --!strict local x: { ["<>"] : number } diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 87d4965..77a2928 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -276,8 +276,6 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification") TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local a = {} a.x = 99 @@ -347,8 +345,6 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict function foo(o) @@ -370,8 +366,6 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local T = {} T.bar = 'hello' @@ -477,8 +471,6 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict local t = { u = {} } @@ -512,8 +504,6 @@ TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_ TEST_CASE_FIXTURE(Fixture, "width_subtyping") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict function f(x : { q : number }) @@ -772,8 +762,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_han TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local t: { a: string, [number]: string } = { a = "foo" } )"); @@ -783,8 +771,6 @@ TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") TEST_CASE_FIXTURE(Fixture, "array_factory_function") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( function empty() return {} end local array: {string} = empty() @@ -1175,8 +1161,6 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_sealed_table_must TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_unsealed_table_is_ok") { - ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; - CheckResult result = check(R"( local t = {x = 1} function t.m() end @@ -1187,8 +1171,6 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_unsealed_table_is_ok") TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_unsealed_table_is_ok") { - ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; - CheckResult result = check(R"( local t = {x = 1} function t:m() end @@ -1468,11 +1450,6 @@ TEST_CASE_FIXTURE(Fixture, "right_table_missing_key2") TEST_CASE_FIXTURE(Fixture, "casting_unsealed_tables_with_props_into_table_with_indexer") { - ScopedFastFlag sff[]{ - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( type StringToStringMap = { [string]: string } local rt: StringToStringMap = { ["foo"] = 1 } @@ -1518,11 +1495,6 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { - ScopedFastFlag sff[]{ - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end foo({ a = 1 }) @@ -1609,8 +1581,6 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multipl TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( local vec3 = {x = 1, y = 2, z = 3} local vec1 = {x = 1} @@ -1998,8 +1968,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_prope TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_properties_in_strict") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict local buttons = {} @@ -2013,8 +1981,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_should_cope_with_optional_prope TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") { - ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - CheckResult result = check(R"( type A = { x: number, y: number } type B = { x: number, y: string } @@ -2031,8 +1997,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") { - ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - CheckResult result = check(R"( type AS = { x: number, y: number } type BS = { x: number, y: string } @@ -2054,11 +2018,6 @@ caused by: TEST_CASE_FIXTURE(BuiltinsFixture, "error_detailed_metatable_prop") { - ScopedFastFlag sff[]{ - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); local b1 = setmetatable({ x = 2, y = "hello" }, { __call = function(s) end }); @@ -2085,8 +2044,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_key") { - ScopedFastFlag luauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - CheckResult result = check(R"( type A = { [number]: string } type B = { [string]: string } @@ -2103,8 +2060,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") { - ScopedFastFlag luauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - CheckResult result = check(R"( type A = { [number]: number } type B = { [number]: string } @@ -2121,10 +2076,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") { - ScopedFastFlag sffs[]{ - {"LuauTableSubtypingVariance2", true}, - }; - CheckResult result = check(R"( --!strict type Super = { x : number } @@ -2140,11 +2091,6 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") { - ScopedFastFlag sffs[]{ - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( --!strict type Super = { x : number } @@ -2166,10 +2112,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") { - ScopedFastFlag sffs[]{ - {"LuauTableSubtypingVariance2", true}, - }; - CheckResult result = check(R"( --!strict type Super = { x : number } @@ -2185,10 +2127,6 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_metatable_type_call") { - ScopedFastFlag sff[]{ - {"LuauUnsealedTableLiteral", true}, - }; - CheckResult result = check(R"( local b b = setmetatable({}, {__call = b}) @@ -2201,11 +2139,6 @@ b() TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables") { - ScopedFastFlag sffs[] = { - {"LuauTableSubtypingVariance2", true}, - {"LuauSubtypingAddOptPropsToUnsealedTables", true}, - }; - CheckResult result = check(R"( --!strict local function setNumber(t: { p: number? }, x:number) t.p = x end @@ -2706,8 +2639,6 @@ type t0 = any TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_table_cloning_2") { - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; - CheckResult result = check(R"( type X = T type K = X @@ -2725,8 +2656,6 @@ type K = X TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_3") { - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; - CheckResult result = check(R"( type X = T local a = {} @@ -2977,8 +2906,6 @@ TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") { - ScopedFastFlag luauSubtypingAddOptPropsToUnsealedTables{"LuauSubtypingAddOptPropsToUnsealedTables", true}; - CheckResult result = check(R"( type X = { { x: boolean?, y: boolean? } } diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 6257cda..6a048b2 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -887,8 +887,6 @@ end TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") { - ScopedFastFlag subtypingVariance{"LuauTableSubtypingVariance2", true}; - CheckResult result = check(R"( --!strict --!nolint @@ -928,7 +926,6 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") { ScopedFastInt sfi("LuauTypeInferRecursionLimit", 2); - ScopedFastFlag sff{"LuauRecursionLimitException", true}; CheckResult result = check(R"( function complex() diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index d19d80c..2b48133 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -428,12 +428,6 @@ y = x TEST_CASE_FIXTURE(Fixture, "unify_sealed_table_union_check") { - ScopedFastFlag sffs[] = { - {"LuauTableSubtypingVariance2", true}, - {"LuauUnsealedTableLiteral", true}, - {"LuauSubtypingAddOptPropsToUnsealedTables", true}, - }; - CheckResult result = check(R"( -- the difference between this and unify_unsealed_table_union_check is the type annotation on x local t = { x = 3, y = true } diff --git a/tests/VisitTypeVar.test.cpp b/tests/VisitTypeVar.test.cpp index 01960fb..4fba694 100644 --- a/tests/VisitTypeVar.test.cpp +++ b/tests/VisitTypeVar.test.cpp @@ -10,14 +10,9 @@ using namespace Luau; LUAU_FASTINT(LuauVisitRecursionLimit) -struct VisitTypeVarFixture : Fixture -{ - ScopedFastFlag flag2 = {"LuauRecursionLimitException", true}; -}; - TEST_SUITE_BEGIN("VisitTypeVar"); -TEST_CASE_FIXTURE(VisitTypeVarFixture, "throw_when_limit_is_exceeded") +TEST_CASE_FIXTURE(Fixture, "throw_when_limit_is_exceeded") { ScopedFastInt sfi{"LuauVisitRecursionLimit", 3}; @@ -30,7 +25,7 @@ TEST_CASE_FIXTURE(VisitTypeVarFixture, "throw_when_limit_is_exceeded") CHECK_THROWS_AS(toString(tType), RecursionLimitException); } -TEST_CASE_FIXTURE(VisitTypeVarFixture, "dont_throw_when_limit_is_high_enough") +TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough") { ScopedFastInt sfi{"LuauVisitRecursionLimit", 8}; From 8f040862b11e3758af97ac9a712d6c579a7e01f5 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 30 Jun 2022 16:29:02 -0700 Subject: [PATCH 18/19] Sync to upstream/release/534 --- Analysis/include/Luau/Constraint.h | 19 +- .../include/Luau/ConstraintGraphBuilder.h | 109 +++- Analysis/include/Luau/ConstraintSolver.h | 12 +- Analysis/include/Luau/Error.h | 33 +- Analysis/include/Luau/Frontend.h | 5 + Analysis/include/Luau/Module.h | 2 +- Analysis/include/Luau/NotNull.h | 14 +- Analysis/include/Luau/Scope.h | 6 +- Analysis/include/Luau/TypeArena.h | 6 + Analysis/include/Luau/TypeInfer.h | 4 +- Analysis/include/Luau/TypeVar.h | 3 + Analysis/src/ConstraintGraphBuilder.cpp | 557 ++++++++++++++---- Analysis/src/ConstraintSolver.cpp | 71 ++- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 10 +- Analysis/src/Frontend.cpp | 22 +- Analysis/src/Normalize.cpp | 235 +------- Analysis/src/Quantify.cpp | 73 ++- Analysis/src/Scope.cpp | 15 + Analysis/src/ToString.cpp | 111 ++-- Analysis/src/TypeChecker2.cpp | 106 +++- Analysis/src/TypeInfer.cpp | 144 ++++- Analysis/src/Unifier.cpp | 4 +- Ast/src/Parser.cpp | 122 +++- Common/include/Luau/Bytecode.h | 5 +- Compiler/src/Builtins.cpp | 4 + Compiler/src/BytecodeBuilder.cpp | 15 +- Compiler/src/Compiler.cpp | 16 +- VM/src/lbaselib.cpp | 15 + VM/src/lbuiltins.cpp | 23 + VM/src/ldebug.h | 1 + VM/src/ltm.cpp | 2 +- VM/src/ltm.h | 2 +- VM/src/lvmexecute.cpp | 54 +- VM/src/lvmutils.cpp | 54 +- fuzz/protoprint.cpp | 11 + tests/Compiler.test.cpp | 16 +- tests/Conformance.test.cpp | 6 + tests/ConstraintGraphBuilder.test.cpp | 30 +- tests/ConstraintSolver.test.cpp | 17 +- tests/Fixture.cpp | 23 +- tests/Fixture.h | 3 +- tests/NotNull.test.cpp | 12 +- tests/Parser.test.cpp | 32 +- tests/ToString.test.cpp | 18 +- tests/TypeInfer.aliases.test.cpp | 35 +- tests/TypeInfer.annotations.test.cpp | 14 +- tests/TypeInfer.functions.test.cpp | 21 + tests/TypeInfer.generics.test.cpp | 5 +- tests/TypeInfer.operators.test.cpp | 55 ++ tests/TypeInfer.provisional.test.cpp | 22 + tests/TypeInfer.tables.test.cpp | 121 +++- tests/TypeInfer.test.cpp | 4 +- tests/conformance/basic.lua | 10 +- tests/conformance/events.lua | 38 ++ 54 files changed, 1714 insertions(+), 653 deletions(-) diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 8a41c9e..dcfb14b 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Ast.h" // Used for some of the enumerations #include "Luau/NotNull.h" #include "Luau/Variant.h" @@ -47,6 +48,21 @@ struct InstantiationConstraint TypeId superType; }; +struct UnaryConstraint +{ + AstExprUnary::Op op; + TypeId operandType; + TypeId resultType; +}; + +struct BinaryConstraint +{ + AstExprBinary::Op op; + TypeId leftType; + TypeId rightType; + TypeId resultType; +}; + // name(namedType) = name struct NameConstraint { @@ -54,7 +70,8 @@ struct NameConstraint std::string name; }; -using ConstraintV = Variant; +using ConstraintV = Variant; using ConstraintPtr = std::unique_ptr; struct Constraint diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 9b11869..a49e859 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -25,9 +25,12 @@ struct ConstraintGraphBuilder // scope pointers; the scopes themselves borrow pointers to other scopes to // define the scope hierarchy. std::vector>> scopes; + + ModuleName moduleName; SingletonTypes& singletonTypes; - TypeArena* const arena; + const NotNull arena; // The root scope of the module we're generating constraints for. + // This is null when the CGB is initially constructed. Scope2* rootScope; // A mapping of AST node to TypeId. DenseHashMap astTypes{nullptr}; @@ -39,40 +42,50 @@ struct ConstraintGraphBuilder // Type packs resolved from type annotations. Analogous to astTypePacks. DenseHashMap astResolvedTypePacks{nullptr}; - explicit ConstraintGraphBuilder(TypeArena* arena); + int recursionCount = 0; + + // It is pretty uncommon for constraint generation to itself produce errors, but it can happen. + std::vector errors; + + // Occasionally constraint generation needs to produce an ICE. + const NotNull ice; + + NotNull globalScope; + + ConstraintGraphBuilder(const ModuleName& moduleName, TypeArena* arena, NotNull ice, NotNull globalScope); /** * Fabricates a new free type belonging to a given scope. - * @param scope the scope the free type belongs to. Must not be null. + * @param scope the scope the free type belongs to. */ - TypeId freshType(Scope2* scope); + TypeId freshType(NotNull scope); /** * Fabricates a new free type pack belonging to a given scope. - * @param scope the scope the free type pack belongs to. Must not be null. + * @param scope the scope the free type pack belongs to. */ - TypePackId freshTypePack(Scope2* scope); + TypePackId freshTypePack(NotNull scope); /** * Fabricates a scope that is a child of another scope. * @param location the lexical extent of the scope in the source code. * @param parent the parent scope of the new scope. Must not be null. */ - Scope2* childScope(Location location, Scope2* parent); + NotNull childScope(Location location, NotNull parent); /** * Adds a new constraint with no dependencies to a given scope. - * @param scope the scope to add the constraint to. Must not be null. + * @param scope the scope to add the constraint to. * @param cv the constraint variant to add. */ - void addConstraint(Scope2* scope, ConstraintV cv); + void addConstraint(NotNull scope, ConstraintV cv); /** * Adds a constraint to a given scope. * @param scope the scope to add the constraint to. Must not be null. * @param c the constraint to add. */ - void addConstraint(Scope2* scope, std::unique_ptr c); + void addConstraint(NotNull scope, std::unique_ptr c); /** * The entry point to the ConstraintGraphBuilder. This will construct a set @@ -81,20 +94,22 @@ struct ConstraintGraphBuilder */ void visit(AstStatBlock* block); - void visit(Scope2* scope, AstStat* stat); - void visit(Scope2* scope, AstStatBlock* block); - void visit(Scope2* scope, AstStatLocal* local); - void visit(Scope2* scope, AstStatLocalFunction* function); - void visit(Scope2* scope, AstStatFunction* function); - void visit(Scope2* scope, AstStatReturn* ret); - void visit(Scope2* scope, AstStatAssign* assign); - void visit(Scope2* scope, AstStatIf* ifStatement); - void visit(Scope2* scope, AstStatTypeAlias* alias); + void visitBlockWithoutChildScope(NotNull scope, AstStatBlock* block); - TypePackId checkExprList(Scope2* scope, const AstArray& exprs); + void visit(NotNull scope, AstStat* stat); + void visit(NotNull scope, AstStatBlock* block); + void visit(NotNull scope, AstStatLocal* local); + void visit(NotNull scope, AstStatLocalFunction* function); + void visit(NotNull scope, AstStatFunction* function); + void visit(NotNull scope, AstStatReturn* ret); + void visit(NotNull scope, AstStatAssign* assign); + void visit(NotNull scope, AstStatIf* ifStatement); + void visit(NotNull scope, AstStatTypeAlias* alias); - TypePackId checkPack(Scope2* scope, AstArray exprs); - TypePackId checkPack(Scope2* scope, AstExpr* expr); + TypePackId checkExprList(NotNull scope, const AstArray& exprs); + + TypePackId checkPack(NotNull scope, AstArray exprs); + TypePackId checkPack(NotNull scope, AstExpr* expr); /** * Checks an expression that is expected to evaluate to one type. @@ -102,19 +117,35 @@ struct ConstraintGraphBuilder * @param expr the expression to check. * @return the type of the expression. */ - TypeId check(Scope2* scope, AstExpr* expr); + TypeId check(NotNull scope, AstExpr* expr); - TypeId checkExprTable(Scope2* scope, AstExprTable* expr); - TypeId check(Scope2* scope, AstExprIndexName* indexName); + TypeId checkExprTable(NotNull scope, AstExprTable* expr); + TypeId check(NotNull scope, AstExprIndexName* indexName); + TypeId check(NotNull scope, AstExprIndexExpr* indexExpr); + TypeId check(NotNull scope, AstExprUnary* unary); + TypeId check(NotNull scope, AstExprBinary* binary); - std::pair checkFunctionSignature(Scope2* parent, AstExprFunction* fn); + struct FunctionSignature + { + // The type of the function. + TypeId signature; + // The scope that encompasses the function's signature. May be nullptr + // if there was no need for a signature scope (the function has no + // generics). + Scope2* signatureScope; + // The scope that encompasses the function's body. Is a child scope of + // signatureScope, if present. + NotNull bodyScope; + }; + + FunctionSignature checkFunctionSignature(NotNull parent, AstExprFunction* fn); /** * Checks the body of a function expression. * @param scope the interior scope of the body of the function. * @param fn the function expression to check. */ - void checkFunctionBody(Scope2* scope, AstExprFunction* fn); + void checkFunctionBody(NotNull scope, AstExprFunction* fn); /** * Resolves a type from its AST annotation. @@ -122,7 +153,7 @@ struct ConstraintGraphBuilder * @param ty the AST annotation to resolve. * @return the type of the AST annotation. **/ - TypeId resolveType(Scope2* scope, AstType* ty); + TypeId resolveType(NotNull scope, AstType* ty); /** * Resolves a type pack from its AST annotation. @@ -130,9 +161,25 @@ struct ConstraintGraphBuilder * @param tp the AST annotation to resolve. * @return the type pack of the AST annotation. **/ - TypePackId resolveTypePack(Scope2* scope, AstTypePack* tp); + TypePackId resolveTypePack(NotNull scope, AstTypePack* tp); - TypePackId resolveTypePack(Scope2* scope, const AstTypeList& list); + TypePackId resolveTypePack(NotNull scope, const AstTypeList& list); + + std::vector> createGenerics(NotNull scope, AstArray generics); + std::vector> createGenericPacks(NotNull scope, AstArray packs); + + TypeId flattenPack(NotNull scope, Location location, TypePackId tp); + + void reportError(Location location, TypeErrorData err); + void reportCodeTooComplex(Location location); + + /** Scan the program for global definitions. + * + * ConstraintGraphBuilder needs to differentiate between globals and accesses to undefined symbols. Doing this "for + * real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an + * initial scan of the AST and note what globals are defined. + */ + void prepopulateGlobalScope(NotNull globalScope, AstStatBlock* program); }; /** @@ -145,6 +192,6 @@ struct ConstraintGraphBuilder * @return a list of pointers to constraints contained within the scope graph. * None of these pointers should be null. */ -std::vector> collectConstraints(Scope2* rootScope); +std::vector> collectConstraints(NotNull rootScope); } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 4870157..cf88efb 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -25,7 +25,7 @@ struct ConstraintSolver // is important to not add elements to this vector, lest the underlying // storage that we retain pointers to be mutated underneath us. const std::vector> constraints; - Scope2* rootScope; + NotNull rootScope; // This includes every constraint that has not been fully solved. // A constraint can be both blocked and unsolved, for instance. @@ -40,7 +40,7 @@ struct ConstraintSolver ConstraintSolverLogger logger; - explicit ConstraintSolver(TypeArena* arena, Scope2* rootScope); + explicit ConstraintSolver(TypeArena* arena, NotNull rootScope); /** * Attempts to dispatch all pending constraints and reach a type solution @@ -50,11 +50,17 @@ struct ConstraintSolver bool done(); + /** Attempt to dispatch a constraint. Returns true if it was successful. + * If tryDispatch() returns false, the constraint remains in the unsolved set and will be retried later. + */ bool tryDispatch(NotNull c, bool force); + bool tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force); bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force); bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force); bool tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const UnaryConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force); bool tryDispatch(const NameConstraint& c, NotNull constraint); void block(NotNull target, NotNull constraint); @@ -115,6 +121,6 @@ private: void unblock_(BlockedConstraintId progressed); }; -void dump(Scope2* rootScope, struct ToStringOptions& opts); +void dump(NotNull rootScope, struct ToStringOptions& opts); } // namespace Luau diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index a132396..4c81d33 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -369,24 +369,25 @@ struct InternalErrorReporter [[noreturn]] void ice(const std::string& message); }; -class InternalCompilerError : public std::exception { +class InternalCompilerError : public std::exception +{ public: - explicit InternalCompilerError(const std::string& message, const std::string& moduleName) - : message(message) - , moduleName(moduleName) - { - } - explicit InternalCompilerError(const std::string& message, const std::string& moduleName, const Location& location) - : message(message) - , moduleName(moduleName) - , location(location) - { - } - virtual const char* what() const throw(); + explicit InternalCompilerError(const std::string& message, const std::string& moduleName) + : message(message) + , moduleName(moduleName) + { + } + explicit InternalCompilerError(const std::string& message, const std::string& moduleName, const Location& location) + : message(message) + , moduleName(moduleName) + , location(location) + { + } + virtual const char* what() const throw(); - const std::string message; - const std::string moduleName; - const std::optional location; + const std::string message; + const std::string moduleName; + const std::optional location; }; } // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index f4226cc..f0d4309 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -5,6 +5,7 @@ #include "Luau/Module.h" #include "Luau/ModuleResolver.h" #include "Luau/RequireTracer.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" @@ -158,6 +159,8 @@ struct Frontend void registerBuiltinDefinition(const std::string& name, std::function); void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); + NotNull getGlobalScope2(); + private: ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope); @@ -173,6 +176,8 @@ private: std::unordered_map environments; std::unordered_map> builtinDefinitions; + std::unique_ptr globalScope2; + public: FileResolver* fileResolver; FrontendModuleResolver moduleResolver; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 39f8dfb..b3105b7 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -68,7 +68,7 @@ struct Module std::shared_ptr allocator; std::shared_ptr names; - std::vector> scopes; // never empty + std::vector> scopes; // never empty std::vector>> scope2s; // never empty DenseHashMap astTypes{nullptr}; diff --git a/Analysis/include/Luau/NotNull.h b/Analysis/include/Luau/NotNull.h index f6043e9..714fa14 100644 --- a/Analysis/include/Luau/NotNull.h +++ b/Analysis/include/Luau/NotNull.h @@ -26,7 +26,7 @@ namespace Luau * The explicit delete statement is permitted (but not recommended) on a * NotNull through this implicit conversion. */ -template +template struct NotNull { explicit NotNull(T* t) @@ -38,10 +38,11 @@ struct NotNull explicit NotNull(std::nullptr_t) = delete; void operator=(std::nullptr_t) = delete; - template + template NotNull(NotNull other) : ptr(other.get()) - {} + { + } operator T*() const noexcept { @@ -72,12 +73,13 @@ private: T* ptr; }; -} +} // namespace Luau namespace std { -template struct hash> +template +struct hash> { size_t operator()(const Luau::NotNull& p) const { @@ -85,4 +87,4 @@ template struct hash> } }; -} +} // namespace std diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index cef4b94..0eaecf1 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -3,6 +3,7 @@ #include "Luau/Constraint.h" #include "Luau/Location.h" +#include "Luau/NotNull.h" #include "Luau/TypeVar.h" #include @@ -71,15 +72,18 @@ struct Scope2 // is the module-level scope). Scope2* parent = nullptr; // All the children of this scope. - std::vector children; + std::vector> children; std::unordered_map bindings; // TODO: I think this can be a DenseHashMap std::unordered_map typeBindings; + std::unordered_map typePackBindings; TypePackId returnType; + std::optional varargPack; // All constraints belonging to this scope. std::vector constraints; std::optional lookup(Symbol sym); std::optional lookupTypeBinding(const Name& name); + std::optional lookupTypePackBinding(const Name& name); }; } // namespace Luau diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index 559c55c..be36f19 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -34,6 +34,12 @@ struct TypeArena TypePackId addTypePack(std::vector types); TypePackId addTypePack(TypePack pack); TypePackId addTypePack(TypePackVar pack); + + template + TypePackId addTypePack(T tp) + { + return addTypePack(TypePackVar(std::move(tp))); + } }; void freeze(TypeArena& arena); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 28adc9d..455654d 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -173,7 +173,7 @@ struct TypeChecker TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, - std::optional originalNameLoc, std::optional expectedType); + std::optional originalNameLoc, std::optional selfType, std::optional expectedType); void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); void checkArgumentList( @@ -424,6 +424,8 @@ private: * (exported, name) to properly deal with the case where the two duplicates do not have the same export status. */ DenseHashSet, HashBoolNamePair> duplicateTypeAliases; + + std::vector> deferredQuantification; }; // Unit test hook diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 20f4107..6ad6b92 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -357,6 +357,9 @@ struct TableTypeVar std::optional boundTo; Tags tags; + + // Methods of this table that have an untyped self will use the same shared self type. + std::optional selfTy; }; // Represents a metatable attached to a table typevar. Somewhat analogous to a bound typevar. diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index d9e8d23..3b9000c 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -1,6 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/ConstraintGraphBuilder.h" +#include "Luau/RecursionCounter.h" +#include "Luau/ToString.h" + +LUAU_FASTINT(LuauCheckRecursionLimit); #include "Luau/Scope.h" @@ -9,32 +13,33 @@ namespace Luau const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp -ConstraintGraphBuilder::ConstraintGraphBuilder(TypeArena* arena) - : singletonTypes(getSingletonTypes()) +ConstraintGraphBuilder::ConstraintGraphBuilder( + const ModuleName& moduleName, TypeArena* arena, NotNull ice, NotNull globalScope) + : moduleName(moduleName) + , singletonTypes(getSingletonTypes()) , arena(arena) , rootScope(nullptr) + , ice(ice) + , globalScope(globalScope) { LUAU_ASSERT(arena); } -TypeId ConstraintGraphBuilder::freshType(Scope2* scope) +TypeId ConstraintGraphBuilder::freshType(NotNull scope) { - LUAU_ASSERT(scope); return arena->addType(FreeTypeVar{scope}); } -TypePackId ConstraintGraphBuilder::freshTypePack(Scope2* scope) +TypePackId ConstraintGraphBuilder::freshTypePack(NotNull scope) { - LUAU_ASSERT(scope); FreeTypePack f{scope}; return arena->addTypePack(TypePackVar{std::move(f)}); } -Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent) +NotNull ConstraintGraphBuilder::childScope(Location location, NotNull parent) { - LUAU_ASSERT(parent); auto scope = std::make_unique(); - Scope2* borrow = scope.get(); + NotNull borrow = NotNull(scope.get()); scopes.emplace_back(location, std::move(scope)); borrow->parent = parent; @@ -44,15 +49,13 @@ Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent) return borrow; } -void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv) +void ConstraintGraphBuilder::addConstraint(NotNull scope, ConstraintV cv) { - LUAU_ASSERT(scope); scope->constraints.emplace_back(new Constraint{std::move(cv)}); } -void ConstraintGraphBuilder::addConstraint(Scope2* scope, std::unique_ptr c) +void ConstraintGraphBuilder::addConstraint(NotNull scope, std::unique_ptr c) { - LUAU_ASSERT(scope); scope->constraints.emplace_back(std::move(c)); } @@ -62,7 +65,11 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) LUAU_ASSERT(rootScope == nullptr); scopes.emplace_back(block->location, std::make_unique()); rootScope = scopes.back().second.get(); - rootScope->returnType = freshTypePack(rootScope); + NotNull borrow = NotNull(rootScope); + + rootScope->returnType = freshTypePack(borrow); + + prepopulateGlobalScope(borrow, block); // TODO: We should share the global scope. rootScope->typeBindings["nil"] = singletonTypes.nilType; @@ -71,12 +78,26 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) rootScope->typeBindings["boolean"] = singletonTypes.booleanType; rootScope->typeBindings["thread"] = singletonTypes.threadType; - visit(rootScope, block); + visitBlockWithoutChildScope(borrow, block); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat) +void ConstraintGraphBuilder::visitBlockWithoutChildScope(NotNull scope, AstStatBlock* block) { - LUAU_ASSERT(scope); + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(block->location); + return; + } + + for (AstStat* stat : block->body) + visit(scope, stat); +} + +void ConstraintGraphBuilder::visit(NotNull scope, AstStat* stat) +{ + RecursionLimiter limiter{&recursionCount, FInt::LuauCheckRecursionLimit}; if (auto s = stat->as()) visit(scope, s); @@ -100,10 +121,8 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat) LUAU_ASSERT(0); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatLocal* local) { - LUAU_ASSERT(scope); - std::vector varTypes; for (AstLocal* local : local->vars) @@ -148,23 +167,19 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) } } -void addConstraints(Constraint* constraint, Scope2* scope) +void addConstraints(Constraint* constraint, NotNull scope) { - LUAU_ASSERT(scope); - scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); for (const auto& c : scope->constraints) constraint->dependencies.push_back(NotNull{c.get()}); - for (Scope2* childScope : scope->children) + for (NotNull childScope : scope->children) addConstraints(constraint, childScope); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatLocalFunction* function) { - LUAU_ASSERT(scope); - // Local // Global // Dotted path @@ -172,36 +187,31 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function TypeId functionType = nullptr; auto ty = scope->lookup(function->name); - if (ty.has_value()) - { - // TODO: This is duplicate definition of a local function. Is this allowed? - functionType = *ty; - } - else - { - functionType = arena->addType(BlockedTypeVar{}); - scope->bindings[function->name] = functionType; - } + LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. - auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func); - innerScope->bindings[function->name] = actualFunctionType; + functionType = arena->addType(BlockedTypeVar{}); + scope->bindings[function->name] = functionType; - checkFunctionBody(innerScope, function->func); + FunctionSignature sig = checkFunctionSignature(scope, function->func); + sig.bodyScope->bindings[function->name] = sig.signature; - std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; - addConstraints(c.get(), innerScope); + checkFunctionBody(sig.bodyScope, function->func); + + std::unique_ptr c{ + new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope : sig.bodyScope}}}; + addConstraints(c.get(), sig.bodyScope); addConstraint(scope, std::move(c)); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatFunction* function) { // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. // With or without self TypeId functionType = nullptr; - auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func); + FunctionSignature sig = checkFunctionSignature(scope, function->func); if (AstExprLocal* localName = function->name->as()) { @@ -216,7 +226,7 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) functionType = arena->addType(BlockedTypeVar{}); scope->bindings[localName->local] = functionType; } - innerScope->bindings[localName->local] = actualFunctionType; + sig.bodyScope->bindings[localName->local] = sig.signature; } else if (AstExprGlobal* globalName = function->name->as()) { @@ -231,32 +241,48 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) functionType = arena->addType(BlockedTypeVar{}); rootScope->bindings[globalName->name] = functionType; } - innerScope->bindings[globalName->name] = actualFunctionType; + sig.bodyScope->bindings[globalName->name] = sig.signature; } else if (AstExprIndexName* indexName = function->name->as()) { - LUAU_ASSERT(0); // not yet implemented + TypeId containingTableType = check(scope, indexName->expr); + + functionType = arena->addType(BlockedTypeVar{}); + TypeId prospectiveTableType = + arena->addType(TableTypeVar{}); // TODO look into stack utilization. This is probably ok because it scales with AST depth. + NotNull prospectiveTable{getMutable(prospectiveTableType)}; + + Property& prop = prospectiveTable->props[indexName->index.value]; + prop.type = functionType; + prop.location = function->name->location; + + addConstraint(scope, SubtypeConstraint{containingTableType, prospectiveTableType}); + } + else if (AstExprError* err = function->name->as()) + { + functionType = singletonTypes.errorRecoveryType(); } - checkFunctionBody(innerScope, function->func); + LUAU_ASSERT(functionType != nullptr); - std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; - addConstraints(c.get(), innerScope); + checkFunctionBody(sig.bodyScope, function->func); + + std::unique_ptr c{ + new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope : sig.bodyScope}}}; + addConstraints(c.get(), sig.bodyScope); addConstraint(scope, std::move(c)); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatReturn* ret) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatReturn* ret) { - LUAU_ASSERT(scope); - TypePackId exprTypes = checkPack(scope, ret->list); addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatBlock* block) { - LUAU_ASSERT(scope); + NotNull innerScope = childScope(block->location, scope); // In order to enable mutually-recursive type aliases, we need to // populate the type bindings before we actually check any of the @@ -271,11 +297,10 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) } } - for (AstStat* stat : block->body) - visit(scope, stat); + visitBlockWithoutChildScope(innerScope, block); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatAssign* assign) { TypePackId varPackId = checkExprList(scope, assign->vars); TypePackId valuePack = checkPack(scope, assign->values); @@ -283,21 +308,21 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign) addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatIf* ifStatement) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatIf* ifStatement) { check(scope, ifStatement->condition); - Scope2* thenScope = childScope(ifStatement->thenbody->location, scope); + NotNull thenScope = childScope(ifStatement->thenbody->location, scope); visit(thenScope, ifStatement->thenbody); if (ifStatement->elsebody) { - Scope2* elseScope = childScope(ifStatement->elsebody->location, scope); + NotNull elseScope = childScope(ifStatement->elsebody->location, scope); visit(elseScope, ifStatement->elsebody); } } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatTypeAlias* alias) { // TODO: Exported type aliases // TODO: Generic type aliases @@ -307,6 +332,10 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias) // AST to set up typeBindings. If it's not, we've somehow skipped // this alias in that first pass. LUAU_ASSERT(it != scope->typeBindings.end()); + if (it == scope->typeBindings.end()) + { + ice->ice("Type alias does not have a pre-populated binding", alias->location); + } TypeId ty = resolveType(scope, alias->type); @@ -319,10 +348,8 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias) addConstraint(scope, NameConstraint{ty, alias->name.value}); } -TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray exprs) +TypePackId ConstraintGraphBuilder::checkPack(NotNull scope, AstArray exprs) { - LUAU_ASSERT(scope); - if (exprs.size == 0) return arena->addTypePack({}); @@ -342,7 +369,7 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray e return arena->addTypePack(TypePack{std::move(types), last}); } -TypePackId ConstraintGraphBuilder::checkExprList(Scope2* scope, const AstArray& exprs) +TypePackId ConstraintGraphBuilder::checkExprList(NotNull scope, const AstArray& exprs) { TypePackId result = arena->addTypePack({}); TypePack* resultPack = getMutable(result); @@ -363,9 +390,15 @@ TypePackId ConstraintGraphBuilder::checkExprList(Scope2* scope, const AstArray scope, AstExpr* expr) { - LUAU_ASSERT(scope); + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(expr->location); + return singletonTypes.errorRecoveryTypePack(); + } TypePackId result = nullptr; @@ -384,7 +417,7 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) astOriginalCallTypes[call->func] = fnType; - TypeId instantiatedType = freshType(scope); + TypeId instantiatedType = arena->addType(BlockedTypeVar{}); addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}); TypePackId rets = freshTypePack(scope); @@ -394,6 +427,13 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}); result = rets; } + else if (AstExprVarargs* varargs = expr->as()) + { + if (scope->varargPack) + result = *scope->varargPack; + else + result = singletonTypes.errorRecoveryTypePack(); + } else { TypeId t = check(scope, expr); @@ -405,9 +445,15 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) return result; } -TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExpr* expr) { - LUAU_ASSERT(scope); + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(expr->location); + return singletonTypes.errorRecoveryType(); + } TypeId result = nullptr; @@ -435,37 +481,38 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) if (ty) result = *ty; else - result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? - } - else if (auto a = expr->as()) - { - TypePackId packResult = checkPack(scope, expr); - if (auto f = first(packResult)) - return *f; - else if (get(packResult)) { - TypeId typeResult = freshType(scope); - TypePack onePack{{typeResult}, freshTypePack(scope)}; - TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); - - addConstraint(scope, PackSubtypeConstraint{packResult, oneTypePack}); - - return typeResult; + /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any + * global that is not already in-scope is definitely an unknown symbol. + */ + reportError(g->location, UnknownSymbol{g->name.value}); + result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? } } + else if (expr->is()) + result = flattenPack(scope, expr->location, checkPack(scope, expr)); + else if (expr->is()) + result = flattenPack(scope, expr->location, checkPack(scope, expr)); else if (auto a = expr->as()) { - auto [fnType, functionScope] = checkFunctionSignature(scope, a); - checkFunctionBody(functionScope, a); - return fnType; + FunctionSignature sig = checkFunctionSignature(scope, a); + checkFunctionBody(sig.bodyScope, a); + return sig.signature; } else if (auto indexName = expr->as()) - { result = check(scope, indexName); - } + else if (auto indexExpr = expr->as()) + result = check(scope, indexExpr); else if (auto table = expr->as()) - { result = checkExprTable(scope, table); + else if (auto unary = expr->as()) + result = check(scope, unary); + else if (auto binary = expr->as()) + result = check(scope, binary); + else if (auto err = expr->as()) + { + // Open question: Should we traverse into this? + result = singletonTypes.errorRecoveryType(); } else { @@ -478,7 +525,7 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) return result; } -TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName) +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprIndexName* indexName) { TypeId obj = check(scope, indexName->expr); TypeId result = freshType(scope); @@ -494,7 +541,67 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName) return result; } -TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprIndexExpr* indexExpr) +{ + TypeId obj = check(scope, indexExpr->expr); + TypeId indexType = check(scope, indexExpr->index); + + TypeId result = freshType(scope); + + TableIndexer indexer{indexType, result}; + TypeId tableType = arena->addType(TableTypeVar{TableTypeVar::Props{}, TableIndexer{indexType, result}, TypeLevel{}, TableState::Free}); + + addConstraint(scope, SubtypeConstraint{obj, tableType}); + + return result; +} + +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprUnary* unary) +{ + TypeId operandType = check(scope, unary->expr); + + switch (unary->op) + { + case AstExprUnary::Minus: + { + TypeId resultType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, UnaryConstraint{AstExprUnary::Minus, operandType, resultType}); + return resultType; + } + default: + LUAU_ASSERT(0); + } + + LUAU_UNREACHABLE(); + return singletonTypes.errorRecoveryType(); +} + +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprBinary* binary) +{ + TypeId leftType = check(scope, binary->left); + TypeId rightType = check(scope, binary->right); + switch (binary->op) + { + case AstExprBinary::Or: + { + addConstraint(scope, SubtypeConstraint{leftType, rightType}); + return leftType; + } + case AstExprBinary::Sub: + { + TypeId resultType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, BinaryConstraint{AstExprBinary::Sub, leftType, rightType, resultType}); + return resultType; + } + default: + LUAU_ASSERT(0); + } + + LUAU_ASSERT(0); + return nullptr; +} + +TypeId ConstraintGraphBuilder::checkExprTable(NotNull scope, AstExprTable* expr) { TypeId ty = arena->addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(ty); @@ -515,6 +622,8 @@ TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) for (const AstExprTable::Item& item : expr->items) { TypeId itemTy = check(scope, item.value); + if (get(follow(itemTy))) + return ty; if (item.key) { @@ -542,47 +651,111 @@ TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) return ty; } -std::pair ConstraintGraphBuilder::checkFunctionSignature(Scope2* parent, AstExprFunction* fn) +ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature(NotNull parent, AstExprFunction* fn) { - Scope2* innerScope = childScope(fn->body->location, parent); - TypePackId returnType = freshTypePack(innerScope); - innerScope->returnType = returnType; + Scope2* signatureScope = nullptr; + Scope2* bodyScope = nullptr; + TypePackId returnType = nullptr; + + std::vector genericTypes; + std::vector genericTypePacks; + + bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; + + // If we don't have any generics, we can save some memory and compute by not + // creating the signatureScope, which is only used to scope the declared + // generics properly. + if (hasGenerics) + { + NotNull signatureBorrow = childScope(fn->location, parent); + signatureScope = signatureBorrow.get(); + + // We need to assign returnType before creating bodyScope so that the + // return type gets propogated to bodyScope. + returnType = freshTypePack(signatureBorrow); + signatureScope->returnType = returnType; + + bodyScope = childScope(fn->body->location, signatureBorrow).get(); + + std::vector> genericDefinitions = createGenerics(signatureBorrow, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureBorrow, fn->genericPacks); + + // We do not support default values on function generics, so we only + // care about the types involved. + for (const auto& [name, g] : genericDefinitions) + { + genericTypes.push_back(g.ty); + signatureScope->typeBindings[name] = g.ty; + } + + for (const auto& [name, g] : genericPackDefinitions) + { + genericTypePacks.push_back(g.tp); + signatureScope->typePackBindings[name] = g.tp; + } + } + else + { + NotNull bodyBorrow = childScope(fn->body->location, parent); + bodyScope = bodyBorrow.get(); + + returnType = freshTypePack(bodyBorrow); + bodyBorrow->returnType = returnType; + + // To eliminate the need to branch on hasGenerics below, we say that the + // signature scope is the body scope when there is no real signature + // scope. + signatureScope = bodyScope; + } + + NotNull bodyBorrow = NotNull(bodyScope); + NotNull signatureBorrow = NotNull(signatureScope); if (fn->returnAnnotation) { - TypePackId annotatedRetType = resolveTypePack(innerScope, *fn->returnAnnotation); - addConstraint(innerScope, PackSubtypeConstraint{returnType, annotatedRetType}); + TypePackId annotatedRetType = resolveTypePack(signatureBorrow, *fn->returnAnnotation); + addConstraint(signatureBorrow, PackSubtypeConstraint{returnType, annotatedRetType}); } std::vector argTypes; for (AstLocal* local : fn->args) { - TypeId t = freshType(innerScope); + TypeId t = freshType(signatureBorrow); argTypes.push_back(t); - innerScope->bindings[local] = t; + signatureScope->bindings[local] = t; if (local->annotation) { - TypeId argAnnotation = resolveType(innerScope, local->annotation); - addConstraint(innerScope, SubtypeConstraint{t, argAnnotation}); + TypeId argAnnotation = resolveType(signatureBorrow, local->annotation); + addConstraint(signatureBorrow, SubtypeConstraint{t, argAnnotation}); } } // TODO: Vararg annotation. + // TODO: Preserve argument names in the function's type. FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; + actualFunction.hasNoGenerics = !hasGenerics; + actualFunction.generics = std::move(genericTypes); + actualFunction.genericPacks = std::move(genericTypePacks); + TypeId actualFunctionType = arena->addType(std::move(actualFunction)); LUAU_ASSERT(actualFunctionType); astTypes[fn] = actualFunctionType; - return {actualFunctionType, innerScope}; + return { + /* signature */ actualFunctionType, + // Undo the workaround we made above: if there's no signature scope, + // don't report it. + /* signatureScope */ hasGenerics ? signatureScope : nullptr, + /* bodyScope */ bodyBorrow, + }; } -void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* fn) +void ConstraintGraphBuilder::checkFunctionBody(NotNull scope, AstExprFunction* fn) { - for (AstStat* stat : fn->body->body) - visit(scope, stat); + visitBlockWithoutChildScope(scope, fn->body); // If it is possible for execution to reach the end of the function, the return type must be compatible with () @@ -593,7 +766,7 @@ void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* f } } -TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty) +TypeId ConstraintGraphBuilder::resolveType(NotNull scope, AstType* ty) { TypeId result = nullptr; @@ -636,29 +809,73 @@ TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty) } else if (auto fn = ty->as()) { - // TODO: Generic functions. - // TODO: Scope (though it may not be needed). // TODO: Recursion limit. - TypePackId argTypes = resolveTypePack(scope, fn->argTypes); - TypePackId returnTypes = resolveTypePack(scope, fn->returnTypes); + bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; + Scope2* signatureScope = nullptr; - // TODO: Is this the right constructor to use? - result = arena->addType(FunctionTypeVar{argTypes, returnTypes}); + std::vector genericTypes; + std::vector genericTypePacks; - FunctionTypeVar* ftv = getMutable(result); - ftv->argNames.reserve(fn->argNames.size); + // If we don't have generics, we do not need to generate a child scope + // for the generic bindings to live on. + if (hasGenerics) + { + NotNull signatureBorrow = childScope(fn->location, scope); + signatureScope = signatureBorrow.get(); + + std::vector> genericDefinitions = createGenerics(signatureBorrow, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureBorrow, fn->genericPacks); + + for (const auto& [name, g] : genericDefinitions) + { + genericTypes.push_back(g.ty); + signatureBorrow->typeBindings[name] = g.ty; + } + + for (const auto& [name, g] : genericPackDefinitions) + { + genericTypePacks.push_back(g.tp); + signatureBorrow->typePackBindings[name] = g.tp; + } + } + else + { + // To eliminate the need to branch on hasGenerics below, we say that + // the signature scope is the parent scope if we don't have + // generics. + signatureScope = scope.get(); + } + + NotNull signatureBorrow(signatureScope); + + TypePackId argTypes = resolveTypePack(signatureBorrow, fn->argTypes); + TypePackId returnTypes = resolveTypePack(signatureBorrow, fn->returnTypes); + + // TODO: FunctionTypeVar needs a pointer to the scope so that we know + // how to quantify/instantiate it. + FunctionTypeVar ftv{argTypes, returnTypes}; + + // This replicates the behavior of the appropriate FunctionTypeVar + // constructors. + ftv.hasNoGenerics = !hasGenerics; + ftv.generics = std::move(genericTypes); + ftv.genericPacks = std::move(genericTypePacks); + + ftv.argNames.reserve(fn->argNames.size); for (const auto& el : fn->argNames) { if (el) { const auto& [name, location] = *el; - ftv->argNames.push_back(FunctionArgument{name.value, location}); + ftv.argNames.push_back(FunctionArgument{name.value, location}); } else { - ftv->argNames.push_back(std::nullopt); + ftv.argNames.push_back(std::nullopt); } } + + result = arena->addType(std::move(ftv)); } else if (auto tof = ty->as()) { @@ -710,7 +927,7 @@ TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty) return result; } -TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, AstTypePack* tp) +TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull scope, AstTypePack* tp) { TypePackId result; if (auto expl = tp->as()) @@ -736,7 +953,7 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, AstTypePack* t return result; } -TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, const AstTypeList& list) +TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull scope, const AstTypeList& list) { std::vector head; @@ -754,16 +971,108 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, const AstTypeL return arena->addTypePack(TypePack{head, tail}); } -void collectConstraints(std::vector>& result, Scope2* scope) +std::vector> ConstraintGraphBuilder::createGenerics(NotNull scope, AstArray generics) +{ + std::vector> result; + for (const auto& generic : generics) + { + TypeId genericTy = arena->addType(GenericTypeVar{scope, generic.name.value}); + std::optional defaultTy = std::nullopt; + + if (generic.defaultValue) + defaultTy = resolveType(scope, generic.defaultValue); + + result.push_back({generic.name.value, GenericTypeDefinition{ + genericTy, + defaultTy, + }}); + } + + return result; +} + +std::vector> ConstraintGraphBuilder::createGenericPacks( + NotNull scope, AstArray generics) +{ + std::vector> result; + for (const auto& generic : generics) + { + TypePackId genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope, generic.name.value}}); + std::optional defaultTy = std::nullopt; + + if (generic.defaultValue) + defaultTy = resolveTypePack(scope, generic.defaultValue); + + result.push_back({generic.name.value, GenericTypePackDefinition{ + genericTy, + defaultTy, + }}); + } + + return result; +} + +TypeId ConstraintGraphBuilder::flattenPack(NotNull scope, Location location, TypePackId tp) +{ + if (auto f = first(tp)) + return *f; + + TypeId typeResult = freshType(scope); + TypePack onePack{{typeResult}, freshTypePack(scope)}; + TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); + + addConstraint(scope, PackSubtypeConstraint{tp, oneTypePack}); + + return typeResult; +} + +void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) +{ + errors.push_back(TypeError{location, moduleName, std::move(err)}); +} + +void ConstraintGraphBuilder::reportCodeTooComplex(Location location) +{ + errors.push_back(TypeError{location, moduleName, CodeTooComplex{}}); +} + +struct GlobalPrepopulator : AstVisitor +{ + const NotNull globalScope; + const NotNull arena; + + GlobalPrepopulator(NotNull globalScope, NotNull arena) + : globalScope(globalScope) + , arena(arena) + { + } + + bool visit(AstStatFunction* function) override + { + if (AstExprGlobal* g = function->name->as()) + globalScope->bindings[g->name] = arena->addType(BlockedTypeVar{}); + + return true; + } +}; + +void ConstraintGraphBuilder::prepopulateGlobalScope(NotNull globalScope, AstStatBlock* program) +{ + GlobalPrepopulator gp{NotNull{globalScope}, arena}; + + program->visit(&gp); +} + +void collectConstraints(std::vector>& result, NotNull scope) { for (const auto& c : scope->constraints) result.push_back(NotNull{c.get()}); - for (Scope2* child : scope->children) + for (NotNull child : scope->children) collectConstraints(result, child); } -std::vector> collectConstraints(Scope2* rootScope) +std::vector> collectConstraints(NotNull rootScope) { std::vector> result; collectConstraints(result, rootScope); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 9e35523..077a4e2 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -13,7 +13,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau { -[[maybe_unused]] static void dumpBindings(Scope2* scope, ToStringOptions& opts) +[[maybe_unused]] static void dumpBindings(NotNull scope, ToStringOptions& opts) { for (const auto& [k, v] : scope->bindings) { @@ -22,22 +22,22 @@ namespace Luau printf("\t%s : %s\n", k.c_str(), d.name.c_str()); } - for (Scope2* child : scope->children) + for (NotNull child : scope->children) dumpBindings(child, opts); } -static void dumpConstraints(Scope2* scope, ToStringOptions& opts) +static void dumpConstraints(NotNull scope, ToStringOptions& opts) { for (const ConstraintPtr& c : scope->constraints) { printf("\t%s\n", toString(*c, opts).c_str()); } - for (Scope2* child : scope->children) + for (NotNull child : scope->children) dumpConstraints(child, opts); } -void dump(Scope2* rootScope, ToStringOptions& opts) +void dump(NotNull rootScope, ToStringOptions& opts) { printf("constraints:\n"); dumpConstraints(rootScope, opts); @@ -55,7 +55,7 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) } } -ConstraintSolver::ConstraintSolver(TypeArena* arena, Scope2* rootScope) +ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull rootScope) : arena(arena) , constraints(collectConstraints(rootScope)) , rootScope(rootScope) @@ -180,6 +180,10 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*gc, constraint, force); else if (auto ic = get(*constraint)) success = tryDispatch(*ic, constraint, force); + else if (auto uc = get(*constraint)) + success = tryDispatch(*uc, constraint, force); + else if (auto bc = get(*constraint)) + success = tryDispatch(*bc, constraint, force); else if (auto nc = get(*constraint)) success = tryDispatch(*nc, constraint); else @@ -246,12 +250,65 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull instantiated = inst.substitute(c.superType); LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS - unify(c.subType, *instantiated); + if (isBlocked(c.subType)) + asMutable(c.subType)->ty.emplace(*instantiated); + else + unify(c.subType, *instantiated); + unblock(c.subType); return true; } +bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull constraint, bool force) +{ + TypeId operandType = follow(c.operandType); + + if (isBlocked(operandType)) + return block(operandType, constraint); + + if (get(operandType)) + return block(operandType, constraint); + + LUAU_ASSERT(get(c.resultType)); + + if (isNumber(operandType) || get(operandType) || get(operandType)) + { + asMutable(c.resultType)->ty.emplace(c.operandType); + return true; + } + + LUAU_ASSERT(0); // TODO metatable handling + return false; +} + +bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force) +{ + TypeId leftType = follow(c.leftType); + TypeId rightType = follow(c.rightType); + + if (isBlocked(leftType) || isBlocked(rightType)) + { + block(leftType, constraint); + block(rightType, constraint); + return false; + } + + if (isNumber(leftType)) + { + unify(leftType, rightType); + asMutable(c.resultType)->ty.emplace(leftType); + return true; + } + + if (get(leftType) && !force) + return block(leftType, constraint); + + // TODO metatables, classes + + return true; +} + bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull constraint) { if (isBlocked(c.namedType)) diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 2407e3e..1b5275f 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAG(LuauCheckLenMT) + namespace Luau { @@ -202,7 +204,13 @@ declare function unpack(tab: {V}, i: number?, j: number?): ...V std::string getBuiltinDefinitionSource() { - return kBuiltinDefinitionLuaSrc; + std::string result = kBuiltinDefinitionLuaSrc; + + // TODO: move this into kBuiltinDefinitionLuaSrc + if (FFlag::LuauCheckLenMT) + result += "declare function rawlen(obj: {[K]: V} | string): number\n"; + + return result; } } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 85c5dbc..4cfaa11 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -787,14 +787,32 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons return const_cast(this)->getSourceModule(moduleName); } +NotNull Frontend::getGlobalScope2() +{ + if (!globalScope2) + { + const SingletonTypes& singletonTypes = getSingletonTypes(); + + globalScope2 = std::make_unique(); + globalScope2->typeBindings["nil"] = singletonTypes.nilType; + globalScope2->typeBindings["number"] = singletonTypes.numberType; + globalScope2->typeBindings["string"] = singletonTypes.stringType; + globalScope2->typeBindings["boolean"] = singletonTypes.booleanType; + globalScope2->typeBindings["thread"] = singletonTypes.threadType; + } + + return NotNull(globalScope2.get()); +} + ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope) { ModulePtr result = std::make_shared(); - ConstraintGraphBuilder cgb{&result->internalTypes}; + ConstraintGraphBuilder cgb{sourceModule.name, &result->internalTypes, NotNull(&iceHandler), getGlobalScope2()}; cgb.visit(sourceModule.root); + result->errors = std::move(cgb.errors); - ConstraintSolver cs{&result->internalTypes, cgb.rootScope}; + ConstraintSolver cs{&result->internalTypes, NotNull(cgb.rootScope)}; cs.run(); result->scope2s = std::move(cgb.scopes); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index d36665e..8ce7f74 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -5,7 +5,6 @@ #include #include "Luau/Clone.h" -#include "Luau/Substitution.h" #include "Luau/Unifier.h" #include "Luau/VisitTypeVar.h" @@ -16,7 +15,6 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false); -LUAU_FASTFLAGVARIABLE(LuauReplaceReplacer, false); LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau @@ -25,238 +23,33 @@ namespace Luau namespace { -struct Replacer : Substitution +struct Replacer { + TypeArena* arena; TypeId sourceType; TypeId replacedType; - DenseHashMap replacedTypes{nullptr}; - DenseHashMap replacedPacks{nullptr}; + DenseHashMap newTypes; Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType) - : Substitution(TxnLog::empty(), arena) + : arena(arena) , sourceType(sourceType) , replacedType(replacedType) + , newTypes(nullptr) { } - bool isDirty(TypeId ty) override - { - if (!sourceType) - return false; - - auto vecHasSourceType = [sourceType = sourceType](const auto& vec) { - return end(vec) != std::find(begin(vec), end(vec), sourceType); - }; - - // Walk every kind of TypeVar and find pointers to sourceType - if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return vecHasSourceType(t->parts); - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - { - if (vecHasSourceType(t->generics)) - return true; - - return false; - } - else if (auto t = get(ty)) - { - if (t->boundTo) - return *t->boundTo == sourceType; - - for (const auto& [_name, prop] : t->props) - { - if (prop.type == sourceType) - return true; - } - - if (auto indexer = t->indexer) - { - if (indexer->indexType == sourceType || indexer->indexResultType == sourceType) - return true; - } - - if (vecHasSourceType(t->instantiatedTypeParams)) - return true; - - return false; - } - else if (auto t = get(ty)) - return t->table == sourceType || t->metatable == sourceType; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return vecHasSourceType(t->options); - else if (auto t = get(ty)) - return vecHasSourceType(t->parts); - else if (auto t = get(ty)) - return false; - - LUAU_ASSERT(!"Luau::Replacer::isDirty internal error: Unknown TypeVar type"); - LUAU_UNREACHABLE(); - } - - bool isDirty(TypePackId tp) override - { - if (auto it = replacedPacks.find(tp)) - return false; - - if (auto pack = get(tp)) - { - for (TypeId ty : pack->head) - { - if (ty == sourceType) - return true; - } - return false; - } - else if (auto vtp = get(tp)) - return vtp->ty == sourceType; - else - return false; - } - - TypeId clean(TypeId ty) override - { - LUAU_ASSERT(sourceType && replacedType); - - // Walk every kind of TypeVar and create a copy with sourceType replaced by replacedType - // Before returning, memoize the result for later use. - - // Helpfully, Substitution::clone() only shallow-clones the kinds of types that we care to work with. This - // function returns the identity for things like primitives. - TypeId res = clone(ty); - - if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = getMutable(res)) - { - for (TypeId& part : t->parts) - { - if (part == sourceType) - part = replacedType; - } - } - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = getMutable(res)) - { - // The constituent typepacks are cleaned separately. We just need to walk the generics array. - for (TypeId& g : t->generics) - { - if (g == sourceType) - g = replacedType; - } - } - else if (auto t = getMutable(res)) - { - for (auto& [_key, prop] : t->props) - { - if (prop.type == sourceType) - prop.type = replacedType; - } - } - else if (auto t = getMutable(res)) - { - if (t->table == sourceType) - t->table = replacedType; - if (t->metatable == sourceType) - t->table = replacedType; - } - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = getMutable(res)) - { - for (TypeId& option : t->options) - { - if (option == sourceType) - option = replacedType; - } - } - else if (auto t = getMutable(res)) - { - for (TypeId& part : t->parts) - { - if (part == sourceType) - part = replacedType; - } - } - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else - LUAU_ASSERT(!"Luau::Replacer::clean internal error: Unknown TypeVar type"); - - replacedTypes[ty] = res; - return res; - } - - TypePackId clean(TypePackId tp) override - { - TypePackId res = clone(tp); - - if (auto pack = getMutable(res)) - { - for (TypeId& type : pack->head) - { - if (type == sourceType) - type = replacedType; - } - } - else if (auto vtp = getMutable(res)) - { - if (vtp->ty == sourceType) - vtp->ty = replacedType; - } - - replacedPacks[tp] = res; - return res; - } - TypeId smartClone(TypeId t) { - if (FFlag::LuauReplaceReplacer) - { - // The new smartClone is just a memoized clone() - // TODO: Remove the Substitution base class and all other methods from this struct. - // Add DenseHashMap newTypes; - t = log->follow(t); - TypeId* res = newTypes.find(t); - if (res) - return *res; - - TypeId result = shallowClone(t, *arena, TxnLog::empty()); - newTypes[t] = result; - newTypes[result] = result; - - return result; - } - else - { - std::optional res = replace(t); - LUAU_ASSERT(res.has_value()); // TODO think about this - if (*res == t) - return clone(t); + t = follow(t); + TypeId* res = newTypes.find(t); + if (res) return *res; - } + + TypeId result = shallowClone(t, *arena, TxnLog::empty()); + newTypes[t] = result; + newTypes[result] = result; + + return result; } }; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 40e14c6..294c479 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -8,6 +8,7 @@ #include "Luau/VisitTypeVar.h" LUAU_FASTFLAG(LuauAlwaysQuantify); +LUAU_FASTFLAG(DebugLuauSharedSelf) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAGVARIABLE(LuauQuantifyConstrained, false) @@ -158,24 +159,61 @@ struct Quantifier final : TypeVarOnceVisitor void quantify(TypeId ty, TypeLevel level) { - Quantifier q{level}; - q.traverse(ty); - - FunctionTypeVar* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - if (FFlag::LuauAlwaysQuantify) + if (FFlag::DebugLuauSharedSelf) { - ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + ty = follow(ty); + + if (auto ttv = getTableType(ty); ttv && ttv->selfTy) + { + Quantifier selfQ{level}; + selfQ.traverse(*ttv->selfTy); + + Quantifier q{level}; + q.traverse(ty); + + for (const auto& [_, prop] : ttv->props) + { + auto ftv = getMutable(follow(prop.type)); + if (!ftv || !ftv->hasSelf) + continue; + + if (Luau::first(ftv->argTypes) == ttv->selfTy) + { + ftv->generics.insert(ftv->generics.end(), selfQ.generics.begin(), selfQ.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), selfQ.genericPacks.begin(), selfQ.genericPacks.end()); + } + } + } + else if (auto ftv = getMutable(ty)) + { + Quantifier q{level}; + q.traverse(ty); + + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + + if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) + ftv->hasNoGenerics = true; + } } else { - ftv->generics = q.generics; - ftv->genericPacks = q.genericPacks; - } + Quantifier q{level}; + q.traverse(ty); - if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) - ftv->hasNoGenerics = true; + FunctionTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + if (FFlag::LuauAlwaysQuantify) + { + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + } + else + { + ftv->generics = q.generics; + ftv->genericPacks = q.genericPacks; + } + } } void quantify(TypeId ty, Scope2* scope) @@ -206,8 +244,8 @@ struct PureQuantifier : Substitution std::vector insertedGenerics; std::vector insertedGenericPacks; - PureQuantifier(const TxnLog* log, TypeArena* arena, Scope2* scope) - : Substitution(log, arena) + PureQuantifier(TypeArena* arena, Scope2* scope) + : Substitution(TxnLog::empty(), arena) , scope(scope) { } @@ -286,7 +324,7 @@ struct PureQuantifier : Substitution TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope) { - PureQuantifier quantifier{TxnLog::empty(), arena, scope}; + PureQuantifier quantifier{arena, scope}; std::optional result = quantifier.substitute(ty); LUAU_ASSERT(result); @@ -294,8 +332,7 @@ TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope) LUAU_ASSERT(ftv); ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end()); ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end()); - - // TODO: Set hasNoGenerics. + ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty(); return *result; } diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 66aaee1..247a9dd 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -153,4 +153,19 @@ std::optional Scope2::lookupTypeBinding(const Name& name) return std::nullopt; } +std::optional Scope2::lookupTypePackBinding(const Name& name) +{ + Scope2* s = this; + while (s) + { + auto it = s->typePackBindings.find(name); + if (it != s->typePackBindings.end()) + return it->second; + + s = s->parent; + } + + return std::nullopt; +} + } // namespace Luau diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index eb7b9cd..fe940d5 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1367,51 +1367,74 @@ std::string generateName(size_t i) return n; } -std::string toString(const Constraint& c, ToStringOptions& opts) +std::string toString(const Constraint& constraint, ToStringOptions& opts) { - if (const SubtypeConstraint* sc = Luau::get_if(&c.c)) - { - ToStringResult subStr = toStringDetailed(sc->subType, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(sc->superType, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " <: " + superStr.name; - } - else if (const PackSubtypeConstraint* psc = Luau::get_if(&c.c)) - { - ToStringResult subStr = toStringDetailed(psc->subPack, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(psc->superPack, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " <: " + superStr.name; - } - else if (const GeneralizationConstraint* gc = Luau::get_if(&c.c)) - { - ToStringResult subStr = toStringDetailed(gc->generalizedType, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(gc->sourceType, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " ~ gen " + superStr.name; - } - else if (const InstantiationConstraint* ic = Luau::get_if(&c.c)) - { - ToStringResult subStr = toStringDetailed(ic->subType, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(ic->superType, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " ~ inst " + superStr.name; - } - else if (const NameConstraint* nc = Luau::get(c)) - { - ToStringResult namedStr = toStringDetailed(nc->namedType, opts); - opts.nameMap = std::move(namedStr.nameMap); - return "@name(" + namedStr.name + ") = " + nc->name; - } - else - { - LUAU_ASSERT(false); - return ""; - } + auto go = [&opts](auto&& c) { + using T = std::decay_t; + + if constexpr (std::is_same_v) + { + ToStringResult subStr = toStringDetailed(c.subType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(c.superType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " <: " + superStr.name; + } + else if constexpr (std::is_same_v) + { + ToStringResult subStr = toStringDetailed(c.subPack, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(c.superPack, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " <: " + superStr.name; + } + else if constexpr (std::is_same_v) + { + ToStringResult subStr = toStringDetailed(c.generalizedType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(c.sourceType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " ~ gen " + superStr.name; + } + else if constexpr (std::is_same_v) + { + ToStringResult subStr = toStringDetailed(c.subType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(c.superType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " ~ inst " + superStr.name; + } + else if constexpr (std::is_same_v) + { + ToStringResult resultStr = toStringDetailed(c.resultType, opts); + opts.nameMap = std::move(resultStr.nameMap); + ToStringResult operandStr = toStringDetailed(c.operandType, opts); + opts.nameMap = std::move(operandStr.nameMap); + + return resultStr.name + " ~ Unary<" + toString(c.op) + ", " + operandStr.name + ">"; + } + else if constexpr (std::is_same_v) + { + ToStringResult resultStr = toStringDetailed(c.resultType); + opts.nameMap = std::move(resultStr.nameMap); + ToStringResult leftStr = toStringDetailed(c.leftType); + opts.nameMap = std::move(leftStr.nameMap); + ToStringResult rightStr = toStringDetailed(c.rightType); + opts.nameMap = std::move(rightStr.nameMap); + + return resultStr.name + " ~ Binary<" + toString(c.op) + ", " + leftStr.name + ", " + rightStr.name + ">"; + } + else if constexpr (std::is_same_v) + { + ToStringResult namedStr = toStringDetailed(c.namedType, opts); + opts.nameMap = std::move(namedStr.nameMap); + return "@name(" + namedStr.name + ") = " + c.name; + } + else + static_assert(always_false_v, "Non-exhaustive constraint switch"); + }; + + return visit(go, constraint.c); } std::string dump(const Constraint& c) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 63e5800..30e498a 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -6,8 +6,10 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" #include "Luau/Clone.h" +#include "Luau/Instantiation.h" #include "Luau/Normalize.h" -#include "Luau/ConstraintGraphBuilder.h" // FIXME move Scope2 into its own header +#include "Luau/TxnLog.h" +#include "Luau/TypeUtils.h" #include "Luau/Unifier.h" #include "Luau/ToString.h" @@ -19,10 +21,12 @@ struct TypeChecker2 : public AstVisitor const SourceModule* sourceModule; Module* module; InternalErrorReporter ice; // FIXME accept a pointer from Frontend + SingletonTypes& singletonTypes; TypeChecker2(const SourceModule* sourceModule, Module* module) : sourceModule(sourceModule) , module(module) + , singletonTypes(getSingletonTypes()) { } @@ -30,16 +34,30 @@ struct TypeChecker2 : public AstVisitor TypePackId lookupPack(AstExpr* expr) { + // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. + // We'll just return anyType in these cases. Typechecking against any is very fast and this + // allows us not to think about this very much in the actual typechecking logic. TypePackId* tp = module->astTypePacks.find(expr); - LUAU_ASSERT(tp); - return follow(*tp); + if (tp) + return follow(*tp); + else + return singletonTypes.anyTypePack; } TypeId lookupType(AstExpr* expr) { + // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. + // We'll just return anyType in these cases. Typechecking against any is very fast and this + // allows us not to think about this very much in the actual typechecking logic. TypeId* ty = module->astTypes.find(expr); - LUAU_ASSERT(ty); - return follow(*ty); + if (ty) + return follow(*ty); + + TypePackId* tp = module->astTypePacks.find(expr); + if (tp) + return flattenPack(*tp); + + return singletonTypes.anyType; } TypeId lookupAnnotation(AstType* annotation) @@ -78,7 +96,7 @@ struct TypeChecker2 : public AstVisitor bestLocation = scopeBounds; } } - else + else if (scopeBounds.begin > location.end) { // TODO: Is this sound? This relies on the fact that scopes are inserted // into the scope list in the order that they appear in the AST. @@ -147,16 +165,14 @@ struct TypeChecker2 : public AstVisitor for (size_t i = 0; i < count; ++i) { AstExpr* lhs = assign->vars.data[i]; - TypeId* lhsType = module->astTypes.find(lhs); - LUAU_ASSERT(lhsType); + TypeId lhsType = lookupType(lhs); AstExpr* rhs = assign->values.data[i]; - TypeId* rhsType = module->astTypes.find(rhs); - LUAU_ASSERT(rhsType); + TypeId rhsType = lookupType(rhs); - if (!isSubtype(*rhsType, *lhsType, ice)) + if (!isSubtype(rhsType, lhsType, ice)) { - reportError(TypeMismatch{*lhsType, *rhsType}, rhs->location); + reportError(TypeMismatch{lhsType, rhsType}, rhs->location); } } @@ -181,7 +197,7 @@ struct TypeChecker2 : public AstVisitor if (!ok) { for (const TypeError& e : u.errors) - module->errors.push_back(e); + reportError(e); } return true; @@ -189,10 +205,14 @@ struct TypeChecker2 : public AstVisitor bool visit(AstExprCall* call) override { + TypeArena arena; + Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}}; + TypePackId expectedRetType = lookupPack(call); TypeId functionType = lookupType(call->func); + TypeId instantiatedFunctionType = instantiation.substitute(functionType).value_or(nullptr); + LUAU_ASSERT(functionType); - TypeArena arena; TypePack args; for (const auto& arg : call->args) { @@ -204,7 +224,7 @@ struct TypeChecker2 : public AstVisitor TypePackId argsTp = arena.addTypePack(args); FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); - if (!isSubtype(expectedType, functionType, ice)) + if (!isSubtype(expectedType, instantiatedFunctionType, ice)) { unfreeze(module->interfaceTypes); CloneState cloneState; @@ -252,16 +272,12 @@ struct TypeChecker2 : public AstVisitor // leftType must have a property called indexName->index - if (auto ttv = get(leftType)) + std::optional t = findTablePropertyRespectingMeta(module->errors, leftType, indexName->index.value, indexName->location); + if (t) { - auto it = ttv->props.find(indexName->index.value); - if (it == ttv->props.end()) + if (!isSubtype(resultType, *t, ice)) { - reportError(UnknownProperty{leftType, indexName->index.value}, indexName->location); - } - else if (!isSubtype(resultType, it->second.type, ice)) - { - reportError(TypeMismatch{resultType, it->second.type}, indexName->location); + reportError(TypeMismatch{resultType, *t}, indexName->location); } } else @@ -277,7 +293,7 @@ struct TypeChecker2 : public AstVisitor TypeId actualType = lookupType(number); TypeId numberType = getSingletonTypes().numberType; - if (!isSubtype(actualType, numberType, ice)) + if (!isSubtype(numberType, actualType, ice)) { reportError(TypeMismatch{actualType, numberType}, number->location); } @@ -290,7 +306,7 @@ struct TypeChecker2 : public AstVisitor TypeId actualType = lookupType(string); TypeId stringType = getSingletonTypes().stringType; - if (!isSubtype(actualType, stringType, ice)) + if (!isSubtype(stringType, actualType, ice)) { reportError(TypeMismatch{actualType, stringType}, string->location); } @@ -298,6 +314,41 @@ struct TypeChecker2 : public AstVisitor return true; } + /** Extract a TypeId for the first type of the provided pack. + * + * Note that this may require modifying some types. I hope this doesn't cause problems! + */ + TypeId flattenPack(TypePackId pack) + { + pack = follow(pack); + + while (auto tp = get(pack)) + { + if (tp->head.empty() && tp->tail) + pack = *tp->tail; + } + + if (auto ty = first(pack)) + return *ty; + else if (auto vtp = get(pack)) + return vtp->ty; + else if (auto ftp = get(pack)) + { + TypeId result = module->internalTypes.addType(FreeTypeVar{ftp->scope}); + TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope}); + + TypePack& resultPack = asMutable(pack)->ty.emplace(); + resultPack.head.assign(1, result); + resultPack.tail = freeTail; + + return result; + } + else if (get(pack)) + return singletonTypes.errorRecoveryType(); + else + ice.ice("flattenPack got a weird pack!"); + } + bool visit(AstType* ty) override { return true; @@ -321,6 +372,11 @@ struct TypeChecker2 : public AstVisitor { module->errors.emplace_back(location, sourceModule->name, std::move(data)); } + + void reportError(TypeError e) + { + module->errors.emplace_back(std::move(e)); + } }; void check(const SourceModule& sourceModule, Module* module) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 44635e8..d9486a4 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -38,11 +38,13 @@ LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) +LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false); LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) LUAU_FASTFLAG(LuauQuantifyConstrained) LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false) LUAU_FASTFLAGVARIABLE(LuauNonCopyableTypeVarFields, false) +LUAU_FASTFLAGVARIABLE(LuauCheckLenMT, false) namespace Luau { @@ -238,7 +240,7 @@ static bool isMetamethod(const Name& name) { return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || - name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode"; + name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; } size_t HashBoolNamePair::operator()(const std::pair& pair) const @@ -327,10 +329,19 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo currentModule->timeout = true; } + if (FFlag::DebugLuauSharedSelf) + { + for (auto& [ty, scope] : deferredQuantification) + Luau::quantify(ty, scope->level); + deferredQuantification.clear(); + } + if (get(follow(moduleScope->returnType))) moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); else + { moduleScope->returnType = anyify(moduleScope, moduleScope->returnType, Location{}); + } for (auto& [_, typeFun] : moduleScope->exportedTypeBindings) typeFun.type = anyify(moduleScope, typeFun.type, Location{}); @@ -537,18 +548,43 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A } else if (auto fun = (*protoIter)->as()) { + std::optional selfType; std::optional expectedType; - if (!fun->func->self) + if (FFlag::DebugLuauSharedSelf) { if (auto name = fun->name->as()) { - TypeId exprTy = checkExpr(scope, *name->expr).type; - expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false); + TypeId baseTy = checkExpr(scope, *name->expr).type; + tablify(baseTy); + + if (!fun->func->self) + expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, false); + else if (auto ttv = getMutableTableType(baseTy)) + { + if (!baseTy->persistent && ttv->state != TableState::Sealed && !ttv->selfTy) + { + ttv->selfTy = anyIfNonstrict(freshType(ttv->level)); + deferredQuantification.push_back({baseTy, scope}); + } + + selfType = ttv->selfTy; + } + } + } + else + { + if (!fun->func->self) + { + if (auto name = fun->name->as()) + { + TypeId exprTy = checkExpr(scope, *name->expr).type; + expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false); + } } } - auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, expectedType); + auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, selfType, expectedType); auto [funTy, funScope] = pair; functionDecls[*protoIter] = pair; @@ -560,7 +596,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A } else if (auto fun = (*protoIter)->as()) { - auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt); + auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt, std::nullopt); auto [funTy, funScope] = pair; functionDecls[*protoIter] = pair; @@ -2076,7 +2112,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) { - auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, expectedType); + auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, std::nullopt, expectedType); checkFunctionBody(funScope, funTy, expr); @@ -2296,6 +2332,8 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); state.log.commit(); + reportErrors(state.errors); + TypeId retType = first(retTypePack).value_or(nilType); if (!state.errors.empty()) retType = errorRecoveryType(retType); @@ -2322,6 +2360,23 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp DenseHashSet seen{nullptr}; + if (FFlag::LuauCheckLenMT && typeCouldHaveMetatable(operandType)) + { + if (auto fnt = findMetatableEntry(operandType, "__len", expr.location)) + { + TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); + TypePackId arguments = addTypePack({operandType}); + TypePackId retTypePack = addTypePack({numberType}); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); + + Unifier state = mkUnifier(expr.location); + state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); + state.log.commit(); + + reportErrors(state.errors); + } + } + if (!hasLength(operandType, seen, &recursionCount)) reportError(TypeError{expr.location, NotATable{operandType}}); @@ -2530,17 +2585,15 @@ TypeId TypeChecker::checkRelationalOperation( } } - if (!matches) { reportError( expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); return errorRecoveryType(booleanType); } } - if (leftMetatable) { std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location); @@ -3139,8 +3192,8 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T // `(X) -> Y...`, but after typechecking the body, we cam unify `Y...` with `X` // to get type `(X) -> X`, then we quantify the free types to get the final // generic type `(a) -> a`. -std::pair TypeChecker::checkFunctionSignature( - const ScopePtr& scope, int subLevel, const AstExprFunction& expr, std::optional originalName, std::optional expectedType) +std::pair TypeChecker::checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, + std::optional originalName, std::optional selfType, std::optional expectedType) { ScopePtr funScope = childFunctionScope(scope, expr.location, subLevel); @@ -3241,12 +3294,25 @@ std::pair TypeChecker::checkFunctionSignature( funScope->returnType = retPack; - if (expr.self) + if (FFlag::DebugLuauSharedSelf) { - // TODO: generic self types: CLI-39906 - TypeId selfType = anyIfNonstrict(freshType(funScope)); - funScope->bindings[expr.self] = {selfType, expr.self->location}; - argTypes.push_back(selfType); + if (expr.self) + { + // TODO: generic self types: CLI-39906 + TypeId selfTy = anyIfNonstrict(selfType ? *selfType : freshType(funScope)); + funScope->bindings[expr.self] = {selfTy, expr.self->location}; + argTypes.push_back(selfTy); + } + } + else + { + if (expr.self) + { + // TODO: generic self types: CLI-39906 + TypeId selfType = anyIfNonstrict(freshType(funScope)); + funScope->bindings[expr.self] = {selfType, expr.self->location}; + argTypes.push_back(selfType); + } } // Prepare expected argument type iterators if we have an expected function type @@ -4457,25 +4523,43 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location { ty = follow(ty); - const FunctionTypeVar* ftv = get(ty); - - if (FFlag::LuauAlwaysQuantify) + if (FFlag::DebugLuauSharedSelf) { - if (ftv) + if (auto ftv = get(ty)) Luau::quantify(ty, scope->level); + else if (auto ttv = getTableType(ty); ttv && ttv->selfTy) + Luau::quantify(ty, scope->level); + + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + return t; + } } else { - if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) - Luau::quantify(ty, scope->level); - } + const FunctionTypeVar* ftv = get(ty); - if (FFlag::LuauLowerBoundsCalculation && ftv) - { - auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); - if (!ok) - reportError(location, NormalizationTooComplex{}); - return t; + if (FFlag::LuauAlwaysQuantify) + { + if (ftv) + Luau::quantify(ty, scope->level); + } + else + { + if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) + Luau::quantify(ty, scope->level); + } + + if (FFlag::LuauLowerBoundsCalculation && ftv) + { + auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + return t; + } } return ty; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 6147e11..0792a35 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -740,7 +740,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I std::optional unificationTooComplex; std::optional firstFailedOption; - // T <: A & B if A <: T and B <: T + // T <: A & B if T <: A and T <: B for (TypeId type : uv->parts) { Unifier innerState = makeChildUnifier(); @@ -765,7 +765,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) { - // A & B <: T if T <: A or T <: B + // A & B <: T if A <: T or B <: T bool found = false; std::optional unificationTooComplex; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 95bce3e..70c9255 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -5,6 +5,9 @@ #include +#include +#include + // Warning: If you are introducing new syntax, ensure that it is behind a separate // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. @@ -14,6 +17,18 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false) LUAU_FASTFLAGVARIABLE(LuauReturnTypeTokenConfusion, false) +LUAU_FASTFLAGVARIABLE(LuauFixNamedFunctionParse, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseWrongNamedType, false) + +bool lua_telemetry_parsed_named_non_function_type = false; + +LUAU_FASTFLAGVARIABLE(LuauErrorParseIntegerIssues, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) + +bool lua_telemetry_parsed_out_of_range_bin_integer = false; +bool lua_telemetry_parsed_out_of_range_hex_integer = false; +bool lua_telemetry_parsed_double_prefix_hex_integer = false; + namespace Luau { @@ -1330,7 +1345,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); - bool monomorphic = lexer.current().type != '<'; + bool forceFunctionType = lexer.current().type == '<'; Lexeme begin = lexer.current(); @@ -1355,21 +1370,33 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) AstArray paramTypes = copy(params); + if (FFlag::LuauFixNamedFunctionParse && !names.empty()) + forceFunctionType = true; + bool returnTypeIntroducer = FFlag::LuauReturnTypeTokenConfusion ? lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':' : false; // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element - if (params.size() == 1 && !varargAnnotation && monomorphic && + if (params.size() == 1 && !varargAnnotation && !forceFunctionType && (FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow)) { + if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) + lua_telemetry_parsed_named_non_function_type = true; + if (allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; else return {params[0], {}}; } - if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && monomorphic && allowPack) + if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && !forceFunctionType && + allowPack) + { + if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) + lua_telemetry_parsed_named_non_function_type = true; + return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, varargAnnotation})}; + } AstArray> paramNames = copy(names); @@ -2010,7 +2037,63 @@ AstExpr* Parser::parseAssertionExpr() return expr; } -static bool parseNumber(double& result, const char* data) +static const char* parseInteger(double& result, const char* data, int base) +{ + char* end = nullptr; + unsigned long long value = strtoull(data, &end, base); + + if (value == ULLONG_MAX && errno == ERANGE) + { + // 'errno' might have been set before we called 'strtoull', but we don't want the overhead of resetting a TLS variable on each call + // so we only reset it when we get a result that might be an out-of-range error and parse again to make sure + errno = 0; + value = strtoull(data, &end, base); + + if (errno == ERANGE) + { + if (DFFlag::LuaReportParseIntegerIssues) + { + if (base == 2) + lua_telemetry_parsed_out_of_range_bin_integer = true; + else + lua_telemetry_parsed_out_of_range_hex_integer = true; + } + + if (FFlag::LuauErrorParseIntegerIssues) + return "Integer number value is out of range"; + } + } + + result = double(value); + return *end == 0 ? nullptr : "Malformed number"; +} + +static const char* parseNumber(double& result, const char* data) +{ + // binary literal + if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2]) + return parseInteger(result, data + 2, 2); + + // hexadecimal literal + if (data[0] == '0' && (data[1] == 'x' || data[1] == 'X') && data[2]) + { + if (DFFlag::LuaReportParseIntegerIssues && data[2] == '0' && (data[3] == 'x' || data[3] == 'X')) + lua_telemetry_parsed_double_prefix_hex_integer = true; + + if (FFlag::LuauErrorParseIntegerIssues) + return parseInteger(result, data, 16); // keep prefix, it's handled by 'strtoull' + else + return parseInteger(result, data + 2, 16); + } + + char* end = nullptr; + double value = strtod(data, &end); + + result = value; + return *end == 0 ? nullptr : "Malformed number"; +} + +static bool parseNumber_DEPRECATED(double& result, const char* data) { // binary literal if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2]) @@ -2080,18 +2163,37 @@ AstExpr* Parser::parseSimpleExpr() scratchData.erase(std::remove(scratchData.begin(), scratchData.end(), '_'), scratchData.end()); } - double value = 0; - if (parseNumber(value, scratchData.c_str())) + if (DFFlag::LuaReportParseIntegerIssues || FFlag::LuauErrorParseIntegerIssues) { - nextLexeme(); + double value = 0; + if (const char* error = parseNumber(value, scratchData.c_str())) + { + nextLexeme(); - return allocator.alloc(start, value); + return reportExprError(start, {}, "%s", error); + } + else + { + nextLexeme(); + + return allocator.alloc(start, value); + } } else { - nextLexeme(); + double value = 0; + if (parseNumber_DEPRECATED(value, scratchData.c_str())) + { + nextLexeme(); - return reportExprError(start, {}, "Malformed number"); + return allocator.alloc(start, value); + } + else + { + nextLexeme(); + + return reportExprError(start, {}, "Malformed number"); + } } } else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 218bb5d..0cb7e1d 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -276,7 +276,7 @@ enum LuauOpcode // FORGLOOP: adjust loop variables for one iteration of a generic for loop, jump back to the loop header if loop needs to continue // A: target register; generic for loops assume a register layout [generator, state, index, variables...] // D: jump offset (-32768..32767) - // AUX: variable count (1..255) + // AUX: variable count (1..255) in the low 8 bits, high bit indicates whether to use ipairs-style traversal in the fast path // loop variables are adjusted by calling generator(state, index) and expecting it to return a tuple that's copied to the user variables // the first variable is then copied into index; generator/state are immutable, index isn't visible to user code LOP_FORGLOOP, @@ -490,6 +490,9 @@ enum LuauBuiltinFunction // select(_, ...) LBF_SELECT_VARARG, + + // rawlen + LBF_RAWLEN, }; // Capture type, used in LOP_CAPTURE diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index ff75311..6bd24b6 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,6 +4,8 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" +LUAU_FASTFLAGVARIABLE(LuauCompileRawlen, false) + namespace Luau { namespace Compile @@ -58,6 +60,8 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) return LBF_RAWGET; if (builtin.isGlobal("rawequal")) return LBF_RAWEQUAL; + if (FFlag::LuauCompileRawlen && builtin.isGlobal("rawlen")) + return LBF_RAWLEN; if (builtin.isGlobal("unpack")) return LBF_TABLE_UNPACK; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 301cf25..5e2669b 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -1302,20 +1302,22 @@ void BytecodeBuilder::validate() const case LOP_FORNPREP: case LOP_FORNLOOP: - VREG(LUAU_INSN_A(insn) + 2); // for loop protocol: A, A+1, A+2 are used for iteration + // for loop protocol: A, A+1, A+2 are used for iteration + VREG(LUAU_INSN_A(insn) + 2); VJUMP(LUAU_INSN_D(insn)); break; case LOP_FORGPREP: - VREG(LUAU_INSN_A(insn) + 2 + 1); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + VREG(LUAU_INSN_A(insn) + 2 + 1); VJUMP(LUAU_INSN_D(insn)); break; case LOP_FORGLOOP: - VREG( - LUAU_INSN_A(insn) + 2 + insns[i + 1]); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + VREG(LUAU_INSN_A(insn) + 2 + uint8_t(insns[i + 1])); VJUMP(LUAU_INSN_D(insn)); - LUAU_ASSERT(insns[i + 1] >= 1); + LUAU_ASSERT(uint8_t(insns[i + 1]) >= 1); break; case LOP_FORGPREP_INEXT: @@ -1679,7 +1681,8 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, break; case LOP_FORGLOOP: - formatAppend(result, "FORGLOOP R%d L%d %d\n", LUAU_INSN_A(insn), targetLabel, *code++); + formatAppend(result, "FORGLOOP R%d L%d %d%s\n", LUAU_INSN_A(insn), targetLabel, uint8_t(*code), int(*code) < 0 ? " [inext]" : ""); + code++; break; case LOP_FORGPREP_INEXT: diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index e732256..d7c8155 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -23,6 +23,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) +LUAU_FASTFLAGVARIABLE(LuauCompileNoIpairs, false) + namespace Luau { @@ -2665,7 +2667,7 @@ struct Compiler if (builtin.isGlobal("ipairs")) // for .. in ipairs(t) { skipOp = LOP_FORGPREP_INEXT; - loopOp = LOP_FORGLOOP_INEXT; + loopOp = FFlag::LuauCompileNoIpairs ? LOP_FORGLOOP : LOP_FORGLOOP_INEXT; } else if (builtin.isGlobal("pairs")) // for .. in pairs(t) { @@ -2709,8 +2711,16 @@ struct Compiler bytecode.emitAD(loopOp, regs, 0); + if (FFlag::LuauCompileNoIpairs) + { + // TODO: remove loopOp as it's a constant now + LUAU_ASSERT(loopOp == LOP_FORGLOOP); + + // FORGLOOP uses aux to encode variable count and fast path flag for ipairs traversal in the high bit + bytecode.emitAux((skipOp == LOP_FORGPREP_INEXT ? 0x80000000 : 0) | uint32_t(stat->vars.size)); + } // note: FORGLOOP needs variable count encoded in AUX field, other loop instructions assume a fixed variable count - if (loopOp == LOP_FORGLOOP) + else if (loopOp == LOP_FORGLOOP) bytecode.emitAux(uint32_t(stat->vars.size)); size_t endLabel = bytecode.emitLabel(); @@ -3341,7 +3351,7 @@ struct Compiler std::vector upvals; }; - struct ReturnVisitor: AstVisitor + struct ReturnVisitor : AstVisitor { Compiler* self; bool returnsOne = true; diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index f791761..4fc5033 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -11,6 +11,8 @@ #include #include +LUAU_FASTFLAG(LuauLenTM) + static void writestring(const char* s, size_t l) { fwrite(s, 1, l, stdout); @@ -178,6 +180,18 @@ static int luaB_rawset(lua_State* L) return 1; } +static int luaB_rawlen(lua_State* L) +{ + if (!FFlag::LuauLenTM) + luaL_error(L, "'rawlen' is not available"); + + int tt = lua_type(L, 1); + luaL_argcheck(L, tt == LUA_TTABLE || tt == LUA_TSTRING, 1, "table or string expected"); + int len = lua_objlen(L, 1); + lua_pushinteger(L, len); + return 1; +} + static int luaB_gcinfo(lua_State* L) { lua_pushinteger(L, lua_gc(L, LUA_GCCOUNT, 0)); @@ -428,6 +442,7 @@ static const luaL_Reg base_funcs[] = { {"rawequal", luaB_rawequal}, {"rawget", luaB_rawget}, {"rawset", luaB_rawset}, + {"rawlen", luaB_rawlen}, {"select", luaB_select}, {"setfenv", luaB_setfenv}, {"setmetatable", luaB_setmetatable}, diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index deaf140..e98660a 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1117,6 +1117,27 @@ static int luauF_select(lua_State* L, StkId res, TValue* arg0, int nresults, Stk return -1; } +static int luauF_rawlen(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1) + { + if (ttistable(arg0)) + { + Table* h = hvalue(arg0); + setnvalue(res, double(luaH_getn(h))); + return 1; + } + else if (ttisstring(arg0)) + { + TString* ts = tsvalue(arg0); + setnvalue(res, double(ts->len)); + return 1; + } + } + + return -1; +} + luau_FastFunction luauF_table[256] = { NULL, luauF_assert, @@ -1188,4 +1209,6 @@ luau_FastFunction luauF_table[256] = { luauF_countrz, luauF_select, + + luauF_rawlen, }; diff --git a/VM/src/ldebug.h b/VM/src/ldebug.h index cf905e9..75bb8dc 100644 --- a/VM/src/ldebug.h +++ b/VM/src/ldebug.h @@ -19,6 +19,7 @@ LUAI_FUNC l_noret luaG_concaterror(lua_State* L, StkId p1, StkId p2); LUAI_FUNC l_noret luaG_aritherror(lua_State* L, const TValue* p1, const TValue* p2, TMS op); LUAI_FUNC l_noret luaG_ordererror(lua_State* L, const TValue* p1, const TValue* p2, TMS op); LUAI_FUNC l_noret luaG_indexerror(lua_State* L, const TValue* p1, const TValue* p2); + LUAI_FUNC LUA_PRINTF_ATTR(2, 3) l_noret luaG_runerrorL(lua_State* L, const char* fmt, ...); LUAI_FUNC void luaG_pusherror(lua_State* L, const char* error); diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index e7df4e5..49982b2 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -39,6 +39,7 @@ const char* const luaT_eventname[] = { "__namecall", "__call", "__iter", + "__len", "__eq", @@ -52,7 +53,6 @@ const char* const luaT_eventname[] = { "__unm", - "__len", "__lt", "__le", "__concat", diff --git a/VM/src/ltm.h b/VM/src/ltm.h index a522394..e11ddb3 100644 --- a/VM/src/ltm.h +++ b/VM/src/ltm.h @@ -18,6 +18,7 @@ typedef enum TM_NAMECALL, TM_CALL, TM_ITER, + TM_LEN, TM_EQ, /* last tag method with `fast' access */ @@ -31,7 +32,6 @@ typedef enum TM_UNM, - TM_LEN, TM_LT, TM_LE, TM_CONCAT, diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index e0a9647..85829ca 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,6 +16,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauLenTM, false) + // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -2082,13 +2084,25 @@ static void luau_execute(lua_State* L) // fast-path #1: tables if (ttistable(rb)) { - setnvalue(ra, cast_num(luaH_getn(hvalue(rb)))); - VM_NEXT(); + Table* h = hvalue(rb); + + if (!FFlag::LuauLenTM || fastnotm(h->metatable, TM_LEN)) + { + setnvalue(ra, cast_num(luaH_getn(h))); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_dolen(L, ra, rb)); + VM_NEXT(); + } } // fast-path #2: strings (not very important but easy to do) else if (ttisstring(rb)) { - setnvalue(ra, cast_num(tsvalue(rb)->len)); + TString* ts = tsvalue(rb); + setnvalue(ra, cast_num(ts->len)); VM_NEXT(); } else @@ -2226,6 +2240,15 @@ static void luau_execute(lua_State* L) VM_PROTECT(luaD_call(L, ra, 3)); L->top = L->ci->top; + + /* recompute ra since stack might have been reallocated */ + ra = VM_REG(LUAU_INSN_A(insn)); + + /* protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP */ + if (ttisnil(ra)) + { + VM_PROTECT(luaG_typeerror(L, ra, "call")); + } } else if (fasttm(L, mt, TM_CALL)) { @@ -2258,27 +2281,38 @@ static void luau_execute(lua_State* L) uint32_t aux = *pc; // fast-path: builtin table iteration - if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) + // note: ra=nil guarantees ra+1=table and ra+2=userdata because of the setup by FORGPREP* opcodes + // TODO: remove the table check per guarantee above + if (ttisnil(ra) && ttistable(ra + 1)) { Table* h = hvalue(ra + 1); int index = int(reinterpret_cast(pvalue(ra + 2))); int sizearray = h->sizearray; - int sizenode = 1 << h->lsizenode; // clear extra variables since we might have more than two - if (LUAU_UNLIKELY(aux > 2)) + // note: while aux encodes ipairs bit, when set we always use 2 variables, so it's safe to check this via a signed comparison + if (LUAU_UNLIKELY(int(aux) > 2)) for (int i = 2; i < int(aux); ++i) setnilvalue(ra + 3 + i); + // terminate ipairs-style traversal early when encountering nil + if (int(aux) < 0 && (unsigned(index) >= unsigned(sizearray) || ttisnil(&h->array[index]))) + { + pc++; + VM_NEXT(); + } + // first we advance index through the array portion while (unsigned(index) < unsigned(sizearray)) { - if (!ttisnil(&h->array[index])) + TValue* e = &h->array[index]; + + if (!ttisnil(e)) { setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); setnvalue(ra + 3, double(index + 1)); - setobj2s(L, ra + 4, &h->array[index]); + setobj2s(L, ra + 4, e); pc += LUAU_INSN_D(insn); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -2288,6 +2322,8 @@ static void luau_execute(lua_State* L) index++; } + int sizenode = 1 << h->lsizenode; + // then we advance index through the hash portion while (unsigned(index - sizearray) < unsigned(sizenode)) { @@ -2321,7 +2357,7 @@ static void luau_execute(lua_State* L) L->top = ra + 3 + 3; /* func + 2 args (state and index) */ LUAU_ASSERT(L->top <= L->stack_last); - VM_PROTECT(luaD_call(L, ra + 3, aux)); + VM_PROTECT(luaD_call(L, ra + 3, uint8_t(aux))); L->top = L->ci->top; // recompute ra since stack might have been reallocated diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 8a18a4d..b9e762e 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -10,6 +10,9 @@ #include "lnumutils.h" #include +#include + +LUAU_FASTFLAG(LuauLenTM) /* limit for table tag-method chains (to avoid loops) */ #define MAXTAGLOOP 100 @@ -51,7 +54,7 @@ const float* luaV_tovector(const TValue* obj) return nullptr; } -static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1, const TValue* p2) +static StkId callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1, const TValue* p2) { ptrdiff_t result = savestack(L, res); // using stack room beyond top is technically safe here, but for very complicated reasons: @@ -71,6 +74,7 @@ static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1 res = restorestack(L, result); L->top--; setobjs2s(L, res, L->top); + return res; } static void callTM(lua_State* L, const TValue* f, const TValue* p1, const TValue* p2, const TValue* p3) @@ -472,22 +476,56 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM void luaV_dolen(lua_State* L, StkId ra, const TValue* rb) { + if (!FFlag::LuauLenTM) + { + switch (ttype(rb)) + { + case LUA_TTABLE: + { + setnvalue(ra, cast_num(luaH_getn(hvalue(rb)))); + break; + } + case LUA_TSTRING: + { + setnvalue(ra, cast_num(tsvalue(rb)->len)); + break; + } + default: + { /* try metamethod */ + if (!call_binTM(L, rb, luaO_nilobject, ra, TM_LEN)) + luaG_typeerror(L, rb, "get length of"); + } + } + return; + } + + const TValue* tm = NULL; switch (ttype(rb)) { case LUA_TTABLE: { - setnvalue(ra, cast_num(luaH_getn(hvalue(rb)))); + Table* h = hvalue(rb); + if ((tm = fasttm(L, h->metatable, TM_LEN)) == NULL) + { + setnvalue(ra, cast_num(luaH_getn(h))); + return; + } break; } case LUA_TSTRING: { - setnvalue(ra, cast_num(tsvalue(rb)->len)); - break; + TString* ts = tsvalue(rb); + setnvalue(ra, cast_num(ts->len)); + return; } default: - { /* try metamethod */ - if (!call_binTM(L, rb, luaO_nilobject, ra, TM_LEN)) - luaG_typeerror(L, rb, "get length of"); - } + tm = luaT_gettmbyobj(L, rb, TM_LEN); } + + if (ttisnil(tm)) + luaG_typeerror(L, rb, "get length of"); + + StkId res = callTMres(L, ra, tm, rb, luaO_nilobject); + if (!ttisnumber(res)) + luaG_runerror(L, "'__len' must return a number"); /* note, we can't access rb since stack may have been reallocated */ } diff --git a/fuzz/protoprint.cpp b/fuzz/protoprint.cpp index 66a89f2..d4d5227 100644 --- a/fuzz/protoprint.cpp +++ b/fuzz/protoprint.cpp @@ -11,6 +11,7 @@ static const std::string kNames[] = { "__div", "__eq", "__index", + "__iter", "__le", "__len", "__lt", @@ -41,13 +42,18 @@ static const std::string kNames[] = { "ceil", "char", "charpattern", + "clamp", "clock", + "clone", + "close", "codepoint", "codes", "concat", "coroutine", "cos", "cosh", + "countlz", + "countrz", "create", "date", "debug", @@ -63,6 +69,7 @@ static const std::string kNames[] = { "foreachi", "format", "frexp", + "freeze", "function", "gcinfo", "getfenv", @@ -72,8 +79,10 @@ static const std::string kNames[] = { "gmatch", "gsub", "huge", + "info", "insert", "ipairs", + "isfrozen", "isyieldable", "ldexp", "len", @@ -93,6 +102,7 @@ static const std::string kNames[] = { "newproxy", "next", "nil", + "noise", "number", "offset", "os", @@ -121,6 +131,7 @@ static const std::string kNames[] = { "select", "setfenv", "setmetatable", + "sign", "sin", "sinh", "sort", diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 655e48c..fafafd7 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -261,6 +261,8 @@ L1: RETURN R0 0 TEST_CASE("ForBytecode") { + ScopedFastFlag sff("LuauCompileNoIpairs", true); + // basic for loop: variable directly refers to internal iteration index (R2) CHECK_EQ("\n" + compileFunction0("for i=1,5 do print(i) end"), R"( LOADN R2 1 @@ -313,7 +315,7 @@ L0: GETIMPORT R5 3 MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 -L1: FORGLOOP_INEXT R0 L0 +L1: FORGLOOP R0 L0 2 [inext] RETURN R0 0 )"); @@ -347,13 +349,15 @@ RETURN R0 0 TEST_CASE("ForBytecodeBuiltin") { + ScopedFastFlag sff("LuauCompileNoIpairs", true); + // we generally recognize builtins like pairs/ipairs and emit special opcodes CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"( GETIMPORT R0 1 NEWTABLE R1 0 0 CALL R0 1 3 FORGPREP_INEXT R0 L0 -L0: FORGLOOP_INEXT R0 L0 +L0: FORGLOOP R0 L0 2 [inext] RETURN R0 0 )"); @@ -364,7 +368,7 @@ MOVE R1 R0 NEWTABLE R2 0 0 CALL R1 1 3 FORGPREP_INEXT R1 L0 -L0: FORGLOOP_INEXT R1 L0 +L0: FORGLOOP R1 L0 2 [inext] RETURN R0 0 )"); @@ -374,7 +378,7 @@ GETUPVAL R0 0 NEWTABLE R1 0 0 CALL R0 1 3 FORGPREP_INEXT R0 L0 -L0: FORGLOOP_INEXT R0 L0 +L0: FORGLOOP R0 L0 2 [inext] RETURN R0 0 )"); @@ -2107,6 +2111,8 @@ RETURN R3 -1 TEST_CASE("UpvaluesLoopsBytecode") { + ScopedFastFlag sff("LuauCompileNoIpairs", true); + CHECK_EQ("\n" + compileFunction(R"( function test() for i=1,10 do @@ -2169,7 +2175,7 @@ JUMPIFNOT R5 L1 CLOSEUPVALS R3 JUMP L3 L1: CLOSEUPVALS R3 -L2: FORGLOOP_INEXT R0 L0 +L2: FORGLOOP R0 L0 1 [inext] L3: LOADN R0 0 RETURN R0 1 )"); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 96a2775..3f41514 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -231,6 +231,8 @@ TEST_CASE("Assert") TEST_CASE("Basic") { + ScopedFastFlag sff("LuauLenTM", true); + runConformance("basic.lua"); } @@ -301,6 +303,8 @@ TEST_CASE("Errors") TEST_CASE("Events") { + ScopedFastFlag sff("LuauLenTM", true); + runConformance("events.lua"); } @@ -475,6 +479,8 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { + ScopedFastFlag sff("LuauCheckLenMT", true); + runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; Luau::InternalErrorReporter iceHandler; diff --git a/tests/ConstraintGraphBuilder.test.cpp b/tests/ConstraintGraphBuilder.test.cpp index 96b2161..00c3309 100644 --- a/tests/ConstraintGraphBuilder.test.cpp +++ b/tests/ConstraintGraphBuilder.test.cpp @@ -17,7 +17,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello_world") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(2 == constraints.size()); @@ -36,7 +36,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "primitives") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(3 == constraints.size()); @@ -54,15 +54,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "nil_primitive") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); ToStringOptions opts; REQUIRE(5 <= constraints.size()); CHECK("*blocked-1* ~ gen () -> (a...)" == toString(*constraints[0], opts)); - CHECK("b ~ inst *blocked-1*" == toString(*constraints[1], opts)); - CHECK("() -> (c...) <: b" == toString(*constraints[2], opts)); - CHECK("c... <: d" == toString(*constraints[3], opts)); + CHECK("*blocked-2* ~ inst *blocked-1*" == toString(*constraints[1], opts)); + CHECK("() -> (b...) <: *blocked-2*" == toString(*constraints[2], opts)); + CHECK("b... <: c" == toString(*constraints[3], opts)); CHECK("nil <: a..." == toString(*constraints[4], opts)); } @@ -74,15 +74,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(4 == constraints.size()); ToStringOptions opts; CHECK("string <: a" == toString(*constraints[0], opts)); - CHECK("b ~ inst a" == toString(*constraints[1], opts)); - CHECK("(string) -> (c...) <: b" == toString(*constraints[2], opts)); - CHECK("c... <: d" == toString(*constraints[3], opts)); + CHECK("*blocked-1* ~ inst a" == toString(*constraints[1], opts)); + CHECK("(string) -> (b...) <: *blocked-1*" == toString(*constraints[2], opts)); + CHECK("b... <: c" == toString(*constraints[3], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") @@ -94,7 +94,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(2 == constraints.size()); @@ -112,15 +112,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "recursive_function") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(4 == constraints.size()); ToStringOptions opts; CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts)); - CHECK("c ~ inst (a) -> (b...)" == toString(*constraints[1], opts)); - CHECK("(a) -> (d...) <: c" == toString(*constraints[2], opts)); - CHECK("d... <: b..." == toString(*constraints[3], opts)); + CHECK("*blocked-2* ~ inst (a) -> (b...)" == toString(*constraints[1], opts)); + CHECK("(a) -> (c...) <: *blocked-2*" == toString(*constraints[2], opts)); + CHECK("c... <: b..." == toString(*constraints[3], opts)); } TEST_SUITE_END(); diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp index 5959f55..f521c66 100644 --- a/tests/ConstraintSolver.test.cpp +++ b/tests/ConstraintSolver.test.cpp @@ -9,7 +9,7 @@ using namespace Luau; -static TypeId requireBinding(Scope2* scope, const char* name) +static TypeId requireBinding(NotNull scope, const char* name) { auto b = linearSearchForBinding(scope, name); LUAU_ASSERT(b.has_value()); @@ -26,12 +26,13 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello") )"); cgb.visit(block); + NotNull rootScope = NotNull(cgb.rootScope); - ConstraintSolver cs{&arena, cgb.rootScope}; + ConstraintSolver cs{&arena, rootScope}; cs.run(); - TypeId bType = requireBinding(cgb.rootScope, "b"); + TypeId bType = requireBinding(rootScope, "b"); CHECK("number" == toString(bType)); } @@ -45,12 +46,13 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") )"); cgb.visit(block); + NotNull rootScope = NotNull(cgb.rootScope); - ConstraintSolver cs{&arena, cgb.rootScope}; + ConstraintSolver cs{&arena, rootScope}; cs.run(); - TypeId idType = requireBinding(cgb.rootScope, "id"); + TypeId idType = requireBinding(rootScope, "id"); CHECK("(a) -> a" == toString(idType)); } @@ -71,14 +73,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") )"); cgb.visit(block); + NotNull rootScope = NotNull(cgb.rootScope); ToStringOptions opts; - ConstraintSolver cs{&arena, cgb.rootScope}; + ConstraintSolver cs{&arena, rootScope}; cs.run(); - TypeId idType = requireBinding(cgb.rootScope, "b"); + TypeId idType = requireBinding(rootScope, "b"); CHECK("(a) -> number" == toString(idType, opts)); } diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index ac22f65..c92c445 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -195,12 +195,15 @@ ParseResult Fixture::matchParseError(const std::string& source, const std::strin sourceModule.reset(new SourceModule); ParseResult result = Parser::parse(source.c_str(), source.length(), *sourceModule->names, *sourceModule->allocator, options); - REQUIRE_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); + CHECK_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); - CHECK_EQ(result.errors.front().getMessage(), message); + if (!result.errors.empty()) + { + CHECK_EQ(result.errors.front().getMessage(), message); - if (location) - CHECK_EQ(result.errors.front().getLocation(), *location); + if (location) + CHECK_EQ(result.errors.front().getLocation(), *location); + } return result; } @@ -213,11 +216,14 @@ ParseResult Fixture::matchParseErrorPrefix(const std::string& source, const std: sourceModule.reset(new SourceModule); ParseResult result = Parser::parse(source.c_str(), source.length(), *sourceModule->names, *sourceModule->allocator, options); - REQUIRE_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); + CHECK_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); - const std::string& message = result.errors.front().getMessage(); - CHECK_GE(message.length(), prefix.length()); - CHECK_EQ(prefix, message.substr(0, prefix.size())); + if (!result.errors.empty()) + { + const std::string& message = result.errors.front().getMessage(); + CHECK_GE(message.length(), prefix.length()); + CHECK_EQ(prefix, message.substr(0, prefix.size())); + } return result; } @@ -428,6 +434,7 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() : Fixture() + , cgb(mainModuleName, &arena, NotNull(&ice), frontend.getGlobalScope2()) , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} { BlockedTypeVar::nextIndex = 0; diff --git a/tests/Fixture.h b/tests/Fixture.h index 0e3735f..1bc573d 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -133,6 +133,7 @@ struct Fixture TestConfigResolver configResolver; std::unique_ptr sourceModule; Frontend frontend; + InternalErrorReporter ice; TypeChecker& typeChecker; std::string decorateWithTypes(const std::string& code); @@ -160,7 +161,7 @@ struct BuiltinsFixture : Fixture struct ConstraintGraphBuilderFixture : Fixture { TypeArena arena; - ConstraintGraphBuilder cgb{&arena}; + ConstraintGraphBuilder cgb; ScopedFastFlag forceTheFlag; diff --git a/tests/NotNull.test.cpp b/tests/NotNull.test.cpp index ed1c25e..e77ba78 100644 --- a/tests/NotNull.test.cpp +++ b/tests/NotNull.test.cpp @@ -30,23 +30,23 @@ struct Test int Test::count = 0; -} +} // namespace int foo(NotNull p) { return *p; } -void bar(int* q) -{} +void bar(int* q) {} TEST_SUITE_BEGIN("NotNull"); TEST_CASE("basic_stuff") { - NotNull a = NotNull{new int(55)}; // Does runtime test - NotNull b{new int(55)}; // As above - // NotNull c = new int(55); // Nope. Mildly regrettable, but implicit conversion from T* to NotNull in the general case is not good. + NotNull a = NotNull{new int(55)}; // Does runtime test + NotNull b{new int(55)}; // As above + // NotNull c = new int(55); // Nope. Mildly regrettable, but implicit conversion from T* to NotNull in the general case is not + // good. // a = nullptr; // nope diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 878023e..c3c7599 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -6,6 +6,8 @@ #include "doctest.h" +#include + using namespace Luau; namespace @@ -786,33 +788,46 @@ TEST_CASE_FIXTURE(Fixture, "parse_numbers_decimal") TEST_CASE_FIXTURE(Fixture, "parse_numbers_hexadecimal") { - AstStat* stat = parse("return 0xab, 0XAB05, 0xff_ff"); + AstStat* stat = parse("return 0xab, 0XAB05, 0xff_ff, 0xffffffffffffffff"); REQUIRE(stat != nullptr); AstStatReturn* str = stat->as()->body.data[0]->as(); - CHECK(str->list.size == 3); + CHECK(str->list.size == 4); CHECK_EQ(str->list.data[0]->as()->value, 0xab); CHECK_EQ(str->list.data[1]->as()->value, 0xAB05); CHECK_EQ(str->list.data[2]->as()->value, 0xFFFF); + CHECK_EQ(str->list.data[3]->as()->value, double(ULLONG_MAX)); } TEST_CASE_FIXTURE(Fixture, "parse_numbers_binary") { - AstStat* stat = parse("return 0b1, 0b0, 0b101010"); + AstStat* stat = parse("return 0b1, 0b0, 0b101010, 0b1111111111111111111111111111111111111111111111111111111111111111"); REQUIRE(stat != nullptr); AstStatReturn* str = stat->as()->body.data[0]->as(); - CHECK(str->list.size == 3); + CHECK(str->list.size == 4); CHECK_EQ(str->list.data[0]->as()->value, 1); CHECK_EQ(str->list.data[1]->as()->value, 0); CHECK_EQ(str->list.data[2]->as()->value, 42); + CHECK_EQ(str->list.data[3]->as()->value, double(ULLONG_MAX)); } TEST_CASE_FIXTURE(Fixture, "parse_numbers_error") { + ScopedFastFlag luauErrorParseIntegerIssues{"LuauErrorParseIntegerIssues", true}; + CHECK_EQ(getParseError("return 0b123"), "Malformed number"); CHECK_EQ(getParseError("return 123x"), "Malformed number"); CHECK_EQ(getParseError("return 0xg"), "Malformed number"); + CHECK_EQ(getParseError("return 0x0x123"), "Malformed number"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_numbers_range_error") +{ + ScopedFastFlag luauErrorParseIntegerIssues{"LuauErrorParseIntegerIssues", true}; + + CHECK_EQ(getParseError("return 0x10000000000000000"), "Integer number value is out of range"); + CHECK_EQ(getParseError("return 0b10000000000000000000000000000000000000000000000000000000000000000"), "Integer number value is out of range"); } TEST_CASE_FIXTURE(Fixture, "break_return_not_last_error") @@ -2111,6 +2126,15 @@ type C = Packed<(number, X...)> REQUIRE(stat != nullptr); } +TEST_CASE_FIXTURE(Fixture, "invalid_type_forms") +{ + ScopedFastFlag luauFixNamedFunctionParse{"LuauFixNamedFunctionParse", true}; + + matchParseError("type A = (b: number)", "Expected '->' when parsing function type, got "); + matchParseError("type P = () -> T... type B = P<(x: number, y: string)>", "Expected '->' when parsing function type, got '>'"); + matchParseError("type F = (T...) -> ()", "Expected '->' when parsing function type, got '>'"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index e03069a..387e07c 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -409,6 +409,8 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") { + ScopedFastFlag sff2{"DebugLuauSharedSelf", true}; + CheckResult result = check(R"( local base = {} function base:one() return 1 end @@ -424,7 +426,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") TypeId tType = requireType("inst"); ToStringResult r = toStringDetailed(tType); - CHECK_EQ("{ @metatable { __index: { @metatable { __index: base }, child } }, inst }", r.name); + CHECK_EQ("{ @metatable { __index: { @metatable {| __index: base |}, child } }, inst }", r.name); CHECK_EQ(0, r.nameMap.typeVars.size()); ToStringOptions opts; @@ -455,11 +457,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") std::string twoResult = toString(tMeta6->props["two"].type, opts); - REQUIRE_EQ("(a) -> number", oneResult.name); - REQUIRE_EQ("(b) -> number", twoResult); + CHECK_EQ("(a) -> number", oneResult.name); + CHECK_EQ("(b) -> number", twoResult); } - TEST_CASE_FIXTURE(Fixture, "toStringErrorPack") { CheckResult result = check(R"( @@ -688,6 +689,10 @@ TEST_CASE_FIXTURE(Fixture, "pick_distinct_names_for_mixed_explicit_and_implicit_ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") { + ScopedFastFlag sff[]{ + {"DebugLuauSharedSelf", true}, + }; + CheckResult result = check(R"( local foo = {} function foo:method(arg: string): () @@ -701,9 +706,12 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") CHECK_EQ("foo:method(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv)); } - TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") { + ScopedFastFlag sff[]{ + {"DebugLuauSharedSelf", true}, + }; + CheckResult result = check(R"( local foo = {} function foo:method(arg: string): () diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index d6f0a0c..bdd4d6f 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -73,24 +73,24 @@ TEST_CASE_FIXTURE(Fixture, "cannot_steal_hoisted_type_alias") if (FFlag::DebugLuauDeferredConstraintResolution) { CHECK(result.errors[0] == TypeError{ - Location{{1, 21}, {1, 26}}, - getMainSourceModule()->name, - TypeMismatch{ - getSingletonTypes().numberType, - getSingletonTypes().stringType, - }, - }); + Location{{1, 21}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + getSingletonTypes().numberType, + getSingletonTypes().stringType, + }, + }); } else { CHECK(result.errors[0] == TypeError{ - Location{{1, 8}, {1, 26}}, - getMainSourceModule()->name, - TypeMismatch{ - getSingletonTypes().numberType, - getSingletonTypes().stringType, - }, - }); + Location{{1, 8}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + getSingletonTypes().numberType, + getSingletonTypes().stringType, + }, + }); } } @@ -716,6 +716,10 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2") { + ScopedFastFlag sff[] = { + {"DebugLuauSharedSelf", true}, + }; + CheckResult result = check(R"( local B = {} B.bar = 4 @@ -737,7 +741,8 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni type FutureIntersection = A & B )"); - LUAU_REQUIRE_NO_ERRORS(result); + // TODO: shared self causes this test to break in bizarre ways. + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 3e2ad6d..8a86ee5 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -134,13 +134,13 @@ TEST_CASE_FIXTURE(Fixture, "unknown_type_reference_generates_error") LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK(result.errors[0] == TypeError{ - Location{{1, 17}, {1, 28}}, - getMainSourceModule()->name, - UnknownSymbol{ - "IDoNotExist", - UnknownSymbol::Context::Type, - }, - }); + Location{{1, 17}, {1, 28}}, + getMainSourceModule()->name, + UnknownSymbol{ + "IDoNotExist", + UnknownSymbol::Context::Type, + }, + }); } TEST_CASE_FIXTURE(Fixture, "typeof_variable_type_annotation_should_return_its_type") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 036a667..401a6c6 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -37,6 +37,27 @@ TEST_CASE_FIXTURE(Fixture, "check_function_bodies") }})); } +TEST_CASE_FIXTURE(Fixture, "cannot_hoist_interior_defns_into_signature") +{ + // This test verifies that the signature does not have access to types + // declared within the body. Under DCR, if the function's inner scope + // encompasses the entire function expression, it would be possible for this + // to type check (but the solver output is somewhat undefined). This test + // ensures that this isn't the case. + CheckResult result = check(R"( + local function f(x: T) + type T = number + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(result.errors[0] == TypeError{Location{{1, 28}, {1, 29}}, getMainSourceModule()->name, + UnknownSymbol{ + "T", + UnknownSymbol::Context::Type, + }}); +} + TEST_CASE_FIXTURE(Fixture, "infer_return_type") { CheckResult result = check("function take_five() return 5 end"); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 97ba080..e9e94cf 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -271,13 +271,16 @@ TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") TEST_CASE_FIXTURE(Fixture, "infer_generic_methods") { + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + CheckResult result = check(R"( local x = {} function x:id(x) return x end function x:f(): string return self:id("hello") end function x:g(): number return self:id(37) end )"); - LUAU_REQUIRE_NO_ERRORS(result); + // TODO: Quantification should be doing the conversion, not normalization. + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index fd9b1dd..e6174df 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -461,6 +461,61 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") +{ + CheckResult result = check(R"( + --!strict + local foo = { + value = 10 + } + + local mt = {} + setmetatable(foo, mt) + + mt.__unm = function(val: boolean): string + return "test" + end + + local a = -foo + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("string", toString(requireType("a"))); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE_EQ(*tm->wantedType, *typeChecker.booleanType); + // given type is the typeof(foo) which is complex to compare against +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error") +{ + ScopedFastFlag sff("LuauCheckLenMT", true); + + CheckResult result = check(R"( + --!strict + local foo = { + value = 10 + } + local mt = {} + setmetatable(foo, mt) + + mt.__len = function(val: any): string + return "test" + end + + local a = #foo + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("number", toString(requireType("a"))); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); + REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "unary_not_is_boolean") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 487e597..059aed2 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -499,4 +499,26 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_function_with_no_returns") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local T = {} + T.__index = T + + function T.new() + local self = setmetatable({}, T) + return self:ctor() or self + end + + function T:ctor() + -- oops, no return! + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Not all codepaths in this function return '{ @metatable T, {| |} }, a...'.", toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 77a2928..eead5b3 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1863,6 +1863,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") TEST_CASE_FIXTURE(BuiltinsFixture, "less_exponential_blowup_please") { + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + CheckResult result = check(R"( --!strict @@ -1890,7 +1892,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "less_exponential_blowup_please") newData:First() )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(2, result); } TEST_CASE_FIXTURE(Fixture, "common_table_element_union_in_call") @@ -2868,6 +2870,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") { ScopedFastFlag sff[] = { {"LuauLowerBoundsCalculation", true}, + {"DebugLuauSharedSelf", true}, }; check(R"( @@ -2887,7 +2890,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") end )"); - CHECK_EQ("(t1) -> {| Byte: (b) -> (a...), PeekByte: (c) -> (a...) |} where t1 = {+ byte: (t1, number) -> (a...) +}", + CHECK_EQ("(t1) -> {| Byte: (a) -> (b...), PeekByte: (a) -> (b...) |} where t1 = {+ byte: (t1, number) -> (b...) +}", toString(requireType("Base64FileReader"))); } @@ -2904,6 +2907,66 @@ TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[2])); } +TEST_CASE_FIXTURE(Fixture, "shared_selfs") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local t = {} + t.x = 5 + + function t:m1() return self.x end + function t:m2() return self.y end + + return t + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ("{| m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b, x: number |}", toString(requireType("t"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "shared_selfs_from_free_param") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local function f(t) + function t:m1() return self.x end + function t:m2() return self.y end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({+ m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b +}) -> ()", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "shared_selfs_through_metatables") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local t = {} + t.__index = t + setmetatable({}, t) + + function t:m1() return self.x end + function t:m2() return self.y end + + return t + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ( + toString(requireType("t"), opts), "t1 where t1 = {| __index: t1, m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b |}"); +} + TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") { CheckResult result = check(R"( @@ -2953,4 +3016,58 @@ TEST_CASE_FIXTURE(Fixture, "prop_access_on_unions_of_indexers_where_key_whose_ty CHECK_EQ("Type '{number} | {| [boolean]: number |}' does not have key 'x'", toString(result.errors[0])); } +TEST_CASE_FIXTURE(BuiltinsFixture, "quantify_metatables_of_metatables_of_table") +{ + ScopedFastFlag sff[]{ + {"DebugLuauSharedSelf", true}, + }; + + CheckResult result = check(R"( + local T = {} + + function T:m() + return self.x, self.y + end + + function T:n() + end + + local U = setmetatable({}, {__index = T}) + + local V = setmetatable({}, {__index = U}) + + return V + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ(toString(requireType("V"), opts), "{ @metatable { __index: { @metatable { __index: {| m: ({+ x: a, y: b +}) -> (a, b), n: ({+ x: a, y: b +}) -> () |} }, { } } }, { } }"); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_even_that_table_was_never_exported_at_all") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local T = {} + + function T:m() + return self.x + end + + function T:n() + return self.y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ("{| m: ({+ x: a, y: b +}) -> a, n: ({+ x: a, y: b +}) -> b |}", toString(requireType("T"), opts)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 6a048b2..efdfe0b 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -369,14 +369,14 @@ TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") CHECK_EQ("foo", us->name); } -TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_do") +TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") { CheckResult result = check(R"( do local a = 1 end - print(a) -- oops! + local b = a -- oops! )"); LUAU_REQUIRE_ERROR_COUNT(1, result); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index f803c31..385a045 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -118,10 +118,12 @@ assert((function() return #_G end)() == 0) assert((function() return #{1,2} end)() == 2) assert((function() return #'g' end)() == 1) -assert((function() local ud = newproxy(true) getmetatable(ud).__len = function() return 42 end return #ud end)() == 42) - assert((function() local a = 1 a = -a return a end)() == -1) +-- __len metamethod +assert((function() local ud = newproxy(true) getmetatable(ud).__len = function() return 42 end return #ud end)() == 42) +assert((function() local t = {} setmetatable(t, { __len = function() return 42 end }) return #t end)() == 42) + -- while/repeat assert((function() local a = 10 local b = 1 while a > 1 do b = b * 2 a = a - 1 end return b end)() == 512) assert((function() local a = 10 local b = 1 repeat b = b * 2 a = a - 1 until a == 1 return b end)() == 512) @@ -889,6 +891,10 @@ assert((function() return table.concat(res, ',') end)() == "6,8,10") +-- typeof and type require an argument +assert(pcall(typeof) == false) +assert(pcall(type) == false) + -- typeof == type in absence of custom userdata assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number,nil,table,userdata") diff --git a/tests/conformance/events.lua b/tests/conformance/events.lua index 32e090a..42f1bed 100644 --- a/tests/conformance/events.lua +++ b/tests/conformance/events.lua @@ -386,4 +386,42 @@ do assert(t.X) -- fails if table flags are set incorrectly end +do + -- verify __len behavior & error handling + local t = {1} + + setmetatable(t, {}) + assert(#t == 1) + + setmetatable(t, { __len = rawlen }) + assert(#t == 1) + + setmetatable(t, { __len = function() return 42 end }) + assert(#t == 42) + + setmetatable(t, { __len = 42 }) + local ok, err = pcall(function() return #t end) + assert(not ok and err:match("attempt to call a number value")) + + setmetatable(t, { __len = function() end }) + local ok, err = pcall(function() return #t end) + assert(not ok and err:match("'__len' must return a number")) + + setmetatable(t, { __len = error }) + local ok, err = pcall(function() return #t end) + assert(not ok and err == t) +end + +-- verify rawlen behavior +do + local t = {1} + setmetatable(t, { __len = 42 }) + + assert(rawlen(t) == 1) + assert(rawlen("foo") == 3) + + local ok, err = pcall(function() return rawlen(42) end) + assert(not ok and err:match("table or string expected")) +end + return 'OK' From 4a95f2201ec1f879dfa3f1b8bc015e0a30df005a Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 7 Jul 2022 18:05:31 -0700 Subject: [PATCH 19/19] Sync to upstream/release/535 --- Analysis/include/Luau/AstQuery.h | 1 + Analysis/include/Luau/TypeInfer.h | 25 +- Analysis/include/Luau/TypePack.h | 1 + Analysis/include/Luau/TypeVar.h | 157 +- Analysis/include/Luau/VisitTypeVar.h | 8 + Analysis/src/AstQuery.cpp | 107 +- Analysis/src/Autocomplete.cpp | 201 +- Analysis/src/BuiltinDefinitions.cpp | 43 +- Analysis/src/Clone.cpp | 12 + Analysis/src/EmbeddedBuiltinDefinitions.cpp | 9 +- Analysis/src/Frontend.cpp | 2 + Analysis/src/Normalize.cpp | 38 +- Analysis/src/Substitution.cpp | 10 +- Analysis/src/ToString.cpp | 23 +- Analysis/src/Transpiler.cpp | 20 +- Analysis/src/TxnLog.cpp | 59 +- Analysis/src/TypeAttach.cpp | 8 + Analysis/src/TypeInfer.cpp | 420 +++- Analysis/src/TypePack.cpp | 51 +- Analysis/src/TypeUtils.cpp | 2 +- Analysis/src/TypeVar.cpp | 269 ++- Analysis/src/Unifier.cpp | 66 +- Ast/src/Parser.cpp | 17 +- CLI/Analyze.cpp | 18 +- CLI/Repl.cpp | 37 + CodeGen/include/Luau/AssemblyBuilderX64.h | 3 + CodeGen/src/AssemblyBuilderX64.cpp | 23 + Makefile | 4 + Sources.cmake | 8 +- VM/src/ltable.cpp | 8 +- VM/src/lvmexecute.cpp | 66 +- bench/bench.py | 37 +- bench/bench_support.lua | 10 + bench/other/LuauPolyfillMap.lua | 961 +++++++++ bench/other/regex.lua | 2089 +++++++++++++++++++ tests/AssemblyBuilderX64.test.cpp | 27 + tests/AstQuery.test.cpp | 33 + tests/Fixture.h | 1 + tests/Frontend.test.cpp | 2 + tests/Module.test.cpp | 2 - tests/Normalize.test.cpp | 42 +- tests/Parser.test.cpp | 1 - tests/ToString.test.cpp | 33 +- tests/Transpiler.test.cpp | 2 +- tests/TypeInfer.anyerror.test.cpp | 8 +- tests/TypeInfer.builtins.test.cpp | 143 +- tests/TypeInfer.functions.test.cpp | 54 +- tests/TypeInfer.generics.test.cpp | 39 +- tests/TypeInfer.loops.test.cpp | 2 +- tests/TypeInfer.modules.test.cpp | 45 +- tests/TypeInfer.operators.test.cpp | 22 + tests/TypeInfer.primitives.test.cpp | 2 +- tests/TypeInfer.provisional.test.cpp | 15 +- tests/TypeInfer.refinements.test.cpp | 45 +- tests/TypeInfer.tables.test.cpp | 14 + tests/TypeInfer.test.cpp | 33 +- tests/TypeInfer.tryUnify.test.cpp | 4 +- tests/TypeInfer.unionTypes.test.cpp | 2 +- tests/TypeInfer.unknownnever.test.cpp | 280 +++ tests/TypePack.test.cpp | 2 - tests/TypeVar.test.cpp | 2 - tests/conformance/vector.lua | 16 + tools/natvis/VM.natvis | 32 +- 63 files changed, 5097 insertions(+), 619 deletions(-) create mode 100644 bench/other/LuauPolyfillMap.lua create mode 100644 bench/other/regex.lua create mode 100644 tests/TypeInfer.unknownnever.test.cpp diff --git a/Analysis/include/Luau/AstQuery.h b/Analysis/include/Luau/AstQuery.h index dfe373a..950a19d 100644 --- a/Analysis/include/Luau/AstQuery.h +++ b/Analysis/include/Luau/AstQuery.h @@ -63,6 +63,7 @@ private: AstLocal* local = nullptr; }; +std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos); std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos); AstNode* findNodeAtPosition(const SourceModule& source, Position pos); AstExpr* findExprAtPosition(const SourceModule& source, Position pos); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 455654d..3fb710b 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -153,7 +153,7 @@ struct TypeChecker const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); TypeId checkBinaryOperation( const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); - WithPredicate checkExpr(const ScopePtr& scope, const AstExprBinary& expr); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional expectedType = std::nullopt); WithPredicate checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprError& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); @@ -180,8 +180,12 @@ struct TypeChecker const ScopePtr& scope, Unifier& state, TypePackId paramPack, TypePackId argPack, const std::vector& argLocations); WithPredicate checkExprPack(const ScopePtr& scope, const AstExpr& expr); - WithPredicate checkExprPack(const ScopePtr& scope, const AstExprCall& expr); + + WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr); + WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr); + std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); + std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); @@ -236,10 +240,11 @@ struct TypeChecker void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location); - std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); - std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); + std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors); + std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors); std::optional getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); + std::optional getIndexTypeFromTypeImpl(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); // Reduces the union to its simplest possible shape. // (A | B) | B | C yields A | B | C @@ -316,11 +321,12 @@ private: TypeIdPredicate mkTruthyPredicate(bool sense); - // Returns nullopt if the predicate filters down the TypeId to 0 options. - std::optional filterMap(TypeId type, TypeIdPredicate predicate); + // TODO: Return TypeId only. + std::optional filterMapImpl(TypeId type, TypeIdPredicate predicate); + std::pair, bool> filterMap(TypeId type, TypeIdPredicate predicate); public: - std::optional pickTypesFromSense(TypeId type, bool sense); + std::pair, bool> pickTypesFromSense(TypeId type, bool sense); private: TypeId unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes = true); @@ -345,6 +351,7 @@ private: TypePackId freshTypePack(TypeLevel level); TypeId resolveType(const ScopePtr& scope, const AstType& annotation); + TypeId resolveTypeWorker(const ScopePtr& scope, const AstType& annotation); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, @@ -412,8 +419,12 @@ public: const TypeId booleanType; const TypeId threadType; const TypeId anyType; + const TypeId unknownType; + const TypeId neverType; const TypePackId anyTypePack; + const TypePackId neverTypePack; + const TypePackId uninhabitableTypePack; private: int checkRecursionCount = 0; diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index c1de242..b17003b 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -173,5 +173,6 @@ std::pair, std::optional> flatten(TypePackId tp, bool isVariadic(TypePackId tp); bool isVariadic(TypePackId tp, const TxnLog& log); +bool containsNever(TypePackId tp); } // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 6ad6b92..fb6093d 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -460,10 +460,18 @@ struct LazyTypeVar std::function thunk; }; +struct UnknownTypeVar +{ +}; + +struct NeverTypeVar +{ +}; + using ErrorTypeVar = Unifiable::Error; using TypeVariant = Unifiable::Variant; + MetatableTypeVar, ClassTypeVar, AnyTypeVar, UnionTypeVar, IntersectionTypeVar, LazyTypeVar, UnknownTypeVar, NeverTypeVar>; struct TypeVar final { @@ -575,8 +583,12 @@ struct SingletonTypes const TypeId trueType; const TypeId falseType; const TypeId anyType; + const TypeId unknownType; + const TypeId neverType; const TypePackId anyTypePack; + const TypePackId neverTypePack; + const TypePackId uninhabitableTypePack; SingletonTypes(); ~SingletonTypes(); @@ -632,12 +644,30 @@ T* getMutable(TypeId tv) return get_if(&asMutable(tv)->ty); } -/* Traverses the UnionTypeVar yielding each TypeId. - * If the iterator encounters a nested UnionTypeVar, it will instead yield each TypeId within. - * - * Beware: the iterator does not currently filter for unique TypeIds. This may change in the future. +const std::vector& getTypes(const UnionTypeVar* utv); +const std::vector& getTypes(const IntersectionTypeVar* itv); +const std::vector& getTypes(const ConstrainedTypeVar* ctv); + +template +struct TypeIterator; + +using UnionTypeVarIterator = TypeIterator; +UnionTypeVarIterator begin(const UnionTypeVar* utv); +UnionTypeVarIterator end(const UnionTypeVar* utv); + +using IntersectionTypeVarIterator = TypeIterator; +IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv); +IntersectionTypeVarIterator end(const IntersectionTypeVar* itv); + +using ConstrainedTypeVarIterator = TypeIterator; +ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv); +ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv); + +/* Traverses the type T yielding each TypeId. + * If the iterator encounters a nested type T, it will instead yield each TypeId within. */ -struct UnionTypeVarIterator +template +struct TypeIterator { using value_type = Luau::TypeId; using pointer = value_type*; @@ -645,33 +675,116 @@ struct UnionTypeVarIterator using difference_type = size_t; using iterator_category = std::input_iterator_tag; - explicit UnionTypeVarIterator(const UnionTypeVar* utv); + explicit TypeIterator(const T* t) + { + LUAU_ASSERT(t); - UnionTypeVarIterator& operator++(); - UnionTypeVarIterator operator++(int); - bool operator!=(const UnionTypeVarIterator& rhs); - bool operator==(const UnionTypeVarIterator& rhs); + const std::vector& types = getTypes(t); + if (!types.empty()) + stack.push_front({t, 0}); - const TypeId& operator*(); + seen.insert(t); + } - friend UnionTypeVarIterator end(const UnionTypeVar* utv); + TypeIterator& operator++() + { + advance(); + descend(); + return *this; + } + + TypeIterator operator++(int) + { + TypeIterator copy = *this; + ++copy; + return copy; + } + + bool operator==(const TypeIterator& rhs) const + { + if (!stack.empty() && !rhs.stack.empty()) + return stack.front() == rhs.stack.front(); + + return stack.empty() && rhs.stack.empty(); + } + + bool operator!=(const TypeIterator& rhs) const + { + return !(*this == rhs); + } + + const TypeId& operator*() + { + LUAU_ASSERT(!stack.empty()); + + descend(); + + auto [t, currentIndex] = stack.front(); + LUAU_ASSERT(t); + const std::vector& types = getTypes(t); + LUAU_ASSERT(currentIndex < types.size()); + + const TypeId& ty = types[currentIndex]; + LUAU_ASSERT(!get(follow(ty))); + return ty; + } + + // Normally, we'd have `begin` and `end` be a template but there's too much trouble + // with templates portability in this area, so not worth it. Thanks MSVC. + friend UnionTypeVarIterator end(const UnionTypeVar*); + friend IntersectionTypeVarIterator end(const IntersectionTypeVar*); + friend ConstrainedTypeVarIterator end(const ConstrainedTypeVar*); private: - UnionTypeVarIterator() = default; + TypeIterator() = default; - // (UnionTypeVar* utv, size_t currentIndex) - using SavedIterInfo = std::pair; + // (T* t, size_t currentIndex) + using SavedIterInfo = std::pair; std::deque stack; - std::unordered_set seen; // Only needed to protect the iterator from hanging the thread. + std::unordered_set seen; // Only needed to protect the iterator from hanging the thread. - void advance(); - void descend(); + void advance() + { + while (!stack.empty()) + { + auto& [t, currentIndex] = stack.front(); + ++currentIndex; + + const std::vector& types = getTypes(t); + if (currentIndex >= types.size()) + stack.pop_front(); + else + break; + } + } + + void descend() + { + while (!stack.empty()) + { + auto [current, currentIndex] = stack.front(); + const std::vector& types = getTypes(current); + if (auto inner = get(follow(types[currentIndex]))) + { + // If we're about to descend into a cyclic type, we should skip over this. + // Ideally this should never happen, but alas it does from time to time. :( + if (seen.find(inner) != seen.end()) + advance(); + else + { + seen.insert(inner); + stack.push_front({inner, 0}); + } + + continue; + } + + break; + } + } }; -UnionTypeVarIterator begin(const UnionTypeVar* utv); -UnionTypeVarIterator end(const UnionTypeVar* utv); - using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index 5fd43f0..ab4a397 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -129,6 +129,14 @@ struct GenericTypeVarVisitor { return visit(ty); } + virtual bool visit(TypeId ty, const UnknownTypeVar& atv) + { + return visit(ty); + } + virtual bool visit(TypeId ty, const NeverTypeVar& atv) + { + return visit(ty); + } virtual bool visit(TypeId ty, const UnionTypeVar& utv) { return visit(ty); diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 0522b1f..1124c29 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -17,6 +17,104 @@ namespace Luau namespace { + +struct AutocompleteNodeFinder : public AstVisitor +{ + const Position pos; + std::vector ancestry; + + explicit AutocompleteNodeFinder(Position pos, AstNode* root) + : pos(pos) + { + } + + bool visit(AstExpr* expr) override + { + if (expr->location.begin < pos && pos <= expr->location.end) + { + ancestry.push_back(expr); + return true; + } + return false; + } + + bool visit(AstStat* stat) override + { + if (stat->location.begin < pos && pos <= stat->location.end) + { + ancestry.push_back(stat); + return true; + } + return false; + } + + bool visit(AstType* type) override + { + if (type->location.begin < pos && pos <= type->location.end) + { + ancestry.push_back(type); + return true; + } + return false; + } + + bool visit(AstTypeError* type) override + { + // For a missing type, match the whole range including the start position + if (type->isMissing && type->location.containsClosed(pos)) + { + ancestry.push_back(type); + return true; + } + return false; + } + + bool visit(class AstTypePack* typePack) override + { + return true; + } + + bool visit(AstStatBlock* block) override + { + // If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite. + if (ancestry.empty()) + { + ancestry.push_back(block); + return true; + } + + // AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes. + // ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz} + if (!ancestry.empty() && ancestry.back()->is()) + return false; + + // Type annotation error might intersect the block statement when the function header is being written, + // annotation takes priority + if (!ancestry.empty() && ancestry.back()->is()) + return false; + + // If the cursor is at the end of an expression or type and simultaneously at the beginning of a block, + // the expression or type wins out. + // The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to + // be within the block. + if (block->location.begin == pos && !ancestry.empty()) + { + if (ancestry.back()->asExpr() && !ancestry.back()->is()) + return false; + + if (ancestry.back()->asType()) + return false; + } + + if (block->location.begin <= pos && pos <= block->location.end) + { + ancestry.push_back(block); + return true; + } + return false; + } +}; + struct FindNode : public AstVisitor { const Position pos; @@ -102,6 +200,13 @@ struct FindFullAncestry final : public AstVisitor } // namespace +std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos) +{ + AutocompleteNodeFinder finder{pos, source.root}; + source.root->visit(&finder); + return finder.ancestry; +} + std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos) { const Position end = source.root->location.end; @@ -110,7 +215,7 @@ std::vector findAstAncestryOfPosition(const SourceModule& source, Posi FindFullAncestry finder(pos, end); source.root->visit(&finder); - return std::move(finder.nodes); + return finder.nodes; } AstNode* findNodeAtPosition(const SourceModule& source, Position pos) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 8a63901..cc54d49 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -21,102 +21,6 @@ static const std::unordered_set kStatementStartingKeywords = { namespace Luau { -struct NodeFinder : public AstVisitor -{ - const Position pos; - std::vector ancestry; - - explicit NodeFinder(Position pos, AstNode* root) - : pos(pos) - { - } - - bool visit(AstExpr* expr) override - { - if (expr->location.begin < pos && pos <= expr->location.end) - { - ancestry.push_back(expr); - return true; - } - return false; - } - - bool visit(AstStat* stat) override - { - if (stat->location.begin < pos && pos <= stat->location.end) - { - ancestry.push_back(stat); - return true; - } - return false; - } - - bool visit(AstType* type) override - { - if (type->location.begin < pos && pos <= type->location.end) - { - ancestry.push_back(type); - return true; - } - return false; - } - - bool visit(AstTypeError* type) override - { - // For a missing type, match the whole range including the start position - if (type->isMissing && type->location.containsClosed(pos)) - { - ancestry.push_back(type); - return true; - } - return false; - } - - bool visit(class AstTypePack* typePack) override - { - return true; - } - - bool visit(AstStatBlock* block) override - { - // If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite. - if (ancestry.empty()) - { - ancestry.push_back(block); - return true; - } - - // AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes. - // ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz} - if (!ancestry.empty() && ancestry.back()->is()) - return false; - - // Type annotation error might intersect the block statement when the function header is being written, - // annotation takes priority - if (!ancestry.empty() && ancestry.back()->is()) - return false; - - // If the cursor is at the end of an expression or type and simultaneously at the beginning of a block, - // the expression or type wins out. - // The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to - // be within the block. - if (block->location.begin == pos && !ancestry.empty()) - { - if (ancestry.back()->asExpr() && !ancestry.back()->is()) - return false; - - if (ancestry.back()->asType()) - return false; - } - - if (block->location.begin <= pos && pos <= block->location.end) - { - ancestry.push_back(block); - return true; - } - return false; - } -}; static bool alreadyHasParens(const std::vector& nodes) { @@ -905,7 +809,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi } AstNode* parent = nullptr; - AstType* topType = nullptr; + AstType* topType = nullptr; // TODO: rename? for (auto it = ancestry.rbegin(), e = ancestry.rend(); it != e; ++it) { @@ -1477,21 +1381,20 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (isWithinComment(sourceModule, position)) return {}; - NodeFinder finder{position, sourceModule.root}; - sourceModule.root->visit(&finder); - LUAU_ASSERT(!finder.ancestry.empty()); - AstNode* node = finder.ancestry.back(); + std::vector ancestry = findAncestryAtPositionForAutocomplete(sourceModule, position); + LUAU_ASSERT(!ancestry.empty()); + AstNode* node = ancestry.back(); AstExprConstantNil dummy{Location{}}; - AstNode* parent = finder.ancestry.size() >= 2 ? finder.ancestry.rbegin()[1] : &dummy; + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; // If we are inside a body of a function that doesn't have a completed argument list, ignore the body node if (auto exprFunction = parent->as(); exprFunction && !exprFunction->argLocation && node == exprFunction->body) { - finder.ancestry.pop_back(); + ancestry.pop_back(); - node = finder.ancestry.back(); - parent = finder.ancestry.size() >= 2 ? finder.ancestry.rbegin()[1] : &dummy; + node = ancestry.back(); + parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; } if (auto indexName = node->as()) @@ -1504,47 +1407,47 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; if (!FFlag::LuauSelfCallAutocompleteFix2 && isString(ty)) - return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, finder.ancestry), - finder.ancestry}; + return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, ancestry), + ancestry}; else - return {autocompleteProps(*module, typeArena, ty, indexType, finder.ancestry), finder.ancestry}; + return {autocompleteProps(*module, typeArena, ty, indexType, ancestry), ancestry}; } else if (auto typeReference = node->as()) { if (typeReference->prefix) - return {autocompleteModuleTypes(*module, position, typeReference->prefix->value), finder.ancestry}; + return {autocompleteModuleTypes(*module, position, typeReference->prefix->value), ancestry}; else - return {autocompleteTypeNames(*module, position, finder.ancestry), finder.ancestry}; + return {autocompleteTypeNames(*module, position, ancestry), ancestry}; } else if (node->is()) { - return {autocompleteTypeNames(*module, position, finder.ancestry), finder.ancestry}; + return {autocompleteTypeNames(*module, position, ancestry), ancestry}; } else if (AstStatLocal* statLocal = node->as()) { if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) - return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; else return {}; } - else if (AstStatFor* statFor = extractStat(finder.ancestry)) + else if (AstStatFor* statFor = extractStat(ancestry)) { if (!statFor->hasDo || position < statFor->doLocation.begin) { if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || (statFor->step && statFor->step->location.containsClosed(position))) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; return {}; } - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; } else if (AstStatForIn* statForIn = parent->as(); statForIn && (node->is() || isIdentifier(node))) @@ -1560,7 +1463,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {}; } - return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; } if (!statForIn->hasDo || position <= statForIn->doLocation.begin) @@ -1569,58 +1472,58 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1]; if (lastExpr->location.containsClosed(position)) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; if (position > lastExpr->location.end) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; return {}; // Not sure what this means } } - else if (AstStatForIn* statForIn = extractStat(finder.ancestry)) + else if (AstStatForIn* statForIn = extractStat(ancestry)) { // The AST looks a bit differently if the cursor is at a position where only the "do" keyword is allowed. // ex "for f in f do" if (!statForIn->hasDo) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; } else if (AstStatWhile* statWhile = parent->as(); node->is() && statWhile) { if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; if (!statWhile->hasDo || position < statWhile->doLocation.begin) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; if (statWhile->hasDo && position > statWhile->doLocation.end) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; } - else if (AstStatWhile* statWhile = extractStat(finder.ancestry); statWhile && !statWhile->hasDo) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + else if (AstStatWhile* statWhile = extractStat(ancestry); statWhile && !statWhile->hasDo) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) { return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, - finder.ancestry}; + ancestry}; } else if (AstStatIf* statIf = parent->as(); statIf && node->is()) { if (statIf->condition->is()) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) - return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; } - else if (AstStatIf* statIf = extractStat(finder.ancestry); + else if (AstStatIf* statIf = extractStat(ancestry); statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position))) - return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry}; else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; - else if (AstStatRepeat* statRepeat = extractStat(finder.ancestry); statRepeat) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; + else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; else if (AstExprTable* exprTable = parent->as(); exprTable && (node->is() || node->is())) { for (const auto& [kind, key, value] : exprTable->items) @@ -1630,7 +1533,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (auto it = module->astExpectedTypes.find(exprTable)) { - auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, finder.ancestry); + auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, ancestry); // Remove keys that are already completed for (const auto& item : exprTable->items) @@ -1644,9 +1547,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M // If we know for sure that a key is being written, do not offer general expression suggestions if (!key) - autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position, result); + autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position, result); - return {result, finder.ancestry}; + return {result, ancestry}; } break; @@ -1654,11 +1557,11 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } } else if (isIdentifier(node) && (parent->is() || parent->is())) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; - if (std::optional ret = autocompleteStringParams(sourceModule, module, finder.ancestry, position, callback)) + if (std::optional ret = autocompleteStringParams(sourceModule, module, ancestry, position, callback)) { - return {*ret, finder.ancestry}; + return {*ret, ancestry}; } else if (node->is()) { @@ -1667,14 +1570,14 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (auto it = module->astExpectedTypes.find(node->asExpr())) autocompleteStringSingleton(*it, false, result); - if (finder.ancestry.size() >= 2) + if (ancestry.size() >= 2) { - if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) + if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as()) { if (auto it = module->astTypes.find(idxExpr->expr)) - autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry, result); + autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, ancestry, result); } - else if (auto binExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) + else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as()) { if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) { @@ -1684,7 +1587,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } } - return {result, finder.ancestry}; + return {result, ancestry}; } if (node->is()) @@ -1693,9 +1596,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } if (node->asExpr()) - return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, ancestry, position), ancestry}; else if (node->asStat()) - return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry}; return {}; } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 2f57e23..aeba2c1 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -9,6 +9,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) +LUAU_FASTFLAG(LuauUnknownAndNeverType) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -222,14 +223,14 @@ void registerBuiltinTypes(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); - // setmetatable({ @metatable MT }, MT) -> { @metatable MT } // clang-format off + // setmetatable(T, MT) -> { @metatable MT, T } addGlobalBinding(typeChecker, "setmetatable", arena.addType( FunctionTypeVar{ {genericMT}, {}, - arena.addTypePack(TypePack{{tableMetaMT, genericMT}}), + arena.addTypePack(TypePack{{FFlag::LuauUnknownAndNeverType ? tabTy : tableMetaMT, genericMT}}), arena.addTypePack(TypePack{{tableMetaMT}}) } ), "@luau" @@ -309,6 +310,12 @@ static std::optional> magicFunctionSetMetaTable( { auto [paramPack, _predicates] = withPredicate; + if (FFlag::LuauUnknownAndNeverType) + { + if (size(paramPack) < 2 && finite(paramPack)) + return std::nullopt; + } + TypeArena& arena = typechecker.currentModule->internalTypes; std::vector expectedArgs = typechecker.unTypePack(scope, paramPack, 2, expr.location); @@ -316,6 +323,12 @@ static std::optional> magicFunctionSetMetaTable( TypeId target = follow(expectedArgs[0]); TypeId mt = follow(expectedArgs[1]); + if (FFlag::LuauUnknownAndNeverType) + { + typechecker.tablify(target); + typechecker.tablify(mt); + } + if (const auto& tab = get(target)) { if (target->persistent) @@ -324,7 +337,8 @@ static std::optional> magicFunctionSetMetaTable( } else { - typechecker.tablify(mt); + if (!FFlag::LuauUnknownAndNeverType) + typechecker.tablify(mt); const TableTypeVar* mtTtv = get(mt); MetatableTypeVar mtv{target, mt}; @@ -343,7 +357,10 @@ static std::optional> magicFunctionSetMetaTable( if (FFlag::LuauSetMetaTableArgsCheck && expr.args.size < 1) { - return WithPredicate{}; + if (FFlag::LuauUnknownAndNeverType) + return std::nullopt; + else + return WithPredicate{}; } if (!FFlag::LuauSetMetaTableArgsCheck || !expr.self) @@ -390,11 +407,21 @@ static std::optional> magicFunctionAssert( if (head.size() > 0) { - std::optional newhead = typechecker.pickTypesFromSense(head[0], true); - if (!newhead) - head = {typechecker.nilType}; + auto [ty, ok] = typechecker.pickTypesFromSense(head[0], true); + if (FFlag::LuauUnknownAndNeverType) + { + if (get(*ty)) + head = {*ty}; + else + head[0] = *ty; + } else - head[0] = *newhead; + { + if (!ty) + head = {typechecker.nilType}; + else + head[0] = *ty; + } } return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index df4e0a6..88c5031 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -59,6 +59,8 @@ struct TypeCloner void operator()(const UnionTypeVar& t); void operator()(const IntersectionTypeVar& t); void operator()(const LazyTypeVar& t); + void operator()(const UnknownTypeVar& t); + void operator()(const NeverTypeVar& t); }; struct TypePackCloner @@ -310,6 +312,16 @@ void TypeCloner::operator()(const LazyTypeVar& t) defaultClone(t); } +void TypeCloner::operator()(const UnknownTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const NeverTypeVar& t) +{ + defaultClone(t); +} + } // anonymous namespace TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 1b5275f..f93f65d 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauCheckLenMT) namespace Luau @@ -116,8 +117,6 @@ declare function typeof(value: T): string -- `assert` has a magic function attached that will give more detailed type information declare function assert(value: T, errorMessage: string?): T -declare function error(message: T, level: number?) - declare function tostring(value: T): string declare function tonumber(value: T, radix: number?): number? @@ -204,12 +203,18 @@ declare function unpack(tab: {V}, i: number?, j: number?): ...V std::string getBuiltinDefinitionSource() { + std::string result = kBuiltinDefinitionLuaSrc; // TODO: move this into kBuiltinDefinitionLuaSrc if (FFlag::LuauCheckLenMT) result += "declare function rawlen(obj: {[K]: V} | string): number\n"; + if (FFlag::LuauUnknownAndNeverType) + result += "declare function error(message: T, level: number?): never\n"; + else + result += "declare function error(message: T, level: number?)\n"; + return result; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 4cfaa11..9195363 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -496,6 +496,8 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalastTypes.clear(); module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); + module->astResolvedTypes.clear(); + module->astResolvedTypePacks.clear(); module->scopes.resize(1); } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 8ce7f74..ce8f96c 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -14,7 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); -LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false); +LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau @@ -182,7 +182,6 @@ struct Normalize final : TypeVarVisitor { if (!ty->normal) asMutable(ty)->normal = true; - return false; } @@ -193,6 +192,20 @@ struct Normalize final : TypeVarVisitor return false; } + bool visit(TypeId ty, const UnknownTypeVar&) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + + bool visit(TypeId ty, const NeverTypeVar&) override + { + if (!ty->normal) + asMutable(ty)->normal = true; + return false; + } + bool visit(TypeId ty, const ConstrainedTypeVar& ctvRef) override { CHECK_ITERATION_LIMIT(false); @@ -416,7 +429,13 @@ struct Normalize final : TypeVarVisitor std::vector result; for (TypeId part : options) + { + // AnyTypeVar always win the battle no matter what we do, so we're done. + if (FFlag::LuauUnknownAndNeverType && get(follow(part))) + return {part}; + combineIntoUnion(result, part); + } return result; } @@ -427,7 +446,17 @@ struct Normalize final : TypeVarVisitor if (auto utv = get(ty)) { for (TypeId t : utv) + { + // AnyTypeVar always win the battle no matter what we do, so we're done. + if (FFlag::LuauUnknownAndNeverType && get(t)) + { + result = {t}; + return; + } + combineIntoUnion(result, t); + } + return; } @@ -571,8 +600,7 @@ struct Normalize final : TypeVarVisitor */ TypeId combine(Replacer& replacer, TypeId a, TypeId b) { - if (FFlag::LuauNormalizeCombineEqFix) - b = follow(b); + b = follow(b); if (FFlag::LuauNormalizeCombineTableFix && a == b) return a; @@ -592,7 +620,7 @@ struct Normalize final : TypeVarVisitor } else if (auto ttv = getMutable(a)) { - if (FFlag::LuauNormalizeCombineTableFix && !get(FFlag::LuauNormalizeCombineEqFix ? b : follow(b))) + if (FFlag::LuauNormalizeCombineTableFix && !get(b)) return arena.addType(IntersectionTypeVar{{a, b}}); combineIntoTable(replacer, ttv, b); return a; diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 9c4ce82..7245403 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -8,8 +8,10 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauAnyificationMustClone, false) LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) +LUAU_FASTFLAG(LuauUnknownAndNeverType) namespace Luau { @@ -154,7 +156,7 @@ TarjanResult Tarjan::loop() if (currEdge == -1) { ++childCount; - if (childLimit > 0 && childLimit < childCount) + if (childLimit > 0 && (FFlag::LuauUnknownAndNeverType ? childLimit <= childCount : childLimit < childCount)) return TarjanResult::TooManyChildren; stack.push_back(index); @@ -439,6 +441,9 @@ void Substitution::replaceChildren(TypeId ty) if (ignoreChildren(ty)) return; + if (FFlag::LuauAnyificationMustClone && ty->owningArena != arena) + return; + if (FunctionTypeVar* ftv = getMutable(ty)) { ftv->argTypes = replace(ftv->argTypes); @@ -490,6 +495,9 @@ void Substitution::replaceChildren(TypePackId tp) if (ignoreChildren(tp)) return; + if (FFlag::LuauAnyificationMustClone && tp->owningArena != arena) + return; + if (TypePack* tpp = getMutable(tp)) { for (TypeId& tv : tpp->head) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index fe940d5..c67d639 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauLowerBoundsCalculation) +LUAU_FASTFLAG(LuauUnknownAndNeverType) /* * Prefix generic typenames with gen- @@ -699,6 +700,12 @@ struct TypeVarStringifier void operator()(TypeId, const MetatableTypeVar& mtv) { state.result.invalid = true; + if (!state.exhaustive && mtv.syntheticName) + { + state.emit(*mtv.syntheticName); + return; + } + state.emit("{ @metatable "); stringify(mtv.metatable); state.emit(","); @@ -834,7 +841,7 @@ struct TypeVarStringifier void operator()(TypeId, const ErrorTypeVar& tv) { state.result.error = true; - state.emit("*unknown*"); + state.emit(FFlag::LuauUnknownAndNeverType ? "" : "*unknown*"); } void operator()(TypeId, const LazyTypeVar& ltv) @@ -843,7 +850,17 @@ struct TypeVarStringifier state.emit("lazy?"); } -}; // namespace + void operator()(TypeId, const UnknownTypeVar& ttv) + { + state.emit("unknown"); + } + + void operator()(TypeId, const NeverTypeVar& ttv) + { + state.emit("never"); + } + +}; struct TypePackStringifier { @@ -947,7 +964,7 @@ struct TypePackStringifier void operator()(TypePackId, const Unifiable::Error& error) { state.result.error = true; - state.emit("*unknown*"); + state.emit(FFlag::LuauUnknownAndNeverType ? "" : "*unknown*"); } void operator()(TypePackId, const VariadicTypePack& pack) diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 1577bd6..9feff1c 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -205,20 +205,6 @@ struct Printer } } - void visualizeWithSelf(AstExpr& expr, bool self) - { - if (!self) - return visualize(expr); - - AstExprIndexName* func = expr.as(); - LUAU_ASSERT(func); - - visualize(*func->expr); - writer.symbol(":"); - advance(func->indexLocation.begin); - writer.identifier(func->index.value); - } - void visualizeTypePackAnnotation(const AstTypePack& annotation, bool forVarArg) { advance(annotation.location.begin); @@ -366,7 +352,7 @@ struct Printer } else if (const auto& a = expr.as()) { - visualizeWithSelf(*a->func, a->self); + visualize(*a->func); writer.symbol("("); bool first = true; @@ -385,7 +371,7 @@ struct Printer else if (const auto& a = expr.as()) { visualize(*a->expr); - writer.symbol("."); + writer.symbol(std::string(1, a->op)); writer.write(a->index.value); } else if (const auto& a = expr.as()) @@ -766,7 +752,7 @@ struct Printer else if (const auto& a = program.as()) { writer.keyword("function"); - visualizeWithSelf(*a->name, a->func->self != nullptr); + visualize(*a->name); visualizeFunctionBody(*a->func); } else if (const auto& a = program.as()) diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 4c6d54e..b3f60d3 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,7 +7,7 @@ #include #include -LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) +LUAU_FASTFLAG(LuauUnknownAndNeverType) namespace Luau { @@ -81,34 +81,10 @@ void TxnLog::concat(TxnLog rhs) void TxnLog::commit() { for (auto& [ty, rep] : typeVarChanges) - { - if (FFlag::LuauNonCopyableTypeVarFields) - { - asMutable(ty)->reassign(rep.get()->pending); - } - else - { - TypeArena* owningArena = ty->owningArena; - TypeVar* mtv = asMutable(ty); - *mtv = rep.get()->pending; - mtv->owningArena = owningArena; - } - } + asMutable(ty)->reassign(rep.get()->pending); for (auto& [tp, rep] : typePackChanges) - { - if (FFlag::LuauNonCopyableTypeVarFields) - { - asMutable(tp)->reassign(rep.get()->pending); - } - else - { - TypeArena* owningArena = tp->owningArena; - TypePackVar* mpv = asMutable(tp); - *mpv = rep.get()->pending; - mpv->owningArena = owningArena; - } - } + asMutable(tp)->reassign(rep.get()->pending); clear(); } @@ -196,9 +172,7 @@ PendingType* TxnLog::queue(TypeId ty) if (!pending) { pending = std::make_unique(*ty); - - if (FFlag::LuauNonCopyableTypeVarFields) - pending->pending.owningArena = nullptr; + pending->pending.owningArena = nullptr; } return pending.get(); @@ -214,9 +188,7 @@ PendingTypePack* TxnLog::queue(TypePackId tp) if (!pending) { pending = std::make_unique(*tp); - - if (FFlag::LuauNonCopyableTypeVarFields) - pending->pending.owningArena = nullptr; + pending->pending.owningArena = nullptr; } return pending.get(); @@ -255,24 +227,14 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) { PendingType* newTy = queue(ty); - - if (FFlag::LuauNonCopyableTypeVarFields) - newTy->pending.reassign(replacement); - else - newTy->pending = replacement; - + newTy->pending.reassign(replacement); return newTy; } PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) { PendingTypePack* newTp = queue(tp); - - if (FFlag::LuauNonCopyableTypeVarFields) - newTp->pending.reassign(replacement); - else - newTp->pending = replacement; - + newTp->pending.reassign(replacement); return newTp; } @@ -289,7 +251,7 @@ PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { - LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + LUAU_ASSERT(get(ty) || get(ty) || get(ty) || get(ty)); PendingType* newTy = queue(ty); if (FreeTypeVar* ftv = Luau::getMutable(newTy)) @@ -305,6 +267,11 @@ PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { ftv->level = newLevel; } + else if (ConstrainedTypeVar* ctv = Luau::getMutable(newTy)) + { + if (FFlag::LuauUnknownAndNeverType) + ctv->level = newLevel; + } return newTy; } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 6cca712..2bc89cf 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -335,6 +335,14 @@ public: { return allocator->alloc(Location(), std::nullopt, AstName("")); } + AstType* operator()(const UnknownTypeVar& ttv) + { + return allocator->alloc(Location(), std::nullopt, AstName{"unknown"}); + } + AstType* operator()(const NeverTypeVar& ttv) + { + return allocator->alloc(Location(), std::nullopt, AstName{"never"}); + } private: Allocator* allocator; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index d9486a4..01939fd 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -31,6 +31,7 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) +LUAU_FASTFLAGVARIABLE(LuauIndexSilenceErrors, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix2, false) @@ -41,10 +42,12 @@ LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false); LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) +LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAG(LuauQuantifyConstrained) LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false) -LUAU_FASTFLAGVARIABLE(LuauNonCopyableTypeVarFields, false) LUAU_FASTFLAGVARIABLE(LuauCheckLenMT, false) +LUAU_FASTFLAGVARIABLE(LuauCheckGenericHOFTypes, false) +LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) namespace Luau { @@ -258,7 +261,11 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , booleanType(getSingletonTypes().booleanType) , threadType(getSingletonTypes().threadType) , anyType(getSingletonTypes().anyType) + , unknownType(getSingletonTypes().unknownType) + , neverType(getSingletonTypes().neverType) , anyTypePack(getSingletonTypes().anyTypePack) + , neverTypePack(getSingletonTypes().neverTypePack) + , uninhabitableTypePack(getSingletonTypes().uninhabitableTypePack) , duplicateTypeAliases{{false, {}}} { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -269,6 +276,11 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan globalScope->exportedTypeBindings["string"] = TypeFun{{}, stringType}; globalScope->exportedTypeBindings["boolean"] = TypeFun{{}, booleanType}; globalScope->exportedTypeBindings["thread"] = TypeFun{{}, threadType}; + if (FFlag::LuauUnknownAndNeverType) + { + globalScope->exportedTypeBindings["unknown"] = TypeFun{{}, unknownType}; + globalScope->exportedTypeBindings["never"] = TypeFun{{}, neverType}; + } } ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) @@ -456,6 +468,59 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) } } +struct InplaceDemoter : TypeVarOnceVisitor +{ + TypeLevel newLevel; + TypeArena* arena; + + InplaceDemoter(TypeLevel level, TypeArena* arena) + : newLevel(level) + , arena(arena) + { + } + + bool demote(TypeId ty) + { + if (auto level = getMutableLevel(ty)) + { + if (level->subsumesStrict(newLevel)) + { + *level = newLevel; + return true; + } + } + + return false; + } + + bool visit(TypeId ty, const BoundTypeVar& btyRef) override + { + return true; + } + + bool visit(TypeId ty) override + { + if (ty->owningArena != arena) + return false; + return demote(ty); + } + + bool visit(TypePackId tp, const FreeTypePack& ftpRef) override + { + if (tp->owningArena != arena) + return false; + + FreeTypePack* ftp = &const_cast(ftpRef); + if (ftp->level.subsumesStrict(newLevel)) + { + ftp->level = newLevel; + return true; + } + + return false; + } +}; + void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) { int subLevel = 0; @@ -559,7 +624,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A tablify(baseTy); if (!fun->func->self) - expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, false); + expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, /* addErrors= */ false); else if (auto ttv = getMutableTableType(baseTy)) { if (!baseTy->persistent && ttv->state != TableState::Sealed && !ttv->selfTy) @@ -579,7 +644,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A if (auto name = fun->name->as()) { TypeId exprTy = checkExpr(scope, *name->expr).type; - expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false); + expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false); } } } @@ -634,15 +699,8 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std TypeId type = bindings[name].type; if (get(follow(type))) { - if (FFlag::LuauNonCopyableTypeVarFields) - { - TypeVar* mty = asMutable(follow(type)); - mty->reassign(*errorRecoveryType(anyType)); - } - else - { - *asMutable(type) = *errorRecoveryType(anyType); - } + TypeVar* mty = asMutable(follow(type)); + mty->reassign(*errorRecoveryType(anyType)); reportError(TypeError{typealias->location, OccursCheckFailed{}}); } @@ -1206,7 +1264,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); } - if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location)) + if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location, /* addErrors= */ true)) { // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments @@ -1253,7 +1311,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) for (TypeId var : varTypes) unify(varTy, var, forin.location); - if (!get(iterTy) && !get(iterTy) && !get(iterTy)) + if (!get(iterTy) && !get(iterTy) && !get(iterTy) && !get(iterTy)) reportError(firstValue->location, CannotCallNonFunction{iterTy}); return check(loopScope, *forin.body); @@ -1350,7 +1408,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco TypeId exprTy = checkExpr(scope, *name->expr).type; TableTypeVar* ttv = getMutableTableType(exprTy); - if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false)) + if (!getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, /* addErrors= */ false)) { if (ttv || isTableIntersection(exprTy)) reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); @@ -1376,6 +1434,12 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); + if (FFlag::LuauUnknownAndNeverType) + { + InplaceDemoter demoter{funScope->level, ¤tModule->internalTypes}; + demoter.traverse(ty); + } + if (ttv && ttv->state != TableState::Sealed) ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; } @@ -1729,7 +1793,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) - result = checkExpr(scope, *a); + result = checkExpr(scope, *a, FFlag::LuauBinaryNeedsExpectedTypesToo ? expectedType : std::nullopt); else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) @@ -1851,41 +1915,56 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp lhsType = stripFromNilAndReport(lhsType, expr.expr->location); - if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) + if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, /* addErrors= */ true)) return {*ty}; return {errorRecoveryType(scope)}; } -std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location) +std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors) { ErrorVec errors; auto result = Luau::findTablePropertyRespectingMeta(errors, lhsType, name, location); - reportErrors(errors); + if (!FFlag::LuauIndexSilenceErrors || addErrors) + reportErrors(errors); return result; } -std::optional TypeChecker::findMetatableEntry(TypeId type, std::string entry, const Location& location) +std::optional TypeChecker::findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors) { ErrorVec errors; auto result = Luau::findMetatableEntry(errors, type, entry, location); - reportErrors(errors); + if (!FFlag::LuauIndexSilenceErrors || addErrors) + reportErrors(errors); return result; } std::optional TypeChecker::getIndexTypeFromType( - const ScopePtr& scope, TypeId type, const std::string& name, const Location& location, bool addErrors) + const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) +{ + size_t errorCount = currentModule->errors.size(); + + std::optional result = getIndexTypeFromTypeImpl(scope, type, name, location, addErrors); + + if (FFlag::LuauIndexSilenceErrors && !addErrors) + LUAU_ASSERT(errorCount == currentModule->errors.size()); + + return result; +} + +std::optional TypeChecker::getIndexTypeFromTypeImpl( + const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) { type = follow(type); - if (get(type) || get(type)) + if (get(type) || get(type) || get(type)) return type; tablify(type); if (isString(type)) { - std::optional mtIndex = findMetatableEntry(stringType, "__index", location); + std::optional mtIndex = findMetatableEntry(stringType, "__index", location, addErrors); LUAU_ASSERT(mtIndex); type = *mtIndex; } @@ -1919,7 +1998,7 @@ std::optional TypeChecker::getIndexTypeFromType( return result; } - if (auto found = findTablePropertyRespectingMeta(type, name, location)) + if (auto found = findTablePropertyRespectingMeta(type, name, location, addErrors)) return *found; } else if (const ClassTypeVar* cls = get(type)) @@ -1941,7 +2020,7 @@ std::optional TypeChecker::getIndexTypeFromType( if (get(follow(t))) return t; - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, /* addErrors= */ false)) goodOptions.push_back(*ty); else badOptions.push_back(t); @@ -1972,6 +2051,8 @@ std::optional TypeChecker::getIndexTypeFromType( else { std::vector result = reduceUnion(goodOptions); + if (FFlag::LuauUnknownAndNeverType && result.empty()) + return neverType; if (result.size() == 1) return result[0]; @@ -1987,7 +2068,7 @@ std::optional TypeChecker::getIndexTypeFromType( { RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, /* addErrors= */ false)) parts.push_back(*ty); } @@ -2017,6 +2098,9 @@ std::vector TypeChecker::reduceUnion(const std::vector& types) for (TypeId t : types) { t = follow(t); + if (get(t)) + continue; + if (get(t) || get(t)) return {t}; @@ -2028,6 +2112,8 @@ std::vector TypeChecker::reduceUnion(const std::vector& types) { if (FFlag::LuauNormalizeFlagIsConservative) ty = follow(ty); + if (get(ty)) + continue; if (get(ty) || get(ty)) return {ty}; @@ -2041,6 +2127,8 @@ std::vector TypeChecker::reduceUnion(const std::vector& types) for (TypeId ty : r) { ty = follow(ty); + if (get(ty)) + continue; if (get(ty) || get(ty)) return {ty}; @@ -2314,14 +2402,14 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {booleanType, {NotPredicate{std::move(result.predicates)}}}; case AstExprUnary::Minus: { - const bool operandIsAny = get(operandType) || get(operandType); + const bool operandIsAny = get(operandType) || get(operandType) || get(operandType); if (operandIsAny) return {operandType}; if (typeCouldHaveMetatable(operandType)) { - if (auto fnt = findMetatableEntry(operandType, "__unm", expr.location)) + if (auto fnt = findMetatableEntry(operandType, "__unm", expr.location, /* addErrors= */ true)) { TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypePackId arguments = addTypePack({operandType}); @@ -2355,14 +2443,14 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp operandType = stripFromNilAndReport(operandType, expr.location); - if (get(operandType)) - return {errorRecoveryType(scope)}; + if (get(operandType) || get(operandType)) + return {!FFlag::LuauUnknownAndNeverType ? errorRecoveryType(scope) : operandType}; DenseHashSet seen{nullptr}; if (FFlag::LuauCheckLenMT && typeCouldHaveMetatable(operandType)) { - if (auto fnt = findMetatableEntry(operandType, "__len", expr.location)) + if (auto fnt = findMetatableEntry(operandType, "__len", expr.location, /* addErrors= */ true)) { TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypePackId arguments = addTypePack({operandType}); @@ -2433,6 +2521,9 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, b return a; std::vector types = reduceUnion({a, b}); + if (FFlag::LuauUnknownAndNeverType && types.empty()) + return neverType; + if (types.size() == 1) return types[0]; @@ -2485,7 +2576,7 @@ TypeId TypeChecker::checkRelationalOperation( // If we know nothing at all about the lhs type, we can usually say nothing about the result. // The notable exception to this is the equality and inequality operators, which always produce a boolean. - const bool lhsIsAny = get(lhsType) || get(lhsType); + const bool lhsIsAny = get(lhsType) || get(lhsType) || get(lhsType); // Peephole check for `cond and a or b -> type(a)|type(b)` // TODO: Kill this when singleton types arrive. :( @@ -2508,7 +2599,7 @@ TypeId TypeChecker::checkRelationalOperation( if (isNonstrictMode() && (isNil(lhsType) || isNil(rhsType))) return booleanType; - const bool rhsIsAny = get(rhsType) || get(rhsType); + const bool rhsIsAny = get(rhsType) || get(rhsType) || get(rhsType); if (lhsIsAny || rhsIsAny) return booleanType; @@ -2596,7 +2687,7 @@ TypeId TypeChecker::checkRelationalOperation( if (leftMetatable) { - std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location); + std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location, /* addErrors= */ true); if (metamethod) { if (const FunctionTypeVar* ftv = get(*metamethod)) @@ -2757,9 +2848,9 @@ TypeId TypeChecker::checkBinaryOperation( }; std::string op = opToMetaTableEntry(expr.op); - if (auto fnt = findMetatableEntry(lhsType, op, expr.location)) + if (auto fnt = findMetatableEntry(lhsType, op, expr.location, /* addErrors= */ true)) return checkMetatableCall(*fnt, lhsType, rhsType); - if (auto fnt = findMetatableEntry(rhsType, op, expr.location)) + if (auto fnt = findMetatableEntry(rhsType, op, expr.location, /* addErrors= */ true)) { // Note the intentionally reversed arguments here. return checkMetatableCall(*fnt, rhsType, lhsType); @@ -2793,27 +2884,27 @@ TypeId TypeChecker::checkBinaryOperation( } } -WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr) +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr, std::optional expectedType) { if (expr.op == AstExprBinary::And) { - auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left); + auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left, expectedType); ScopePtr innerScope = childScope(scope, expr.location); resolve(lhsPredicates, innerScope, true); - auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); + auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right, expectedType); return {checkBinaryOperation(scope, expr, lhsTy, rhsTy), {AndPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::Or) { - auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left); + auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left, expectedType); ScopePtr innerScope = childScope(scope, expr.location); resolve(lhsPredicates, innerScope, false); - auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); + auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right, expectedType); // Because of C++, I'm not sure if lhsPredicates was not moved out by the time we call checkBinaryOperation. TypeId result = checkBinaryOperation(scope, expr, lhsTy, rhsTy, lhsPredicates); @@ -2824,6 +2915,8 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; + // For these, passing expectedType is worse than simply forcing them, because their implementation + // may inadvertently check if expectedTypes exist first and use it, instead of forceSingleton first. WithPredicate lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); WithPredicate rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); @@ -2842,6 +2935,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp } else { + // Expected types are not useful for other binary operators. WithPredicate lhs = checkExpr(scope, *expr.left); WithPredicate rhs = checkExpr(scope, *expr.right); @@ -2896,6 +2990,8 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {trueType.type}; std::vector types = reduceUnion({trueType.type, falseType.type}); + if (FFlag::LuauUnknownAndNeverType && types.empty()) + return {neverType}; return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})}; } @@ -2927,7 +3023,10 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& exp TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr) { if (std::optional ty = scope->lookup(expr.local)) - return *ty; + { + ty = follow(*ty); + return get(*ty) ? unknownType : *ty; + } reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); return errorRecoveryType(scope); @@ -2941,7 +3040,10 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGloba const auto it = moduleScope->bindings.find(expr.name); if (it != moduleScope->bindings.end()) - return it->second.typeId; + { + TypeId ty = follow(it->second.typeId); + return get(ty) ? unknownType : ty; + } TypeId result = freshType(scope); Binding& binding = moduleScope->bindings[expr.name]; @@ -2962,6 +3064,9 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex if (get(lhs) || get(lhs)) return lhs; + if (get(lhs)) + return unknownType; + tablify(lhs); Name name = expr.index.value; @@ -3023,7 +3128,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } else if (get(lhs)) { - if (std::optional ty = getIndexTypeFromType(scope, lhs, name, expr.location, false)) + if (std::optional ty = getIndexTypeFromType(scope, lhs, name, expr.location, /* addErrors= */ false)) return *ty; // If intersection has a table part, report that it cannot be extended just as a sealed table @@ -3050,6 +3155,9 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex if (get(exprType) || get(exprType)) return exprType; + if (get(exprType)) + return unknownType; + AstExprConstantString* value = expr.index->as(); if (value) @@ -3156,7 +3264,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T if (!ttv || ttv->state == TableState::Sealed) { - if (auto ty = getIndexTypeFromType(scope, lhsType, indexName->index.value, indexName->indexLocation, false)) + if (auto ty = getIndexTypeFromType(scope, lhsType, indexName->index.value, indexName->indexLocation, /* addErrors= */ false)) return *ty; return errorRecoveryType(scope); @@ -3228,9 +3336,12 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& } } - // We do not infer type binders, so if a generic function is required we do not propagate - if (expectedFunctionType && !(expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty())) - expectedFunctionType = nullptr; + if (!FFlag::LuauCheckGenericHOFTypes) + { + // We do not infer type binders, so if a generic function is required we do not propagate + if (expectedFunctionType && !(expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty())) + expectedFunctionType = nullptr; + } } auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); @@ -3240,7 +3351,8 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& retPack = resolveTypePack(funScope, *expr.returnAnnotation); else if (FFlag::LuauReturnTypeInferenceInNonstrict ? (!FFlag::LuauLowerBoundsCalculation && isNonstrictMode()) : isNonstrictMode()) retPack = anyTypePack; - else if (expectedFunctionType) + else if (expectedFunctionType && + (!FFlag::LuauCheckGenericHOFTypes || (expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty()))) { auto [head, tail] = flatten(expectedFunctionType->retTypes); @@ -3371,16 +3483,50 @@ std::pair TypeChecker::checkFunctionSignature(const ScopePtr& defn.originalNameLocation = originalName.value_or(Location(expr.location.begin, 0)); std::vector genericTys; - genericTys.reserve(generics.size()); - std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { - return el.ty; - }); + // if we have a generic expected function type and no generics, we should use the expected ones. + if (FFlag::LuauCheckGenericHOFTypes) + { + if (expectedFunctionType && generics.empty()) + { + genericTys = expectedFunctionType->generics; + } + else + { + genericTys.reserve(generics.size()); + for (const GenericTypeDefinition& generic : generics) + genericTys.push_back(generic.ty); + } + } + else + { + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + } std::vector genericTps; - genericTps.reserve(genericPacks.size()); - std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { - return el.tp; - }); + // if we have a generic expected function type and no generic typepacks, we should use the expected ones. + if (FFlag::LuauCheckGenericHOFTypes) + { + if (expectedFunctionType && genericPacks.empty()) + { + genericTps = expectedFunctionType->genericPacks; + } + else + { + genericTps.reserve(genericPacks.size()); + for (const GenericTypePackDefinition& generic : genericPacks) + genericTps.push_back(generic.tp); + } + } + else + { + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + } TypeId funTy = addType(FunctionTypeVar(funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, std::move(defn), bool(expr.self))); @@ -3474,9 +3620,22 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE } WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) +{ + if (FFlag::LuauUnknownAndNeverType) + { + WithPredicate result = checkExprPackHelper(scope, expr); + if (containsNever(result.type)) + return {uninhabitableTypePack}; + return result; + } + else + return checkExprPackHelper(scope, expr); +} + +WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr) { if (auto a = expr.as()) - return checkExprPack(scope, *a); + return checkExprPackHelper(scope, *a); else if (expr.is()) { if (!scope->varargPack) @@ -3739,7 +3898,7 @@ void TypeChecker::checkArgumentList( } } -WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, const AstExprCall& expr) +WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr) { // evaluate type of function // decompose an intersection into its component overloads @@ -3763,7 +3922,7 @@ WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, cons selfType = checkExpr(scope, *indexExpr->expr).type; selfType = stripFromNilAndReport(selfType, expr.func->location); - if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true)) + if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, /* addErrors= */ true)) { functionType = *propTy; actualFunctionType = instantiate(scope, functionType, expr.func->location); @@ -3813,11 +3972,25 @@ WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, cons if (get(argPack)) return {errorRecoveryTypePack(scope)}; - TypePack* args = getMutable(argPack); - LUAU_ASSERT(args != nullptr); + TypePack* args = nullptr; + if (FFlag::LuauUnknownAndNeverType) + { + if (expr.self) + { + argPack = addTypePack(TypePack{{selfType}, argPack}); + argListResult.type = argPack; + } + args = getMutable(argPack); + LUAU_ASSERT(args); + } + else + { + args = getMutable(argPack); + LUAU_ASSERT(args != nullptr); - if (expr.self) - args->head.insert(args->head.begin(), selfType); + if (expr.self) + args->head.insert(args->head.begin(), selfType); + } std::vector argLocations; argLocations.reserve(expr.args.size + 1); @@ -3876,7 +4049,10 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st else { std::vector result = reduceUnion({*el, ty}); - el = result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); + if (FFlag::LuauUnknownAndNeverType && result.empty()) + el = neverType; + else + el = result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); } } }; @@ -3930,6 +4106,9 @@ std::optional> TypeChecker::checkCallOverload(const Sc return {{errorRecoveryTypePack(scope)}}; } + if (get(fn)) + return {{uninhabitableTypePack}}; + if (auto ftv = get(fn)) { // fn is one of the overloads of actualFunctionType, which @@ -3975,7 +4154,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc // Might be a callable table if (const MetatableTypeVar* mttv = get(fn)) { - if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, false)) + if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, /* addErrors= */ false)) { // Construct arguments with 'self' added in front TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); @@ -4202,6 +4381,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, bool substituteFreeForNil, const std::vector& instantiateGenerics, const std::vector>& expectedTypes) { + bool uninhabitable = false; TypePackId pack = addTypePack(TypePack{}); PredicateVec predicates; // At the moment we will be pushing all predicate sets into this. Do we need some way to split them up? @@ -4232,7 +4412,13 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons auto [typePack, exprPredicates] = checkExprPack(scope, *expr); insert(exprPredicates); - if (std::optional firstTy = first(typePack)) + if (FFlag::LuauUnknownAndNeverType && containsNever(typePack)) + { + // f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never, ...never) + uninhabitable = true; + continue; + } + else if (std::optional firstTy = first(typePack)) { if (!currentModule->astTypes.find(expr)) currentModule->astTypes[expr] = follow(*firstTy); @@ -4248,6 +4434,13 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons auto [type, exprPredicates] = checkExpr(scope, *expr, expectedType); insert(exprPredicates); + if (FFlag::LuauUnknownAndNeverType && get(type)) + { + // f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never, ...never) + uninhabitable = true; + continue; + } + TypeId actualType = substituteFreeForNil && expr->is() ? freshType(scope) : type; if (instantiateGenerics.size() > i && instantiateGenerics[i]) @@ -4272,6 +4465,8 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons for (TxnLog& log : inverseLogs) log.commit(); + if (FFlag::LuauUnknownAndNeverType && uninhabitable) + return {uninhabitableTypePack}; return {pack, predicates}; } @@ -4830,7 +5025,7 @@ TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) }; } -std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) +std::optional TypeChecker::filterMapImpl(TypeId type, TypeIdPredicate predicate) { std::vector types = Luau::filterMap(type, predicate); if (!types.empty()) @@ -4838,7 +5033,21 @@ std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predic return std::nullopt; } -std::optional TypeChecker::pickTypesFromSense(TypeId type, bool sense) +std::pair, bool> TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) +{ + if (FFlag::LuauUnknownAndNeverType) + { + TypeId ty = filterMapImpl(type, predicate).value_or(neverType); + return {ty, !bool(get(ty))}; + } + else + { + std::optional ty = filterMapImpl(type, predicate); + return {ty, bool(ty)}; + } +} + +std::pair, bool> TypeChecker::pickTypesFromSense(TypeId type, bool sense) { return filterMap(type, mkTruthyPredicate(sense)); } @@ -4884,6 +5093,13 @@ TypePackId TypeChecker::freshTypePack(TypeLevel level) } TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation) +{ + TypeId ty = resolveTypeWorker(scope, annotation); + currentModule->astResolvedTypes[&annotation] = ty; + return ty; +} + +TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& annotation) { if (const auto& lit = annotation.as()) { @@ -5200,9 +5416,10 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypeList TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation) { + TypePackId result; if (const AstTypePackVariadic* variadic = annotation.as()) { - return addTypePack(TypePackVar{VariadicTypePack{resolveType(scope, *variadic->variadicType)}}); + result = addTypePack(TypePackVar{VariadicTypePack{resolveType(scope, *variadic->variadicType)}}); } else if (const AstTypePackGeneric* generic = annotation.as()) { @@ -5216,10 +5433,12 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack else reportError(TypeError{generic->location, UnknownSymbol{genericName, UnknownSymbol::Type}}); - return errorRecoveryTypePack(scope); + result = errorRecoveryTypePack(scope); + } + else + { + result = *genericTy; } - - return *genericTy; } else if (const AstTypePackExplicit* explicitTp = annotation.as()) { @@ -5229,14 +5448,17 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack types.push_back(resolveType(scope, *type)); if (auto tailType = explicitTp->typeList.tailType) - return addTypePack(types, resolveTypePack(scope, *tailType)); - - return addTypePack(types); + result = addTypePack(types, resolveTypePack(scope, *tailType)); + else + result = addTypePack(types); } else { ice("Unknown AstTypePack kind"); } + + currentModule->astResolvedTypePacks[&annotation] = result; + return result; } bool ApplyTypeFunction::isDirty(TypeId ty) @@ -5452,10 +5674,18 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const // If we do not have a key, it means we're not trying to discriminate anything, so it's a simple matter of just filtering for a subset. if (!key) { - if (std::optional result = filterMap(*ty, predicate)) + auto [result, ok] = filterMap(*ty, predicate); + if (FFlag::LuauUnknownAndNeverType) + { addRefinement(refis, *target, *result); + } else - addRefinement(refis, *target, errorRecoveryType(scope)); + { + if (ok) + addRefinement(refis, *target, *result); + else + addRefinement(refis, *target, errorRecoveryType(scope)); + } return; } @@ -5471,17 +5701,29 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const { std::optional discriminantTy; if (auto field = Luau::get(*key)) // need to fully qualify Luau::get because of ADL. - discriminantTy = getIndexTypeFromType(scope, option, field->key, Location(), false); + discriminantTy = getIndexTypeFromType(scope, option, field->key, Location(), /* addErrors= */ false); else LUAU_ASSERT(!"Unhandled LValue alternative?"); if (!discriminantTy) return; // Do nothing. An error was already reported, as per usual. - if (std::optional result = filterMap(*discriminantTy, predicate)) + auto [result, ok] = filterMap(*discriminantTy, predicate); + if (FFlag::LuauUnknownAndNeverType) { - viableTargetOptions.insert(option); - viableChildOptions.insert(*result); + if (!get(*result)) + { + viableTargetOptions.insert(option); + viableChildOptions.insert(*result); + } + } + else + { + if (ok) + { + viableTargetOptions.insert(option); + viableChildOptions.insert(*result); + } } } @@ -5560,7 +5802,7 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV continue; else if (auto field = get(key)) { - found = getIndexTypeFromType(scope, *found, field->key, Location(), false); + found = getIndexTypeFromType(scope, *found, field->key, Location(), /* addErrors= */ false); if (!found) return std::nullopt; // Turns out this type doesn't have the property at all. We're done. } @@ -5740,6 +5982,9 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r auto mkFilter = [](ConditionFunc f, std::optional other = std::nullopt) -> SenseToTypeIdPredicate { return [f, other](bool sense) -> TypeIdPredicate { return [f, other, sense](TypeId ty) -> std::optional { + if (FFlag::LuauUnknownAndNeverType && sense && get(ty)) + return other.value_or(ty); + if (f(ty) == sense) return ty; @@ -5847,8 +6092,15 @@ std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp for (size_t i = 0; i < expectedLength; ++i) expectedPack->head.push_back(freshType(scope)); + size_t oldErrorsSize = currentModule->errors.size(); + unify(tp, expectedTypePack, location); + // HACK: tryUnify would undo the changes to the expectedTypePack if the length mismatches, but + // we want to tie up free types to be error types, so we do this instead. + if (FFlag::LuauUnknownAndNeverType) + currentModule->errors.resize(oldErrorsSize); + for (TypeId& tp : expectedPack->head) tp = follow(tp); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 82451bd..d454448 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -5,8 +5,6 @@ #include -LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) - namespace Luau { @@ -40,19 +38,10 @@ TypePackVar& TypePackVar::operator=(TypePackVariant&& tp) TypePackVar& TypePackVar::operator=(const TypePackVar& rhs) { - if (FFlag::LuauNonCopyableTypeVarFields) - { - LUAU_ASSERT(owningArena == rhs.owningArena); - LUAU_ASSERT(!rhs.persistent); + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); - reassign(rhs); - } - else - { - ty = rhs.ty; - persistent = rhs.persistent; - owningArena = rhs.owningArena; - } + reassign(rhs); return *this; } @@ -294,6 +283,16 @@ std::optional first(TypePackId tp, bool ignoreHiddenVariadics) return std::nullopt; } +TypePackVar* asMutable(TypePackId tp) +{ + return const_cast(tp); +} + +TypePack* asMutable(const TypePack* tp) +{ + return const_cast(tp); +} + bool isEmpty(TypePackId tp) { tp = follow(tp); @@ -360,13 +359,25 @@ bool isVariadic(TypePackId tp, const TxnLog& log) return false; } -TypePackVar* asMutable(TypePackId tp) +bool containsNever(TypePackId tp) { - return const_cast(tp); + auto it = begin(tp); + auto endIt = end(tp); + + while (it != endIt) + { + if (get(follow(*it))) + return true; + ++it; + } + + if (auto tail = it.tail()) + { + if (auto vtp = get(*tail); vtp && get(follow(vtp->ty))) + return true; + } + + return false; } -TypePack* asMutable(const TypePack* tp) -{ - return const_cast(tp); -} } // namespace Luau diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 3d97e6e..66b38cf 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -24,7 +24,7 @@ std::optional findMetatableEntry(ErrorVec& errors, TypeId type, std::str const TableTypeVar* mtt = getTableType(unwrapped); if (!mtt) { - errors.push_back(TypeError{location, GenericError{"Metatable was not a table."}}); + errors.push_back(TypeError{location, GenericError{"Metatable was not a table"}}); return std::nullopt; } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index ade70d7..f884ad7 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -23,7 +23,9 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAG(LuauNonCopyableTypeVarFields) +LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauDeduceGmatchReturnTypes, false) +LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) namespace Luau { @@ -31,6 +33,9 @@ namespace Luau std::optional> magicFunctionFormat( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static std::optional> magicFunctionGmatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); + TypeId follow(TypeId t) { return follow(t, [](TypeId t) { @@ -173,8 +178,8 @@ bool maybeString(TypeId ty) { ty = follow(ty); - if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) - return true; + if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) + return true; if (auto utv = get(ty)) return std::any_of(begin(utv), end(utv), maybeString); @@ -194,7 +199,7 @@ bool isOptional(TypeId ty) ty = follow(ty); - if (get(ty)) + if (get(ty) || (FFlag::LuauUnknownAndNeverType && get(ty))) return true; auto utv = get(ty); @@ -334,6 +339,28 @@ bool isGeneric(TypeId ty) bool maybeGeneric(TypeId ty) { + if (FFlag::LuauMaybeGenericIntersectionTypes) + { + ty = follow(ty); + + if (get(ty)) + return true; + + if (auto ttv = get(ty)) + { + // TODO: recurse on table types CLI-39914 + (void)ttv; + return true; + } + + if (auto itv = get(ty)) + { + return std::any_of(begin(itv), end(itv), maybeGeneric); + } + + return isGeneric(ty); + } + ty = follow(ty); if (get(ty)) return true; @@ -646,20 +673,10 @@ TypeVar& TypeVar::operator=(TypeVariant&& rhs) TypeVar& TypeVar::operator=(const TypeVar& rhs) { - if (FFlag::LuauNonCopyableTypeVarFields) - { - LUAU_ASSERT(owningArena == rhs.owningArena); - LUAU_ASSERT(!rhs.persistent); + LUAU_ASSERT(owningArena == rhs.owningArena); + LUAU_ASSERT(!rhs.persistent); - reassign(rhs); - } - else - { - ty = rhs.ty; - persistent = rhs.persistent; - normal = rhs.normal; - owningArena = rhs.owningArena; - } + reassign(rhs); return *this; } @@ -676,10 +693,14 @@ static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persist static TypeVar trueType_{SingletonTypeVar{BooleanSingleton{true}}, /*persistent*/ true}; static TypeVar falseType_{SingletonTypeVar{BooleanSingleton{false}}, /*persistent*/ true}; static TypeVar anyType_{AnyTypeVar{}, /*persistent*/ true}; +static TypeVar unknownType_{UnknownTypeVar{}, /*persistent*/ true}; +static TypeVar neverType_{NeverTypeVar{}, /*persistent*/ true}; static TypeVar errorType_{ErrorTypeVar{}, /*persistent*/ true}; -static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true}; -static TypePackVar errorTypePack_{Unifiable::Error{}}; +static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, /*persistent*/ true}; +static TypePackVar errorTypePack_{Unifiable::Error{}, /*persistent*/ true}; +static TypePackVar neverTypePack_{VariadicTypePack{&neverType_}, /*persistent*/ true}; +static TypePackVar uninhabitableTypePack_{TypePack{{&neverType_}, &neverTypePack_}, /*persistent*/ true}; SingletonTypes::SingletonTypes() : nilType(&nilType_) @@ -690,7 +711,11 @@ SingletonTypes::SingletonTypes() , trueType(&trueType_) , falseType(&falseType_) , anyType(&anyType_) + , unknownType(&unknownType_) + , neverType(&neverType_) , anyTypePack(&anyTypePack_) + , neverTypePack(&neverTypePack_) + , uninhabitableTypePack(&uninhabitableTypePack_) , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); @@ -738,6 +763,7 @@ TypeId SingletonTypes::makeStringMetatable() const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); const TypeId gmatchFunc = makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionTypeVar{emptyPack, stringVariadicList})}); + attachMagicFunction(gmatchFunc, magicFunctionGmatch); TableTypeVar::Props stringLib = { {"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, @@ -911,6 +937,8 @@ const TypeLevel* getLevel(TypeId ty) return &ttv->level; else if (auto ftv = get(ty)) return &ftv->level; + else if (auto ctv = get(ty)) + return &ctv->level; else return nullptr; } @@ -965,94 +993,19 @@ bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent) return false; } -UnionTypeVarIterator::UnionTypeVarIterator(const UnionTypeVar* utv) +const std::vector& getTypes(const UnionTypeVar* utv) { - LUAU_ASSERT(utv); - - if (!utv->options.empty()) - stack.push_front({utv, 0}); - - seen.insert(utv); + return utv->options; } -UnionTypeVarIterator& UnionTypeVarIterator::operator++() +const std::vector& getTypes(const IntersectionTypeVar* itv) { - advance(); - descend(); - return *this; + return itv->parts; } -UnionTypeVarIterator UnionTypeVarIterator::operator++(int) +const std::vector& getTypes(const ConstrainedTypeVar* ctv) { - UnionTypeVarIterator copy = *this; - ++copy; - return copy; -} - -bool UnionTypeVarIterator::operator!=(const UnionTypeVarIterator& rhs) -{ - return !(*this == rhs); -} - -bool UnionTypeVarIterator::operator==(const UnionTypeVarIterator& rhs) -{ - if (!stack.empty() && !rhs.stack.empty()) - return stack.front() == rhs.stack.front(); - - return stack.empty() && rhs.stack.empty(); -} - -const TypeId& UnionTypeVarIterator::operator*() -{ - LUAU_ASSERT(!stack.empty()); - - descend(); - - auto [utv, currentIndex] = stack.front(); - LUAU_ASSERT(utv); - LUAU_ASSERT(currentIndex < utv->options.size()); - - const TypeId& ty = utv->options[currentIndex]; - LUAU_ASSERT(!get(follow(ty))); - return ty; -} - -void UnionTypeVarIterator::advance() -{ - while (!stack.empty()) - { - auto& [utv, currentIndex] = stack.front(); - ++currentIndex; - - if (currentIndex >= utv->options.size()) - stack.pop_front(); - else - break; - } -} - -void UnionTypeVarIterator::descend() -{ - while (!stack.empty()) - { - auto [utv, currentIndex] = stack.front(); - if (auto innerUnion = get(follow(utv->options[currentIndex]))) - { - // If we're about to descend into a cyclic UnionTypeVar, we should skip over this. - // Ideally this should never happen, but alas it does from time to time. :( - if (seen.find(innerUnion) != seen.end()) - advance(); - else - { - seen.insert(innerUnion); - stack.push_front({innerUnion, 0}); - } - - continue; - } - - break; - } + return ctv->parts; } UnionTypeVarIterator begin(const UnionTypeVar* utv) @@ -1065,6 +1018,27 @@ UnionTypeVarIterator end(const UnionTypeVar* utv) return UnionTypeVarIterator{}; } +IntersectionTypeVarIterator begin(const IntersectionTypeVar* itv) +{ + return IntersectionTypeVarIterator{itv}; +} + +IntersectionTypeVarIterator end(const IntersectionTypeVar* itv) +{ + return IntersectionTypeVarIterator{}; +} + +ConstrainedTypeVarIterator begin(const ConstrainedTypeVar* ctv) +{ + return ConstrainedTypeVarIterator{ctv}; +} + +ConstrainedTypeVarIterator end(const ConstrainedTypeVar* ctv) +{ + return ConstrainedTypeVarIterator{}; +} + + static std::vector parseFormatString(TypeChecker& typechecker, const char* data, size_t size) { const char* options = "cdiouxXeEfgGqs"; @@ -1144,6 +1118,101 @@ std::optional> magicFunctionFormat( return WithPredicate{arena.addTypePack({typechecker.stringType})}; } +static std::vector parsePatternString(TypeChecker& typechecker, const char* data, size_t size) +{ + std::vector result; + int depth = 0; + bool parsingSet = false; + + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + ++i; + if (!parsingSet && i < size && data[i] == 'b') + i += 2; + } + else if (!parsingSet && data[i] == '[') + { + parsingSet = true; + if (i + 1 < size && data[i + 1] == ']') + i += 1; + } + else if (parsingSet && data[i] == ']') + { + parsingSet = false; + } + else if (data[i] == '(') + { + if (parsingSet) + continue; + + if (i + 1 < size && data[i + 1] == ')') + { + i++; + result.push_back(typechecker.numberType); + continue; + } + + ++depth; + result.push_back(typechecker.stringType); + } + else if (data[i] == ')') + { + if (parsingSet) + continue; + + --depth; + + if (depth < 0) + break; + } + } + + if (depth != 0 || parsingSet) + return std::vector(); + + if (result.empty()) + result.push_back(typechecker.stringType); + + return result; +} + +static std::optional> magicFunctionGmatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + if (!FFlag::LuauDeduceGmatchReturnTypes) + return std::nullopt; + + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() != 2) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t index = expr.self ? 0 : 1; + if (expr.args.size > index) + pattern = expr.args.data[index]->as(); + + if (!pattern) + return std::nullopt; + + std::vector returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + + typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location); + + const TypePackId emptyPack = arena.addTypePack({}); + const TypePackId returnList = arena.addTypePack(returnTypes); + const TypeId iteratorType = arena.addType(FunctionTypeVar{emptyPack, returnList}); + return WithPredicate{arena.addTypePack({iteratorType})}; +} + std::vector filterMap(TypeId type, TypeIdPredicate predicate) { type = follow(type); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 0792a35..44a3b85 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -19,6 +19,7 @@ LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); +LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau @@ -47,33 +48,6 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor } } - // TODO cycle and operator() need to be clipped when FFlagLuauUseVisitRecursionLimit is clipped - template - void cycle(TID) - { - } - template - bool operator()(TID ty, const T&) - { - return visit(ty); - } - bool operator()(TypeId ty, const FreeTypeVar& ftv) - { - return visit(ty, ftv); - } - bool operator()(TypeId ty, const FunctionTypeVar& ftv) - { - return visit(ty, ftv); - } - bool operator()(TypeId ty, const TableTypeVar& ttv) - { - return visit(ty, ttv); - } - bool operator()(TypePackId tp, const FreeTypePack& ftp) - { - return visit(tp, ftp); - } - bool visit(TypeId ty) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -103,6 +77,15 @@ struct PromoteTypeLevels final : TypeVarOnceVisitor return true; } + bool visit(TypeId ty, const ConstrainedTypeVar&) override + { + if (!FFlag::LuauUnknownAndNeverType) + return visit(ty); + + promote(ty, log.getMutable(ty)); + return true; + } + bool visit(TypeId ty, const FunctionTypeVar&) override { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -445,6 +428,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } else if (subFree) { + if (FFlag::LuauUnknownAndNeverType) + { + // Normally, if the subtype is free, it should not be bound to any, unknown, or error types. + // But for bug compatibility, we'll only apply this rule to unknown. Doing this will silence cascading type errors. + if (get(superTy)) + return; + } + TypeLevel subLevel = subFree->level; occursCheck(subTy, superTy); @@ -468,7 +459,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool return; } - if (get(superTy) || get(superTy)) + if (get(superTy) || get(superTy) || get(superTy)) return tryUnifyWithAny(subTy, superTy); if (get(subTy)) @@ -485,6 +476,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (get(subTy)) return tryUnifyWithAny(superTy, subTy); + if (get(subTy)) + return tryUnifyWithAny(superTy, subTy); + auto& cache = sharedState.cachedUnify; // What if the types are immutable and we proved their relation before @@ -1862,6 +1856,7 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas if (state.log.getMutable(ty)) { + // TODO: Only bind if the anyType isn't any, unknown, or error (?) state.log.replace(ty, BoundTypeVar{anyType}); } else if (auto fun = state.log.getMutable(ty)) @@ -1901,22 +1896,27 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) { - LUAU_ASSERT(get(anyTy) || get(anyTy)); + LUAU_ASSERT(get(anyTy) || get(anyTy) || get(anyTy) || get(anyTy)); // These types are not visited in general loop below if (get(subTy) || get(subTy) || get(subTy)) return; - const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}}); - - const TypePackId anyTP = get(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); + TypePackId anyTp; + if (FFlag::LuauUnknownAndNeverType) + anyTp = types->addTypePack(TypePackVar{VariadicTypePack{anyTy}}); + else + { + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}}); + anyTp = get(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); + } std::vector queue = {subTy}; sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, getSingletonTypes().anyType, anyTP); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, FFlag::LuauUnknownAndNeverType ? anyTy : getSingletonTypes().anyType, anyTp); } void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 70c9255..b7fa788 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -15,7 +15,6 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false) -LUAU_FASTFLAGVARIABLE(LuauReturnTypeTokenConfusion, false) LUAU_FASTFLAGVARIABLE(LuauFixNamedFunctionParse, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseWrongNamedType, false) @@ -1134,10 +1133,9 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector Parser::parseOptionalReturnTypeAnnotation() { - if (options.allowTypeAnnotations && - (lexer.current().type == ':' || (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow))) + if (options.allowTypeAnnotations && (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow)) { - if (FFlag::LuauReturnTypeTokenConfusion && lexer.current().type == Lexeme::SkinnyArrow) + if (lexer.current().type == Lexeme::SkinnyArrow) report(lexer.current().location, "Function return type annotations are written after ':' instead of '->'"); nextLexeme(); @@ -1373,12 +1371,10 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) if (FFlag::LuauFixNamedFunctionParse && !names.empty()) forceFunctionType = true; - bool returnTypeIntroducer = - FFlag::LuauReturnTypeTokenConfusion ? lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':' : false; + bool returnTypeIntroducer = lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':'; // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element - if (params.size() == 1 && !varargAnnotation && !forceFunctionType && - (FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow)) + if (params.size() == 1 && !varargAnnotation && !forceFunctionType && !returnTypeIntroducer) { if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) lua_telemetry_parsed_named_non_function_type = true; @@ -1389,8 +1385,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) return {params[0], {}}; } - if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && !forceFunctionType && - allowPack) + if (!forceFunctionType && !returnTypeIntroducer && allowPack) { if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) lua_telemetry_parsed_named_non_function_type = true; @@ -1409,7 +1404,7 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray' instead of ':'"); lexer.next(); diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 4bc8cab..479eb16 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -8,6 +8,10 @@ #include "FileUtils.h" +#ifdef CALLGRIND +#include +#endif + LUAU_FASTFLAG(DebugLuauTimeTracing) LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution) @@ -112,6 +116,7 @@ static void displayHelp(const char* argv0) printf("Available options:\n"); printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); + printf(" --mode=strict: default to strict mode when typechecking\n"); printf(" --timetrace: record compiler time tracing information into trace.json\n"); } @@ -178,9 +183,9 @@ struct CliConfigResolver : Luau::ConfigResolver mutable std::unordered_map configCache; mutable std::vector> configErrors; - CliConfigResolver() + CliConfigResolver(Luau::Mode mode) { - defaultConfig.mode = Luau::Mode::Nonstrict; + defaultConfig.mode = mode; } const Luau::Config& getConfig(const Luau::ModuleName& name) const override @@ -229,6 +234,7 @@ int main(int argc, char** argv) } ReportFormat format = ReportFormat::Default; + Luau::Mode mode = Luau::Mode::Nonstrict; bool annotate = false; for (int i = 1; i < argc; ++i) @@ -240,6 +246,8 @@ int main(int argc, char** argv) format = ReportFormat::Luacheck; else if (strcmp(argv[i], "--formatter=gnu") == 0) format = ReportFormat::Gnu; + else if (strcmp(argv[i], "--mode=strict") == 0) + mode = Luau::Mode::Strict; else if (strcmp(argv[i], "--annotate") == 0) annotate = true; else if (strcmp(argv[i], "--timetrace") == 0) @@ -258,12 +266,16 @@ int main(int argc, char** argv) frontendOptions.retainFullTypeGraphs = annotate; CliFileResolver fileResolver; - CliConfigResolver configResolver; + CliConfigResolver configResolver(mode); Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions); Luau::registerBuiltinTypes(frontend.typeChecker); Luau::freeze(frontend.typeChecker.globalTypes); +#ifdef CALLGRIND + CALLGRIND_ZERO_STATS; +#endif + std::vector files = getSourceFiles(argc, argv); int failed = 0; diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 83060f5..5fe12be 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -21,6 +21,10 @@ #include #endif +#ifdef CALLGRIND +#include +#endif + #include LUAU_FASTFLAG(DebugLuauTimeTracing) @@ -166,6 +170,36 @@ static int lua_collectgarbage(lua_State* L) luaL_error(L, "collectgarbage must be called with 'count' or 'collect'"); } +#ifdef CALLGRIND +static int lua_callgrind(lua_State* L) +{ + const char* option = luaL_checkstring(L, 1); + + if (strcmp(option, "running") == 0) + { + int r = RUNNING_ON_VALGRIND; + lua_pushboolean(L, r); + return 1; + } + + if (strcmp(option, "zero") == 0) + { + CALLGRIND_ZERO_STATS; + return 0; + } + + if (strcmp(option, "dump") == 0) + { + const char* name = luaL_checkstring(L, 2); + + CALLGRIND_DUMP_STATS_AT(name); + return 0; + } + + luaL_error(L, "callgrind must be called with one of 'running', 'zero', 'dump'"); +} +#endif + void setupState(lua_State* L) { luaL_openlibs(L); @@ -174,6 +208,9 @@ void setupState(lua_State* L) {"loadstring", lua_loadstring}, {"require", lua_require}, {"collectgarbage", lua_collectgarbage}, +#ifdef CALLGRIND + {"callgrind", lua_callgrind}, +#endif {NULL, NULL}, }; diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index c5979d3..65883b4 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -58,6 +58,9 @@ public: void jmp(Label& label); void jmp(OperandX64 op); + void call(Label& label); + void call(OperandX64 op); + // AVX void vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vaddps(OperandX64 dst, OperandX64 src1, OperandX64 src2); diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 27e0178..2634722 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -286,11 +286,34 @@ void AssemblyBuilderX64::jmp(OperandX64 op) if (logText) log("jmp", op); + placeRex(op); place(0xff); placeModRegMem(op, 4); commit(); } +void AssemblyBuilderX64::call(Label& label) +{ + place(0xe8); + placeLabel(label); + + if (logText) + log("call", label); + + commit(); +} + +void AssemblyBuilderX64::call(OperandX64 op) +{ + if (logText) + log("call", op); + + placeRex(op); + place(0xff); + placeModRegMem(op, 2); + commit(); +} + void AssemblyBuilderX64::vaddpd(OperandX64 dst, OperandX64 src1, OperandX64 src2) { placeAvx("vaddpd", dst, src1, src2, 0x58, false, AVX_0F, AVX_66); diff --git a/Makefile b/Makefile index d4ef7c7..a0a5220 100644 --- a/Makefile +++ b/Makefile @@ -93,6 +93,10 @@ ifeq ($(config),fuzz) LDFLAGS+=-fsanitize=address,fuzzer endif +ifneq ($(CALLGRIND),) + CXXFLAGS+=-DCALLGRIND=$(CALLGRIND) +endif + # target-specific flags $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include diff --git a/Sources.cmake b/Sources.cmake index f261cba..44bed8f 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -247,12 +247,15 @@ if(TARGET Luau.UnitTest) tests/IostreamOptional.h tests/ScopedFlags.h tests/Fixture.cpp + tests/AssemblyBuilderX64.test.cpp tests/AstQuery.test.cpp tests/AstVisitor.test.cpp tests/Autocomplete.test.cpp tests/BuiltinDefinitions.test.cpp tests/Compiler.test.cpp tests/Config.test.cpp + tests/ConstraintGraphBuilder.test.cpp + tests/ConstraintSolver.test.cpp tests/CostModel.test.cpp tests/Error.test.cpp tests/Frontend.test.cpp @@ -262,8 +265,7 @@ if(TARGET Luau.UnitTest) tests/Module.test.cpp tests/NonstrictMode.test.cpp tests/Normalize.test.cpp - tests/ConstraintGraphBuilder.test.cpp - tests/ConstraintSolver.test.cpp + tests/NotNull.test.cpp tests/Parser.test.cpp tests/RequireTracer.test.cpp tests/RuntimeLimits.test.cpp @@ -295,11 +297,11 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.tryUnify.test.cpp tests/TypeInfer.typePacks.cpp tests/TypeInfer.unionTypes.test.cpp + tests/TypeInfer.unknownnever.test.cpp tests/TypePack.test.cpp tests/TypeVar.test.cpp tests/Variant.test.cpp tests/VisitTypeVar.test.cpp - tests/AssemblyBuilderX64.test.cpp tests/main.cpp) endif() diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 2316cc3..79e6591 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -108,9 +108,9 @@ static LuaNode* hashvec(const Table* t, const float* v) memcpy(i, v, sizeof(i)); // convert -0 to 0 to make sure they hash to the same value - i[0] = (i[0] == 0x8000000) ? 0 : i[0]; - i[1] = (i[1] == 0x8000000) ? 0 : i[1]; - i[2] = (i[2] == 0x8000000) ? 0 : i[2]; + i[0] = (i[0] == 0x80000000) ? 0 : i[0]; + i[1] = (i[1] == 0x80000000) ? 0 : i[1]; + i[2] = (i[2] == 0x80000000) ? 0 : i[2]; // scramble bits to make sure that integer coordinates have entropy in lower bits i[0] ^= i[0] >> 17; @@ -121,7 +121,7 @@ static LuaNode* hashvec(const Table* t, const float* v) unsigned int h = (i[0] * 73856093) ^ (i[1] * 19349663) ^ (i[2] * 83492791); #if LUA_VECTOR_SIZE == 4 - i[3] = (i[3] == 0x8000000) ? 0 : i[3]; + i[3] = (i[3] == 0x80000000) ? 0 : i[3]; i[3] ^= i[3] >> 17; h ^= i[3] * 39916801; #endif diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 85829ca..02b3931 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -640,20 +640,16 @@ static void luau_execute(lua_State* L) VM_PATCH_C(pc - 2, L->cachedslot); VM_NEXT(); } - else - { - // slow-path, may invoke Lua calls via __index metamethod - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - VM_NEXT(); - } - } - else - { - // slow-path, may invoke Lua calls via __index metamethod - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - VM_NEXT(); + + // fall through to slow path } + + // fall through to slow path } + + // slow-path, may invoke Lua calls via __index metamethod + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + VM_NEXT(); } VM_CASE(LOP_SETTABLEKS) @@ -753,19 +749,13 @@ static void luau_execute(lua_State* L) setobj2s(L, ra, &h->array[unsigned(index - 1)]); VM_NEXT(); } - else - { - // slow-path: handles out of bounds array lookups and non-integer numeric keys - VM_PROTECT(luaV_gettable(L, rb, rc, ra)); - VM_NEXT(); - } - } - else - { - // slow-path: handles non-array table lookup as well as __index MT calls - VM_PROTECT(luaV_gettable(L, rb, rc, ra)); - VM_NEXT(); + + // fall through to slow path } + + // slow-path: handles out of bounds array lookups, non-integer numeric keys, non-array table lookup, __index MT calls + VM_PROTECT(luaV_gettable(L, rb, rc, ra)); + VM_NEXT(); } VM_CASE(LOP_SETTABLE) @@ -790,19 +780,13 @@ static void luau_execute(lua_State* L) luaC_barriert(L, h, ra); VM_NEXT(); } - else - { - // slow-path: handles out of bounds array assignments and non-integer numeric keys - VM_PROTECT(luaV_settable(L, rb, rc, ra)); - VM_NEXT(); - } - } - else - { - // slow-path: handles non-array table access as well as __newindex MT calls - VM_PROTECT(luaV_settable(L, rb, rc, ra)); - VM_NEXT(); + + // fall through to slow path } + + // slow-path: handles out of bounds array assignments, non-integer numeric keys, non-array table access, __newindex MT calls + VM_PROTECT(luaV_settable(L, rb, rc, ra)); + VM_NEXT(); } VM_CASE(LOP_GETTABLEN) @@ -822,6 +806,8 @@ static void luau_execute(lua_State* L) setobj2s(L, ra, &h->array[c]); VM_NEXT(); } + + // fall through to slow path } // slow-path: handles out of bounds array lookups @@ -849,6 +835,8 @@ static void luau_execute(lua_State* L) luaC_barriert(L, h, ra); VM_NEXT(); } + + // fall through to slow path } // slow-path: handles out of bounds array lookups @@ -2176,8 +2164,10 @@ static void luau_execute(lua_State* L) if (!ttisnumber(ra + 0) || !ttisnumber(ra + 1) || !ttisnumber(ra + 2)) { // slow-path: can convert arguments to numbers and trigger Lua errors - // Note: this doesn't reallocate stack so we don't need to recompute ra - VM_PROTECT(luau_prepareFORN(L, ra + 0, ra + 1, ra + 2)); + // Note: this doesn't reallocate stack so we don't need to recompute ra/base + VM_PROTECT_PC(); + + luau_prepareFORN(L, ra + 0, ra + 1, ra + 2); } double limit = nvalue(ra + 0); diff --git a/bench/bench.py b/bench/bench.py index 67fc8cf..bb3ea5f 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -40,6 +40,7 @@ argumentParser.add_argument('--results', dest='results',type=str,nargs='*',help= argumentParser.add_argument('--run-test', action='store', default=None, help='Regex test filter') argumentParser.add_argument('--extra-loops', action='store',type=int,default=0, help='Amount of times to loop over one test (one test already performs multiple runs)') argumentParser.add_argument('--filename', action='store',type=str,default='bench', help='File name for graph and results file') +argumentParser.add_argument('--callgrind', dest='callgrind',action='store_const',const=1,default=0,help='Use callgrind to run benchmarks') if matplotlib != None: argumentParser.add_argument('--absolute', dest='absolute',action='store_const',const=1,default=0,help='Display absolute values instead of relative (enabled by default when benchmarking a single VM)') @@ -55,6 +56,9 @@ argumentParser.add_argument('--no-print-influx-debugging', action='store_false', argumentParser.add_argument('--no-print-final-summary', action='store_false', dest='print_final_summary', help="Don't print a table summarizing the results after all tests are run") +# Assume 2.5 IPC on a 4 GHz CPU; this is obviously incorrect but it allows us to display simulated instruction counts using regular time units +CALLGRIND_INSN_PER_SEC = 2.5 * 4e9 + def arrayRange(count): result = [] @@ -71,6 +75,21 @@ def arrayRangeOffset(count, offset): return result +def getCallgrindOutput(lines): + result = [] + name = None + + for l in lines: + if l.startswith("desc: Trigger: Client Request: "): + name = l[31:].strip() + elif l.startswith("summary: ") and name != None: + insn = int(l[9:]) + # Note: we only run each bench once under callgrind so we only report a single time per run; callgrind instruction count variance is ~0.01% so it might as well be zero + result += "|><|" + name + "|><|" + str(insn / CALLGRIND_INSN_PER_SEC * 1000.0) + "||_||" + name = None + + return "".join(result) + def getVmOutput(cmd): if os.name == "nt": try: @@ -79,6 +98,16 @@ def getVmOutput(cmd): exit(1) except: return "" + elif arguments.callgrind: + try: + subprocess.check_call("valgrind --tool=callgrind --callgrind-out-file=callgrind.out --combine-dumps=yes --dump-line=no " + cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, cwd=scriptdir) + path = os.path.join(scriptdir, "callgrind.out") + with open(path, "r") as file: + lines = file.readlines() + os.unlink(path) + return getCallgrindOutput(lines) + except: + return "" else: with subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, cwd=scriptdir) as p: # Try to lock to a single processor @@ -375,12 +404,12 @@ def analyzeResult(subdir, main, comparisons): continue - pooledStdDev = math.sqrt((main.unbiasedEst + compare.unbiasedEst) / 2) + if main.count > 1 and stats: + pooledStdDev = math.sqrt((main.unbiasedEst + compare.unbiasedEst) / 2) - tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count)) - degreesOfFreedom = 2 * main.count - 2 + tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count)) + degreesOfFreedom = 2 * main.count - 2 - if stats: # Two-tailed distribution with 95% conf. tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) diff --git a/bench/bench_support.lua b/bench/bench_support.lua index 171b8da..a9608ec 100644 --- a/bench/bench_support.lua +++ b/bench/bench_support.lua @@ -5,6 +5,16 @@ bench.runs = 20 bench.extraRuns = 4 function bench.runCode(f, description) + -- Under Callgrind, run the test only once and measure just the execution cost + if callgrind and callgrind("running") then + if collectgarbage then collectgarbage() end + + callgrind("zero") + f() -- unfortunately we can't easily separate setup cost from runtime cost in f unless it calls callgrind() + callgrind("dump", description) + return + end + local timeTable = {} for i = 1,bench.runs + bench.extraRuns do diff --git a/bench/other/LuauPolyfillMap.lua b/bench/other/LuauPolyfillMap.lua new file mode 100644 index 0000000..1f957d4 --- /dev/null +++ b/bench/other/LuauPolyfillMap.lua @@ -0,0 +1,961 @@ +-- This file is part of the Roblox luau-polyfill repository and is licensed under MIT License; see LICENSE.txt for details +-- #region Array +-- Array related +local Array = {} +local Object = {} +local Map = {} + +type Array = { [number]: T } +type callbackFn = (element: V, key: K, map: Map) -> () +type callbackFnWithThisArg = (thisArg: Object, value: V, key: K, map: Map) -> () +type Map = { + size: number, + -- method definitions + set: (self: Map, K, V) -> Map, + get: (self: Map, K) -> V | nil, + clear: (self: Map) -> (), + delete: (self: Map, K) -> boolean, + forEach: (self: Map, callback: callbackFn | callbackFnWithThisArg, thisArg: Object?) -> (), + has: (self: Map, K) -> boolean, + keys: (self: Map) -> Array, + values: (self: Map) -> Array, + entries: (self: Map) -> Array>, + ipairs: (self: Map) -> any, + [K]: V, + _map: { [K]: V }, + _array: { [number]: K }, +} +type mapFn = (element: T, index: number) -> U +type mapFnWithThisArg = (thisArg: any, element: T, index: number) -> U +type Object = { [string]: any } +type Table = { [T]: V } +type Tuple = Array + +local Set = {} + +-- #region Array +function Array.isArray(value: any): boolean + if typeof(value) ~= "table" then + return false + end + if next(value) == nil then + -- an empty table is an empty array + return true + end + + local length = #value + + if length == 0 then + return false + end + + local count = 0 + local sum = 0 + for key in pairs(value) do + if typeof(key) ~= "number" then + return false + end + if key % 1 ~= 0 or key < 1 then + return false + end + count += 1 + sum += key + end + + return sum == (count * (count + 1) / 2) +end + +function Array.from( + value: string | Array | Object, + mapFn: (mapFn | mapFnWithThisArg)?, + thisArg: Object? +): Array + if value == nil then + error("cannot create array from a nil value") + end + local valueType = typeof(value) + + local array = {} + + if valueType == "table" and Array.isArray(value) then + if mapFn then + for i = 1, #(value :: Array) do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, (value :: Array)[i], i) + else + array[i] = (mapFn :: mapFn)((value :: Array)[i], i) + end + end + else + for i = 1, #(value :: Array) do + array[i] = (value :: Array)[i] + end + end + elseif instanceOf(value, Set) then + if mapFn then + for i, v in (value :: any):ipairs() do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, v, i) + else + array[i] = (mapFn :: mapFn)(v, i) + end + end + else + for i, v in (value :: any):ipairs() do + array[i] = v + end + end + elseif instanceOf(value, Map) then + if mapFn then + for i, v in (value :: any):ipairs() do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, v, i) + else + array[i] = (mapFn :: mapFn)(v, i) + end + end + else + for i, v in (value :: any):ipairs() do + array[i] = v + end + end + elseif valueType == "string" then + if mapFn then + for i = 1, (value :: string):len() do + if thisArg ~= nil then + array[i] = (mapFn :: mapFnWithThisArg)(thisArg, (value :: any):sub(i, i), i) + else + array[i] = (mapFn :: mapFn)((value :: any):sub(i, i), i) + end + end + else + for i = 1, (value :: string):len() do + array[i] = (value :: any):sub(i, i) + end + end + end + + return array +end + +type callbackFnArrayMap = (element: T, index: number, array: Array) -> U +type callbackFnWithThisArgArrayMap = (thisArg: V, element: T, index: number, array: Array) -> U + +-- Implements Javascript's `Array.prototype.map` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/map +function Array.map( + t: Array, + callback: callbackFnArrayMap | callbackFnWithThisArgArrayMap, + thisArg: V? +): Array + if typeof(t) ~= "table" then + error(string.format("Array.map called on %s", typeof(t))) + end + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + local len = #t + local A = {} + local k = 1 + + while k <= len do + local kValue = t[k] + + if kValue ~= nil then + local mappedValue + + if thisArg ~= nil then + mappedValue = (callback :: callbackFnWithThisArgArrayMap)(thisArg, kValue, k, t) + else + mappedValue = (callback :: callbackFnArrayMap)(kValue, k, t) + end + + A[k] = mappedValue + end + k += 1 + end + + return A +end + +type Function = (any, any, number, any) -> any + +-- Implements Javascript's `Array.prototype.reduce` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/reduce +function Array.reduce(array: Array, callback: Function, initialValue: any?): any + if typeof(array) ~= "table" then + error(string.format("Array.reduce called on %s", typeof(array))) + end + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + local length = #array + + local value + local initial = 1 + + if initialValue ~= nil then + value = initialValue + else + initial = 2 + if length == 0 then + error("reduce of empty array with no initial value") + end + value = array[1] + end + + for i = initial, length do + value = callback(value, array[i], i, array) + end + + return value +end + +type callbackFnArrayForEach = (element: T, index: number, array: Array) -> () +type callbackFnWithThisArgArrayForEach = (thisArg: U, element: T, index: number, array: Array) -> () + +-- Implements Javascript's `Array.prototype.forEach` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/forEach +function Array.forEach( + t: Array, + callback: callbackFnArrayForEach | callbackFnWithThisArgArrayForEach, + thisArg: U? +): () + if typeof(t) ~= "table" then + error(string.format("Array.forEach called on %s", typeof(t))) + end + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + local len = #t + local k = 1 + + while k <= len do + local kValue = t[k] + + if thisArg ~= nil then + (callback :: callbackFnWithThisArgArrayForEach)(thisArg, kValue, k, t) + else + (callback :: callbackFnArrayForEach)(kValue, k, t) + end + + if #t < len then + -- don't iterate on removed items, don't iterate more than original length + len = #t + end + k += 1 + end +end +-- #endregion + +-- #region Set +Set.__index = Set + +type callbackFnSet = (value: T, key: T, set: Set) -> () +type callbackFnWithThisArgSet = (thisArg: Object, value: T, key: T, set: Set) -> () + +export type Set = { + size: number, + -- method definitions + add: (self: Set, T) -> Set, + clear: (self: Set) -> (), + delete: (self: Set, T) -> boolean, + forEach: (self: Set, callback: callbackFnSet | callbackFnWithThisArgSet, thisArg: Object?) -> (), + has: (self: Set, T) -> boolean, + ipairs: (self: Set) -> any, +} + +type Iterable = { ipairs: (any) -> any } + +function Set.new(iterable: Array | Set | Iterable | string | nil): Set + local array = {} + local map = {} + if iterable ~= nil then + local arrayIterable: Array + -- ROBLOX TODO: remove type casting from (iterable :: any).ipairs in next release + if typeof(iterable) == "table" then + if Array.isArray(iterable) then + arrayIterable = Array.from(iterable :: Array) + elseif typeof((iterable :: Iterable).ipairs) == "function" then + -- handle in loop below + elseif _G.__DEV__ then + error("cannot create array from an object-like table") + end + elseif typeof(iterable) == "string" then + arrayIterable = Array.from(iterable :: string) + else + error(("cannot create array from value of type `%s`"):format(typeof(iterable))) + end + + if arrayIterable then + for _, element in ipairs(arrayIterable) do + if not map[element] then + map[element] = true + table.insert(array, element) + end + end + elseif typeof(iterable) == "table" and typeof((iterable :: Iterable).ipairs) == "function" then + for _, element in (iterable :: Iterable):ipairs() do + if not map[element] then + map[element] = true + table.insert(array, element) + end + end + end + end + + return (setmetatable({ + size = #array, + _map = map, + _array = array, + }, Set) :: any) :: Set +end + +function Set:add(value) + if not self._map[value] then + -- Luau FIXME: analyze should know self is Set which includes size as a number + self.size = self.size :: number + 1 + self._map[value] = true + table.insert(self._array, value) + end + return self +end + +function Set:clear() + self.size = 0 + table.clear(self._map) + table.clear(self._array) +end + +function Set:delete(value): boolean + if not self._map[value] then + return false + end + -- Luau FIXME: analyze should know self is Map which includes size as a number + self.size = self.size :: number - 1 + self._map[value] = nil + local index = table.find(self._array, value) + if index then + table.remove(self._array, index) + end + return true +end + +-- Implements Javascript's `Map.prototype.forEach` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Set/forEach +function Set:forEach(callback: callbackFnSet | callbackFnWithThisArgSet, thisArg: Object?): () + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + return Array.forEach(self._array, function(value: T) + if thisArg ~= nil then + (callback :: callbackFnWithThisArgSet)(thisArg, value, value, self) + else + (callback :: callbackFnSet)(value, value, self) + end + end) +end + +function Set:has(value): boolean + return self._map[value] ~= nil +end + +function Set:ipairs() + return ipairs(self._array) +end + +-- #endregion Set + +-- #region Object +function Object.entries(value: string | Object | Array): Array + assert(value :: any ~= nil, "cannot get entries from a nil value") + local valueType = typeof(value) + + local entries: Array> = {} + if valueType == "table" then + for key, keyValue in pairs(value :: Object) do + -- Luau FIXME: Luau should see entries as Array, given object is [string]: any, but it sees it as Array> despite all the manual annotation + table.insert(entries, { key :: string, keyValue :: any }) + end + elseif valueType == "string" then + for i = 1, string.len(value :: string) do + entries[i] = { tostring(i), string.sub(value :: string, i, i) } + end + end + + return entries +end + +-- #endregion + +-- #region instanceOf + +-- ROBLOX note: Typed tbl as any to work with strict type analyze +-- polyfill for https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/instanceof +function instanceOf(tbl: any, class) + assert(typeof(class) == "table", "Received a non-table as the second argument for instanceof") + + if typeof(tbl) ~= "table" then + return false + end + + local ok, hasNew = pcall(function() + return class.new ~= nil and tbl.new == class.new + end) + if ok and hasNew then + return true + end + + local seen = { tbl = true } + + while tbl and typeof(tbl) == "table" do + tbl = getmetatable(tbl) + if typeof(tbl) == "table" then + tbl = tbl.__index + + if tbl == class then + return true + end + end + + -- if we still have a valid table then check against seen + if typeof(tbl) == "table" then + if seen[tbl] then + return false + end + seen[tbl] = true + end + end + + return false +end +-- #endregion + +function Map.new(iterable: Array>?): Map + local array = {} + local map = {} + if iterable ~= nil then + local arrayFromIterable + local iterableType = typeof(iterable) + if iterableType == "table" then + if #iterable > 0 and typeof(iterable[1]) ~= "table" then + error("cannot create Map from {K, V} form, it must be { {K, V}... }") + end + + arrayFromIterable = Array.from(iterable) + else + error(("cannot create array from value of type `%s`"):format(iterableType)) + end + + for _, entry in ipairs(arrayFromIterable) do + local key = entry[1] + if _G.__DEV__ then + if key == nil then + error("cannot create Map from a table that isn't an array.") + end + end + local val = entry[2] + -- only add to array if new + if map[key] == nil then + table.insert(array, key) + end + -- always assign + map[key] = val + end + end + + return (setmetatable({ + size = #array, + _map = map, + _array = array, + }, Map) :: any) :: Map +end + +function Map:set(key: K, value: V): Map + -- preserve initial insertion order + if self._map[key] == nil then + -- Luau FIXME: analyze should know self is Map which includes size as a number + self.size = self.size :: number + 1 + table.insert(self._array, key) + end + -- always update value + self._map[key] = value + return self +end + +function Map:get(key) + return self._map[key] +end + +function Map:clear() + local table_: any = table + self.size = 0 + table_.clear(self._map) + table_.clear(self._array) +end + +function Map:delete(key): boolean + if self._map[key] == nil then + return false + end + -- Luau FIXME: analyze should know self is Map which includes size as a number + self.size = self.size :: number - 1 + self._map[key] = nil + local index = table.find(self._array, key) + if index then + table.remove(self._array, index) + end + return true +end + +-- Implements Javascript's `Map.prototype.forEach` as defined below +-- https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Map/forEach +function Map:forEach(callback: callbackFn | callbackFnWithThisArg, thisArg: Object?): () + if typeof(callback) ~= "function" then + error("callback is not a function") + end + + return Array.forEach(self._array, function(key: K) + local value: V = self._map[key] :: V + + if thisArg ~= nil then + (callback :: callbackFnWithThisArg)(thisArg, value, key, self) + else + (callback :: callbackFn)(value, key, self) + end + end) +end + +function Map:has(key): boolean + return self._map[key] ~= nil +end + +function Map:keys() + return self._array +end + +function Map:values() + return Array.map(self._array, function(key) + return self._map[key] + end) +end + +function Map:entries() + return Array.map(self._array, function(key) + return { key, self._map[key] } + end) +end + +function Map:ipairs() + return ipairs(self:entries()) +end + +function Map.__index(self, key) + local mapProp = rawget(Map, key) + if mapProp ~= nil then + return mapProp + end + + return Map.get(self, key) +end + +function Map.__newindex(table_, key, value) + table_:set(key, value) +end + +local function coerceToMap(mapLike: Map | Table): Map + return instanceOf(mapLike, Map) and mapLike :: Map -- ROBLOX: order is preservered + or Map.new(Object.entries(mapLike)) -- ROBLOX: order is not preserved +end + +-- local function coerceToTable(mapLike: Map | Table): Table +-- if not instanceOf(mapLike, Map) then +-- return mapLike +-- end + +-- -- create table from map +-- return Array.reduce(mapLike:entries(), function(tbl, entry) +-- tbl[entry[1]] = entry[2] +-- return tbl +-- end, {}) +-- end + +-- #region Tests to verify it works as expected +local function it(description: string, fn: () -> ()) + local ok, result = pcall(fn) + + if not ok then + error("Failed test: " .. description .. "\n" .. result) + end +end + +local AN_ITEM = "bar" +local ANOTHER_ITEM = "baz" + +-- #region [Describe] "Map" +-- #region [Child Describe] "constructors" +it("creates an empty array", function() + local foo = Map.new() + assert(foo.size == 0) +end) + +it("creates a Map from an array", function() + local foo = Map.new({ + { AN_ITEM, "foo" }, + { ANOTHER_ITEM, "val" }, + }) + assert(foo.size == 2) + assert(foo:has(AN_ITEM) == true) + assert(foo:has(ANOTHER_ITEM) == true) +end) + +it("creates a Map from an array with duplicate keys", function() + local foo = Map.new({ + { AN_ITEM, "foo1" }, + { AN_ITEM, "foo2" }, + }) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == "foo2") + + assert(#foo:keys() == 1 and foo:keys()[1] == AN_ITEM) + assert(#foo:values() == 1 and foo:values()[1] == "foo2") + assert(#foo:entries() == 1) + assert(#foo:entries()[1] == 2) + + assert(foo:entries()[1][1] == AN_ITEM) + assert(foo:entries()[1][2] == "foo2") +end) + +it("preserves the order of keys first assignment", function() + local foo = Map.new({ + { AN_ITEM, "foo1" }, + { ANOTHER_ITEM, "bar" }, + { AN_ITEM, "foo2" }, + }) + assert(foo.size == 2) + assert(foo:get(AN_ITEM) == "foo2") + assert(foo:get(ANOTHER_ITEM) == "bar") + + assert(foo:keys()[1] == AN_ITEM) + assert(foo:keys()[2] == ANOTHER_ITEM) + assert(foo:values()[1] == "foo2") + assert(foo:values()[2] == "bar") + assert(foo:entries()[1][1] == AN_ITEM) + assert(foo:entries()[1][2] == "foo2") + assert(foo:entries()[2][1] == ANOTHER_ITEM) + assert(foo:entries()[2][2] == "bar") +end) +-- #endregion + +-- #region [Child Describe] "type" +it("instanceOf return true for an actual Map object", function() + local foo = Map.new() + assert(instanceOf(foo, Map) == true) +end) + +it("instanceOf return false for an regular plain object", function() + local foo = {} + assert(instanceOf(foo, Map) == false) +end) +-- #endregion + +-- #region [Child Describe] "set" +it("returns the Map object", function() + local foo = Map.new() + assert(foo:set(1, "baz") == foo) +end) + +it("increments the size if the element is added for the first time", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo.size == 1) +end) + +it("does not increment the size the second time an element is added", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(AN_ITEM, "val") + assert(foo.size == 1) +end) + +it("sets values correctly to true/false", function() + -- Luau FIXME: Luau insists that arrays can't be mixed type + local foo = Map.new({ { AN_ITEM, false :: any } }) + foo:set(AN_ITEM, false) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == false) + + foo:set(AN_ITEM, true) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == true) + + foo:set(AN_ITEM, false) + assert(foo.size == 1) + assert(foo:get(AN_ITEM) == false) +end) + +-- #endregion + +-- #region [Child Describe] "get" +it("returns value of item from provided key", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo:get(AN_ITEM) == "foo") +end) + +it("returns nil if the item is not in the Map", function() + local foo = Map.new() + assert(foo:get(AN_ITEM) == nil) +end) +-- #endregion + +-- #region [Child Describe] "clear" +it("sets the size to zero", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:clear() + assert(foo.size == 0) +end) + +it("removes the items from the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:clear() + assert(foo:has(AN_ITEM) == false) +end) +-- #endregion + +-- #region [Child Describe] "delete" +it("removes the items from the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:delete(AN_ITEM) + assert(foo:has(AN_ITEM) == false) +end) + +it("returns true if the item was in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo:delete(AN_ITEM) == true) +end) + +it("returns false if the item was not in the Map", function() + local foo = Map.new() + assert(foo:delete(AN_ITEM) == false) +end) + +it("decrements the size if the item was in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:delete(AN_ITEM) + assert(foo.size == 0) +end) + +it("does not decrement the size if the item was not in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:delete(ANOTHER_ITEM) + assert(foo.size == 1) +end) + +it("deletes value set to false", function() + -- Luau FIXME: Luau insists arrays can't be mixed type + local foo = Map.new({ { AN_ITEM, false :: any } }) + + foo:delete(AN_ITEM) + + assert(foo.size == 0) + assert(foo:get(AN_ITEM) == nil) +end) +-- #endregion + +-- #region [Child Describe] "has" +it("returns true if the item is in the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + assert(foo:has(AN_ITEM) == true) +end) + +it("returns false if the item is not in the Map", function() + local foo = Map.new() + assert(foo:has(AN_ITEM) == false) +end) + +it("returns correctly with value set to false", function() + -- Luau FIXME: Luau insists arrays can't be mixed type + local foo = Map.new({ { AN_ITEM, false :: any } }) + + assert(foo:has(AN_ITEM) == true) +end) +-- #endregion + +-- #region [Child Describe] "keys / values / entries" +it("returns array of elements", function() + local myMap = Map.new() + myMap:set(AN_ITEM, "foo") + myMap:set(ANOTHER_ITEM, "val") + + assert(myMap:keys()[1] == AN_ITEM) + assert(myMap:keys()[2] == ANOTHER_ITEM) + + assert(myMap:values()[1] == "foo") + assert(myMap:values()[2] == "val") + + assert(myMap:entries()[1][1] == AN_ITEM) + assert(myMap:entries()[1][2] == "foo") + assert(myMap:entries()[2][1] == ANOTHER_ITEM) + assert(myMap:entries()[2][2] == "val") +end) +-- #endregion + +-- #region [Child Describe] "__index" +it("can access fields directly without using get", function() + local typeName = "size" + + local foo = Map.new({ + { AN_ITEM, "foo" }, + { ANOTHER_ITEM, "val" }, + { typeName, "buzz" }, + }) + + assert(foo.size == 3) + assert(foo[AN_ITEM] == "foo") + assert(foo[ANOTHER_ITEM] == "val") + assert(foo:get(typeName) == "buzz") +end) +-- #endregion + +-- #region [Child Describe] "__newindex" +it("can set fields directly without using set", function() + local foo = Map.new() + + assert(foo.size == 0) + + foo[AN_ITEM] = "foo" + foo[ANOTHER_ITEM] = "val" + foo.fizz = "buzz" + + assert(foo.size == 3) + assert(foo:get(AN_ITEM) == "foo") + assert(foo:get(ANOTHER_ITEM) == "val") + assert(foo:get("fizz") == "buzz") +end) +-- #endregion + +-- #region [Child Describe] "ipairs" +local function makeArray(...) + local array = {} + for _, item in ... do + table.insert(array, item) + end + return array +end + +it("iterates on the elements by their insertion order", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(ANOTHER_ITEM, "val") + assert(makeArray(foo:ipairs())[1][1] == AN_ITEM) + assert(makeArray(foo:ipairs())[1][2] == "foo") + assert(makeArray(foo:ipairs())[2][1] == ANOTHER_ITEM) + assert(makeArray(foo:ipairs())[2][2] == "val") +end) + +it("does not iterate on removed elements", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(ANOTHER_ITEM, "val") + foo:delete(AN_ITEM) + assert(makeArray(foo:ipairs())[1][1] == ANOTHER_ITEM) + assert(makeArray(foo:ipairs())[1][2] == "val") +end) + +it("iterates on elements if the added back to the Map", function() + local foo = Map.new() + foo:set(AN_ITEM, "foo") + foo:set(ANOTHER_ITEM, "val") + foo:delete(AN_ITEM) + foo:set(AN_ITEM, "food") + assert(makeArray(foo:ipairs())[1][1] == ANOTHER_ITEM) + assert(makeArray(foo:ipairs())[1][2] == "val") + assert(makeArray(foo:ipairs())[2][1] == AN_ITEM) + assert(makeArray(foo:ipairs())[2][2] == "food") +end) +-- #endregion + +-- #region [Child Describe] "Integration Tests" +-- it("MDN Examples", function() +-- local myMap = Map.new() :: Map + +-- local keyString = "a string" +-- local keyObj = {} +-- local keyFunc = function() end + +-- -- setting the values +-- myMap:set(keyString, "value associated with 'a string'") +-- myMap:set(keyObj, "value associated with keyObj") +-- myMap:set(keyFunc, "value associated with keyFunc") + +-- assert(myMap.size == 3) + +-- -- getting the values +-- assert(myMap:get(keyString) == "value associated with 'a string'") +-- assert(myMap:get(keyObj) == "value associated with keyObj") +-- assert(myMap:get(keyFunc) == "value associated with keyFunc") + +-- assert(myMap:get("a string") == "value associated with 'a string'") + +-- assert(myMap:get({}) == nil) -- nil, because keyObj !== {} +-- assert(myMap:get(function() -- nil because keyFunc !== function () {} +-- end) == nil) +-- end) + +it("handles non-traditional keys", function() + local myMap = Map.new() :: Map + + local falseKey = false + local trueKey = true + local negativeKey = -1 + local emptyKey = "" + + myMap:set(falseKey, "apple") + myMap:set(trueKey, "bear") + myMap:set(negativeKey, "corgi") + myMap:set(emptyKey, "doge") + + assert(myMap.size == 4) + + assert(myMap:get(falseKey) == "apple") + assert(myMap:get(trueKey) == "bear") + assert(myMap:get(negativeKey) == "corgi") + assert(myMap:get(emptyKey) == "doge") + + myMap:delete(falseKey) + myMap:delete(trueKey) + myMap:delete(negativeKey) + myMap:delete(emptyKey) + + assert(myMap.size == 0) +end) +-- #endregion + +-- #endregion [Describe] "Map" + +-- #region [Describe] "coerceToMap" +it("returns the same object if instance of Map", function() + local map = Map.new() + assert(coerceToMap(map) == map) + + map = Map.new({}) + assert(coerceToMap(map) == map) + + map = Map.new({ { AN_ITEM, "foo" } }) + assert(coerceToMap(map) == map) +end) +-- #endregion [Describe] "coerceToMap" + +-- #endregion Tests to verify it works as expected diff --git a/bench/other/regex.lua b/bench/other/regex.lua new file mode 100644 index 0000000..eb659a5 --- /dev/null +++ b/bench/other/regex.lua @@ -0,0 +1,2089 @@ +--[[ + PCRE2-based RegEx implemention for Luau + Version 1.0.0a2 + BSD 2-Clause Licence + Copyright © 2020 - Blockzez (devforum /u/Blockzez and github.com/Blockzez) + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +]] +--[[ Settings ]]-- +-- You can change them here +local options = { + -- The maximum cache size for regex so the patterns are cached so it doesn't recompile the pattern + -- The only accepted value are number values >= 0, strings that can be automatically coered to numbers that are >= 0, false and nil + -- Do note that empty regex patterns (comment-only patterns included) are never cached regardless + -- The default is 256 + cacheSize = 256, + + -- A boolean that determines whether this use unicode data + -- If this value evalulates to false, you can remove _unicodechar_category, _scripts and _xuc safely and it'll now error if: + -- - You try to compile a RegEx with unicode flag + -- - You try to use the \p pattern + -- The default is true + unicodeData = false, +}; + +-- +local u_categories = options.unicodeData and require(script:WaitForChild("_unicodechar_category")); +local chr_scripts = options.unicodeData and require(script:WaitForChild("_scripts")); +local xuc_chr = options.unicodeData and require(script:WaitForChild("_xuc")); +local proxy = setmetatable({ }, { __mode = 'k' }); +local re, re_m, match_m = { }, { }, { }; +local lockmsg; + +--[[ Functions ]]-- +local function to_str_arr(self, init) + if init then + self = string.sub(self, utf8.offset(self, init)); + end; + local len = utf8.len(self); + if len <= 1999 then + return { n = len, s = self, utf8.codepoint(self, 1, #self) }; + end; + local clen = math.ceil(len / 1999); + local ret = table.create(len); + local p = 1; + for i = 1, clen do + local c = table.pack(utf8.codepoint(self, utf8.offset(self, i * 1999 - 1998), utf8.offset(self, i * 1999 - (i == clen and 1998 - ((len - 1) % 1999 + 1) or - 1)) - 1)); + table.move(c, 1, c.n, p, ret); + p += c.n; + end; + ret.s, ret.n = self, len; + return ret; +end; + +local function from_str_arr(self) + local len = self.n or #self; + if len <= 7997 then + return utf8.char(table.unpack(self)); + end; + local clen = math.ceil(len / 7997); + local r = table.create(clen); + for i = 1, clen do + r[i] = utf8.char(table.unpack(self, i * 7997 - 7996, i * 7997 - (i == clen and 7997 - ((len - 1) % 7997 + 1) or 0))); + end; + return table.concat(r); +end; + +local function utf8_sub(self, i, j) + j = utf8.offset(self, j); + return string.sub(self, utf8.offset(self, i), j and j - 1); +end; + +-- +local flag_map = { + a = 'anchored', i = 'caseless', m = 'multiline', s = 'dotall', u = 'unicode', U = 'ungreedy', x ='extended', +}; + +local posix_class_names = { + alnum = true, alpha = true, ascii = true, blank = true, cntrl = true, digit = true, graph = true, lower = true, print = true, punct = true, space = true, upper = true, word = true, xdigit = true, +}; + +local escape_chars = { + -- grouped + -- digit, spaces and words + [0x44] = { "class", "digit", true }, [0x53] = { "class", "space", true }, [0x57] = { "class", "word", true }, + [0x64] = { "class", "digit", false }, [0x73] = { "class", "space", false }, [0x77] = { "class", "word", false }, + -- horizontal/vertical whitespace and newline + [0x48] = { "class", "blank", true }, [0x56] = { "class", "vertical_tab", true }, + [0x68] = { "class", "blank", false }, [0x76] = { "class", "vertical_tab", false }, + [0x4E] = { 0x4E }, [0x52] = { 0x52 }, + + -- not grouped + [0x42] = 0x08, + [0x6E] = 0x0A, [0x72] = 0x0D, [0x74] = 0x09, +}; + +local b_escape_chars = { + -- word boundary and not word boundary + [0x62] = { 0x62, { "class", "word", false } }, [0x42] = { 0x42, { "class", "word", false } }, + + -- keep match out + [0x4B] = { 0x4B }, + + -- start & end of string + [0x47] = { 0x47 }, [0x4A] = { 0x4A }, [0x5A] = { 0x5A }, [0x7A] = { 0x7A }, +}; + +local valid_categories = { + C = true, Cc = true, Cf = true, Cn = true, Co = true, Cs = true, + L = true, Ll = true, Lm = true, Lo = true, Lt = true, Lu = true, + M = true, Mc = true, Me = true, Mn = true, + N = true, Nd = true, Nl = true, No = true, + P = true, Pc = true, Pd = true, Pe = true, Pf = true, Pi = true, Po = true, Ps = true, + S = true, Sc = true, Sk = true, Sm = true, So = true, + Z = true, Zl = true, Zp = true, Zs = true, + + Xan = true, Xps = true, Xsp = true, Xuc = true, Xwd = true, +}; + +local class_ascii_punct = { + [0x21] = true, [0x22] = true, [0x23] = true, [0x24] = true, [0x25] = true, [0x26] = true, [0x27] = true, [0x28] = true, [0x29] = true, [0x2A] = true, [0x2B] = true, [0x2C] = true, [0x2D] = true, [0x2E] = true, [0x2F] = true, + [0x3A] = true, [0x3B] = true, [0x3C] = true, [0x3D] = true, [0x3E] = true, [0x3F] = true, [0x40] = true, [0x5B] = true, [0x5C] = true, [0x5D] = true, [0x5E] = true, [0x5F] = true, [0x60] = true, [0x7B] = true, [0x7C] = true, + [0x7D] = true, [0x7E] = true, +}; + +local end_str = { 0x24 }; +local dot = { 0x2E }; +local beginning_str = { 0x5E }; +local alternation = { 0x7C }; + +local function check_re(re_type, name, func) + if re_type == "Match" then + return function(...) + local arg_n = select('#', ...); + if arg_n < 1 then + error("missing argument #1 (Match expected)", 2); + end; + local arg0, arg1 = ...; + if not (proxy[arg0] and proxy[arg0].name == "Match") then + error(string.format("invalid argument #1 to %q (Match expected, got %s)", name, typeof(arg0)), 2); + else + arg0 = proxy[arg0]; + end; + if name == "group" or name == "span" then + if arg1 == nil then + arg1 = 0; + end; + end; + return func(arg0, arg1); + end; + end; + return function(...) + local arg_n = select('#', ...); + if arg_n < 1 then + error("missing argument #1 (RegEx expected)", 2); + elseif arg_n < 2 then + error("missing argument #2 (string expected)", 2); + end; + local arg0, arg1, arg2, arg3, arg4, arg5 = ...; + if not (proxy[arg0] and proxy[arg0].name == "RegEx") then + if type(arg0) ~= "string" and type(arg0) ~= "number" then + error(string.format("invalid argument #1 to %q (RegEx expected, got %s)", name, typeof(arg0)), 2); + end; + arg0 = re.fromstring(arg0); + elseif name == "sub" then + if type(arg2) == "number" then + arg2 ..= ''; + elseif type(arg2) ~= "string" then + error(string.format("invalid argument #3 to 'sub' (string expected, got %s)", typeof(arg2)), 2); + end; + elseif type(arg1) == "number" then + arg1 ..= ''; + elseif type(arg1) ~= "string" then + error(string.format("invalid argument #2 to %q (string expected, got %s)", name, typeof(arg1)), 2); + end; + if name ~= "sub" and name ~= "split" then + local init_type = typeof(arg2); + if init_type ~= 'nil' then + arg2 = tonumber(arg2); + if not arg2 then + error(string.format("invalid argument #3 to %q (number expected, got %s)", name, init_type), 2); + elseif arg2 < 0 then + arg2 = #arg1 + math.floor(arg2 + 0.5) + 1; + else + arg2 = math.max(math.floor(arg2 + 0.5), 1); + end; + end; + end; + arg0 = proxy[arg0]; + if name == "match" or name == "matchiter" then + arg3 = ...; + elseif name == "sub" then + arg5 = ...; + end; + return func(arg0, arg1, arg2, arg3, arg4, arg5); + end; +end; + +--[[ Matches ]]-- +local function match_tostr(self) + local spans = proxy[self].spans; + local s_start, s_end = spans[0][1], spans[0][2]; + if s_end <= s_start then + return string.format("Match (%d..%d, empty)", s_start, s_end - 1); + end; + return string.format("Match (%d..%d): %s", s_start, s_end - 1, utf8_sub(spans.input, s_start, s_end)); +end; + +local function new_match(span_arr, group_id, re, str) + span_arr.source, span_arr.input = re, str; + local object = newproxy(true); + local object_mt = getmetatable(object); + object_mt.__metatable = lockmsg; + object_mt.__index = setmetatable(span_arr, match_m); + object_mt.__tostring = match_tostr; + + proxy[object] = { name = "Match", spans = span_arr, group_id = group_id }; + return object; +end; + +match_m.group = check_re('Match', 'group', function(self, group_id) + local span = self.spans[type(group_id) == "number" and group_id or self.group_id[group_id]]; + if not span then + return nil; + end; + return utf8_sub(self.spans.input, span[1], span[2]); +end); + +match_m.span = check_re('Match', 'span', function(self, group_id) + local span = self.spans[type(group_id) == "number" and group_id or self.group_id[group_id]]; + if not span then + return nil; + end; + return span[1], span[2] - 1; +end); + +match_m.groups = check_re('Match', 'groups', function(self) + local spans = self.spans; + if spans.n > 0 then + local ret = table.create(spans.n); + for i = 0, spans.n do + local v = spans[i]; + if v then + ret[i] = utf8_sub(spans.input, v[1], v[2]); + end; + end; + return table.unpack(ret, 1, spans.n); + end; + return utf8_sub(spans.input, spans[0][1], spans[0][2]); +end); + +match_m.groupdict = check_re('Match', 'groupdict', function(self) + local spans = self.spans; + local ret = { }; + for k, v in pairs(self.group_id) do + v = spans[v]; + if v then + ret[k] = utf8_sub(spans.input, v[1], v[2]); + end; + end; + return ret; +end); + +match_m.grouparr = check_re('Match', 'groupdict', function(self) + local spans = self.spans; + local ret = table.create(spans.n); + for i = 0, spans.n do + local v = spans[i]; + if v then + ret[i] = utf8_sub(spans.input, v[1], v[2]); + end; + end; + ret.n = spans.n; + return ret; +end); + +-- +local line_verbs = { + CR = 0, LF = 1, CRLF = 2, ANYRLF = 3, ANY = 4, NUL = 5, +}; +local function is_newline(str_arr, i, verb_flags) + local line_verb_n = verb_flags.newline; + local chr = str_arr[i]; + if line_verb_n == 0 then + -- carriage return + return chr == 0x0D; + elseif line_verb_n == 2 then + -- carriage return followed by line feed + return chr == 0x0A and str_arr[i - 1] == 0x20; + elseif line_verb_n == 3 then + -- any of the above + return chr == 0x0A or chr == 0x0D; + elseif line_verb_n == 4 then + -- any of Unicode newlines + return chr == 0x0A or chr == 0x0B or chr == 0x0C or chr == 0x0D or chr == 0x85 or chr == 0x2028 or chr == 0x2029; + elseif line_verb_n == 5 then + -- null + return chr == 0; + end; + -- linefeed + return chr == 0x0A; +end; + + +local function tkn_char_match(tkn_part, str_arr, i, flags, verb_flags) + local chr = str_arr[i]; + if not chr then + return false; + elseif flags.ignoreCase and chr >= 0x61 and chr <= 0x7A then + chr -= 0x20; + end; + if type(tkn_part) == "number" then + return tkn_part == chr; + elseif tkn_part[1] == "charset" then + for _, v in ipairs(tkn_part[3]) do + if tkn_char_match(v, str_arr, i, flags, verb_flags) then + return not tkn_part[2]; + end; + end; + return tkn_part[2]; + elseif tkn_part[1] == "range" then + return chr >= tkn_part[2] and chr <= tkn_part[3] or flags.ignoreCase and chr >= 0x41 and chr <= 0x5A and (chr + 0x20) >= tkn_part[2] and (chr + 0x20) <= tkn_part[3]; + elseif tkn_part[1] == "class" then + local char_class = tkn_part[2]; + local negate = tkn_part[3]; + local match = false; + -- if and elseifs :( + -- Might make these into tables in the future + if char_class == "xdigit" then + match = chr >= 0x30 and chr <= 0x39 or chr >= 0x41 and chr <= 0x46 or chr >= 0x61 and chr <= 0x66; + elseif char_class == "ascii" then + match = chr <= 0x7F; + -- cannot be accessed through POSIX classes + elseif char_class == "vertical_tab" then + match = chr >= 0x0A and chr <= 0x0D or chr == 0x2028 or chr == 0x2029; + -- + elseif flags.unicode then + local current_category = u_categories[chr] or 'Cn'; + local first_category = current_category:sub(1, 1); + if char_class == "alnum" then + match = first_category == 'L' or current_category == 'Nl' or current_category == 'Nd'; + elseif char_class == "alpha" then + match = first_category == 'L' or current_category == 'Nl'; + elseif char_class == "blank" then + match = current_category == 'Zs' or chr == 0x09; + elseif char_class == "cntrl" then + match = current_category == 'Cc'; + elseif char_class == "digit" then + match = current_category == 'Nd'; + elseif char_class == "graph" then + match = first_category ~= 'P' and first_category ~= 'C'; + elseif char_class == "lower" then + match = current_category == 'Ll'; + elseif char_class == "print" then + match = first_category ~= 'C'; + elseif char_class == "punct" then + match = first_category == 'P'; + elseif char_class == "space" then + match = first_category == 'Z' or chr >= 0x09 and chr <= 0x0D; + elseif char_class == "upper" then + match = current_category == 'Lu'; + elseif char_class == "word" then + match = first_category == 'L' or current_category == 'Nl' or current_category == 'Nd' or current_category == 'Pc'; + end; + elseif char_class == "alnum" then + match = chr >= 0x30 and chr <= 0x39 or chr >= 0x41 and chr <= 0x5A or chr >= 0x61 and chr <= 0x7A; + elseif char_class == "alpha" then + match = chr >= 0x41 and chr <= 0x5A or chr >= 0x61 and chr <= 0x7A; + elseif char_class == "blank" then + match = chr == 0x09 or chr == 0x20; + elseif char_class == "cntrl" then + match = chr <= 0x1F or chr == 0x7F; + elseif char_class == "digit" then + match = chr >= 0x30 and chr <= 0x39; + elseif char_class == "graph" then + match = chr >= 0x21 and chr <= 0x7E; + elseif char_class == "lower" then + match = chr >= 0x61 and chr <= 0x7A; + elseif char_class == "print" then + match = chr >= 0x20 and chr <= 0x7E; + elseif char_class == "punct" then + match = class_ascii_punct[chr]; + elseif char_class == "space" then + match = chr >= 0x09 and chr <= 0x0D or chr == 0x20; + elseif char_class == "upper" then + match = chr >= 0x41 and chr <= 0x5A; + elseif char_class == "word" then + match = chr >= 0x30 and chr <= 0x39 or chr >= 0x41 and chr <= 0x5A or chr >= 0x61 and chr <= 0x7A or chr == 0x5F; + end; + if negate then + return not match; + end; + return match; + elseif tkn_part[1] == "category" then + local chr_category = u_categories[chr] or 'Cn'; + local category_v = tkn_part[3]; + local category_len = #category_v; + if category_len == 3 then + local match = false; + if category_v == "Xan" or category_v == "Xwd" then + match = chr_category:find("^[LN]") or category_v == "Xwd" and chr == 0x5F; + elseif category_v == "Xps" or category_v == "Xsp" then + match = chr_category:sub(1, 1) == 'Z' or chr >= 0x09 and chr <= 0x0D; + elseif category_v == "Xuc" then + match = tkn_char_match(xuc_chr, str_arr, i, flags, verb_flags); + end; + if tkn_part[2] then + return not match; + end + return match; + elseif chr_category:sub(1, category_len) == category_v then + return not tkn_part[2]; + end; + return tkn_part[2]; + elseif tkn_part[1] == 0x2E then + return flags.dotAll or not is_newline(str_arr, i, verb_flags); + elseif tkn_part[1] == 0x4E then + return not is_newline(str_arr, i, verb_flags); + elseif tkn_part[1] == 0x52 then + if verb_flags.newline_seq == 0 then + -- CR, LF or CRLF + return chr == 0x0A or chr == 0x0D; + end; + -- any unicode newline + return chr == 0x0A or chr == 0x0B or chr == 0x0C or chr == 0x0D or chr == 0x85 or chr == 0x2028 or chr == 0x2029; + end; + return false; +end; + +local function find_alternation(token, i, count) + while true do + local v = token[i]; + local is_table = type(v) == "table"; + if v == alternation then + return i, count; + elseif is_table and v[1] == 0x28 then + if count then + count += v.count; + end; + i = v[3]; + elseif is_table and v[1] == "quantifier" and type(v[5]) == "table" and v[5][1] == 0x28 then + if count then + count += v[5].count; + end; + i = v[5][3]; + elseif not v or is_table and v[1] == 0x29 then + return nil, count; + elseif count then + if is_table and v[1] == "quantifier" then + count += v[3]; + else + count += 1; + end; + end; + i += 1; + end; +end; + +local function re_rawfind(token, str_arr, init, flags, verb_flags, as_bool) + local tkn_i, str_i, start_i = 0, init, init; + local states = { }; + while tkn_i do + if tkn_i == 0 then + tkn_i += 1; + local next_alt = find_alternation(token, tkn_i); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + continue; + end; + local ctkn = token[tkn_i]; + local tkn_type = type(ctkn) == "table" and ctkn[1]; + if not ctkn then + break; + elseif ctkn == "ACCEPT" then + local not_lookaround = true; + local close_i = tkn_i; + repeat + close_i += 1; + local is_table = type(token[close_i]) == "table"; + local close_i_tkn = token[close_i]; + if is_table and (close_i_tkn[1] == 0x28 or close_i_tkn[1] == "quantifier" and type(close_i_tkn[5]) == "table" and close_i_tkn[5][1] == 0x28) then + close_i = close_i_tkn[1] == "quantifier" and close_i_tkn[5][3] or close_i_tkn[3]; + elseif is_table and close_i_tkn[1] == 0x29 and (close_i_tkn[4] == 0x21 or close_i_tkn[4] == 0x3D) then + not_lookaround = false; + tkn_i = close_i; + break; + end; + until not close_i_tkn; + if not_lookaround then + break; + end; + elseif ctkn == "PRUNE" or ctkn == "SKIP" then + table.insert(states, 1, { ctkn, str_i }); + tkn_i += 1; + elseif tkn_type == 0x28 then + table.insert(states, 1, { "group", tkn_i, str_i, nil, ctkn[2], ctkn[3], ctkn[4] }); + tkn_i += 1; + local next_alt, count = find_alternation(token, tkn_i, (ctkn[4] == 0x21 or ctkn[4] == 0x3D) and ctkn[5] and 0); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + if count then + str_i -= count; + end; + elseif tkn_type == 0x29 and ctkn[4] ~= 0x21 then + if ctkn[4] == 0x21 or ctkn[4] == 0x3D then + while true do + local selected_match_start; + local selected_state = table.remove(states, 1); + if selected_state[1] == "group" and selected_state[2] == ctkn[3] then + if (ctkn[4] == 0x21 or ctkn[4] == 0x3D) and not ctkn[5] then + str_i = selected_state[3]; + end; + if selected_match_start then + table.insert(states, 1, selected_match_start); + end; + break; + elseif selected_state[1] == "matchStart" and not selected_match_start and ctkn[4] == 0x3D then + selected_match_start = selected_state; + end; + end; + elseif ctkn[4] == 0x3E then + repeat + local selected_state = table.remove(states, 1); + until not selected_state or selected_state[1] == "group" and selected_state[2] == ctkn[3]; + else + for i, v in ipairs(states) do + if v[1] == "group" and v[2] == ctkn[3] then + if v.jmp then + -- recursive match + tkn_i = v.jmp; + end; + v[4] = str_i; + if v[7] == "quantifier" and v[10] + 1 < v[9] then + if token[ctkn[3]][4] ~= "lazy" or v[10] + 1 < v[8] then + tkn_i = ctkn[3]; + end; + local ctkn1 = token[ctkn[3]]; + local new_group = { "group", v[2], str_i, nil, ctkn1[5][2], ctkn1[5][3], "quantifier", ctkn1[2], ctkn1[3], v[10] + 1, v[11], ctkn1[4] }; + table.insert(states, 1, new_group); + if v[11] then + table.insert(states, 1, { "alternation", v[11], str_i }); + end; + end; + break; + end; + end; + end; + tkn_i += 1; + elseif tkn_type == 0x4B then + table.insert(states, 1, { "matchStart", str_i }); + tkn_i += 1; + elseif tkn_type == 0x7C then + local close_i = tkn_i; + repeat + close_i += 1; + local is_table = type(token[close_i]) == "table"; + local close_i_tkn = token[close_i]; + if is_table and (close_i_tkn[1] == 0x28 or close_i_tkn[1] == "quantifier" and type(close_i_tkn[5]) == "table" and close_i_tkn[5][1] == 0x28) then + close_i = close_i_tkn[1] == "quantifier" and close_i_tkn[5][3] or close_i_tkn[3]; + end; + until is_table and close_i_tkn[1] == 0x29 or not close_i_tkn; + if token[close_i] then + for _, v in ipairs(states) do + if v[1] == "group" and v[6] == close_i then + tkn_i = v[6]; + break; + end; + end; + else + tkn_i = close_i; + end; + elseif tkn_type == "recurmatch" then + table.insert(states, 1, { "group", ctkn[3], str_i, nil, nil, token[ctkn[3]][3], nil, jmp = tkn_i }); + tkn_i = ctkn[3] + 1; + local next_alt, count = find_alternation(token, tkn_i); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + else + local match; + if ctkn == "FAIL" then + match = false; + elseif tkn_type == 0x29 then + repeat + local selected_state = table.remove(states, 1); + until selected_state[1] == "group" and selected_state[2] == ctkn[3]; + elseif tkn_type == "quantifier" then + if type(ctkn[5]) == "table" and ctkn[5][1] == 0x28 then + local next_alt = find_alternation(token, tkn_i + 1); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + table.insert(states, next_alt and 2 or 1, { "group", tkn_i, str_i, nil, ctkn[5][2], ctkn[5][3], "quantifier", ctkn[2], ctkn[3], 0, next_alt, ctkn[4] }); + if ctkn[4] == "lazy" and ctkn[2] == 0 then + tkn_i = ctkn[5][3]; + end; + match = true; + else + local start_i, end_i; + local pattern_count = 1; + local is_backref = type(ctkn[5]) == "table" and ctkn[5][1] == "backref"; + if is_backref then + pattern_count = 0; + local group_n = ctkn[5][2]; + for _, v in ipairs(states) do + if v[1] == "group" and v[5] == group_n then + start_i, end_i = v[3], v[4]; + pattern_count = end_i - start_i; + break; + end; + end; + end; + local min_max_i = str_i + ctkn[2] * pattern_count; + local mcount = 0; + while mcount < ctkn[3] do + if is_backref then + if start_i and end_i then + local org_i = str_i; + if utf8_sub(str_arr.s, start_i, end_i) ~= utf8_sub(str_arr.s, org_i, str_i + pattern_count) then + break; + end; + else + break; + end; + elseif not tkn_char_match(ctkn[5], str_arr, str_i, flags, verb_flags) then + break; + end; + str_i += pattern_count; + mcount += 1; + end; + match = mcount >= ctkn[2]; + if match and ctkn[4] ~= "possessive" then + if ctkn[4] == "lazy" then + min_max_i, str_i = str_i, min_max_i; + end; + table.insert(states, 1, { "quantifier", tkn_i, str_i, math.min(min_max_i, str_arr.n + 1), (ctkn[4] == "lazy" and 1 or -1) * pattern_count }); + end; + end; + elseif tkn_type == "backref" then + local start_i, end_i; + local group_n = ctkn[2]; + for _, v in ipairs(states) do + if v[1] == "group" and v[5] == group_n then + start_i, end_i = v[3], v[4]; + break; + end; + end; + if start_i and end_i then + local org_i = str_i; + str_i += end_i - start_i; + match = utf8_sub(str_arr.s, start_i, end_i) == utf8_sub(str_arr.s, org_i, str_i); + end; + else + local chr = str_arr[str_i]; + if tkn_type == 0x24 or tkn_type == 0x5A or tkn_type == 0x7A then + match = str_i == str_arr.n + 1 or tkn_type == 0x24 and flags.multiline and is_newline(str_arr, str_i + 1, verb_flags) or tkn_type == 0x5A and str_i == str_arr.n and is_newline(str_arr, str_i, verb_flags); + elseif tkn_type == 0x5E or tkn_type == 0x41 or tkn_type == 0x47 then + match = str_i == 1 or tkn_type == 0x5E and flags.multiline and is_newline(str_arr, str_i - 1, verb_flags) or tkn_type == 0x47 and str_i == init; + elseif tkn_type == 0x42 or tkn_type == 0x62 then + local start_m = str_i == 1 or flags.multiline and is_newline(str_arr, str_i - 1, verb_flags); + local end_m = str_i == str_arr.n + 1 or flags.multiline and is_newline(str_arr, str_i, verb_flags); + local w_m = tkn_char_match(ctkn[2], str_arr[str_i - 1], flags) and 0 or tkn_char_match(ctkn[2], chr, flags) and 1; + if w_m == 0 then + match = end_m or not tkn_char_match(ctkn[2], chr, flags); + elseif w_m then + match = start_m or not tkn_char_match(ctkn[2], str_arr[str_i - 1], flags); + end; + if tkn_type == 0x42 then + match = not match; + end; + else + match = tkn_char_match(ctkn, str_arr, str_i, flags, verb_flags); + str_i += 1; + end; + end; + if not match then + while true do + local prev_type, prev_state = states[1] and states[1][1], states[1]; + if not prev_type or prev_type == "PRUNE" or prev_type == "SKIP" then + if prev_type then + table.clear(states); + end; + if start_i > str_arr.n then + if as_bool then + return false; + end; + return nil; + end; + start_i = prev_type == "SKIP" and prev_state[2] or start_i + 1; + tkn_i, str_i = 0, start_i; + break; + elseif prev_type == "alternation" then + tkn_i, str_i = prev_state[2], prev_state[3]; + local next_alt, count = find_alternation(token, tkn_i + 1); + if next_alt then + prev_state[2] = next_alt; + else + table.remove(states, 1); + end; + if count then + str_i -= count; + end; + break; + elseif prev_type == "group" then + if prev_state[7] == "quantifier" then + if prev_state[12] == "greedy" and prev_state[10] >= prev_state[8] + or prev_state[12] == "lazy" and prev_state[10] < prev_state[9] and not prev_state[13] then + tkn_i, str_i = prev_state[12] == "greedy" and prev_state[6] or prev_state[2], prev_state[3]; + if prev_state[12] == "greedy" then + table.remove(states, 1); + break; + elseif prev_state[10] >= prev_state[8] then + prev_state[13] = true; + break; + end; + end; + elseif prev_state[7] == 0x21 then + table.remove(states, 1); + tkn_i, str_i = prev_state[6], prev_state[3]; + break; + end; + elseif prev_type == "quantifier" then + if math.sign(prev_state[4] - prev_state[3]) == math.sign(prev_state[5]) then + prev_state[3] += prev_state[5]; + tkn_i, str_i = prev_state[2], prev_state[3]; + break; + end; + end; + -- keep match out state and recursive state, can be safely removed + -- prevents infinite loop + table.remove(states, 1); + end; + end; + tkn_i += 1; + end; + end; + if as_bool then + return true; + end; + local match_start_ran = false; + local span = table.create(token.group_n); + span[0], span.n = { start_i, str_i }, token.group_n; + for _, v in ipairs(states) do + if v[1] == "matchStart" and not match_start_ran then + span[0][1], match_start_ran = v[2], true; + elseif v[1] == "group" and v[5] and not span[v[5]] then + span[v[5]] = { v[3], v[4] }; + end; + end; + return span; +end; + +--[[ Methods ]]-- +re_m.test = check_re('RegEx', 'test', function(self, str, init) + return re_rawfind(self.token, to_str_arr(str, init), 1, self.flags, self.verb_flags, true); +end); + +re_m.match = check_re('RegEx', 'match', function(self, str, init, source) + local span = re_rawfind(self.token, to_str_arr(str, init), 1, self.flags, self.verb_flags, false); + if not span then + return nil; + end; + return new_match(span, self.group_id, source, str); +end); + +re_m.matchall = check_re('RegEx', 'matchall', function(self, str, init, source) + str = to_str_arr(str, init); + local i = 1; + return function() + local span = i <= str.n + 1 and re_rawfind(self.token, str, i, self.flags, self.verb_flags, false); + if not span then + return nil; + end; + i = span[0][2] + (span[0][1] >= span[0][2] and 1 or 0); + return new_match(span, self.group_id, source, str.s); + end; +end); + +local function insert_tokenized_sub(repl_r, str, span, tkn) + for _, v in ipairs(tkn) do + if type(v) == "table" then + if v[1] == "condition" then + if span[v[2]] then + if v[3] then + insert_tokenized_sub(repl_r, str, span, v[3]); + else + table.move(str, span[v[2]][1], span[v[2]][2] - 1, #repl_r + 1, repl_r); + end; + elseif v[4] then + insert_tokenized_sub(repl_r, str, span, v[4]); + end; + else + table.move(v, 1, #v, #repl_r + 1, repl_r); + end; + elseif span[v] then + table.move(str, span[v][1], span[v][2] - 1, #repl_r + 1, repl_r); + end; + end; + repl_r.n = #repl_r; + return repl_r; +end; + +re_m.sub = check_re('RegEx', 'sub', function(self, repl, str, n, repl_flag_str, source) + if repl_flag_str ~= nil and type(repl_flag_str) ~= "number" and type(repl_flag_str) ~= "string" then + error(string.format("invalid argument #5 to 'sub' (string expected, got %s)", typeof(repl_flag_str)), 3); + end + local repl_flags = { + l = false, o = false, u = false, + }; + for f in string.gmatch(repl_flag_str or '', utf8.charpattern) do + if repl_flags[f] ~= false then + error("invalid regular expression substitution flag " .. f, 3); + end; + repl_flags[f] = true; + end; + local repl_type = type(repl); + if repl_type == "number" then + repl ..= ''; + elseif repl_type ~= "string" and repl_type ~= "function" and (not repl_flags.o or repl_type ~= "table") then + error(string.format("invalid argument #2 to 'sub' (string/function%s expected, got %s)", repl_flags.o and "/table" or '', typeof(repl)), 3); + end; + if tonumber(n) then + n = tonumber(n); + if n <= -1 or n ~= n then + n = math.huge; + end; + elseif n ~= nil then + error(string.format("invalid argument #4 to 'sub' (number expected, got %s)", typeof(n)), 3); + else + n = math.huge; + end; + if n < 1 then + return str, 0; + end; + local min_repl_n = 0; + if repl_type == "string" then + repl = to_str_arr(repl); + if not repl_flags.l then + local i1 = 0; + local repl_r = table.create(3); + local group_n = self.token.group_n; + local conditional_c = { }; + while i1 < repl.n do + local i2 = i1; + repeat + i2 += 1; + until not repl[i2] or repl[i2] == 0x24 or repl[i2] == 0x5C or (repl[i2] == 0x3A or repl[i2] == 0x7D) and conditional_c[1]; + min_repl_n += i2 - i1 - 1; + if i2 - i1 > 1 then + table.insert(repl_r, table.move(repl, i1 + 1, i2 - 1, 1, table.create(i2 - i1 - 1))); + end; + if repl[i2] == 0x3A then + local current_conditional_c = conditional_c[1]; + if current_conditional_c[2] then + error("malformed substitution pattern", 3); + end; + current_conditional_c[2] = table.move(repl_r, current_conditional_c[3], #repl_r, 1, table.create(#repl_r + 1 - current_conditional_c[3])); + for i3 = #repl_r, current_conditional_c[3], -1 do + repl_r[i3] = nil; + end; + elseif repl[i2] == 0x7D then + local current_conditional_c = table.remove(conditional_c, 1); + local second_c = table.move(repl_r, current_conditional_c[3], #repl_r, 1, table.create(#repl_r + 1 - current_conditional_c[3])); + for i3 = #repl_r, current_conditional_c[3], -1 do + repl_r[i3] = nil; + end; + table.insert(repl_r, { "condition", current_conditional_c[1], current_conditional_c[2] ~= true and (current_conditional_c[2] or second_c), current_conditional_c[2] and second_c }); + elseif repl[i2] then + i2 += 1; + local subst_c = repl[i2]; + if not subst_c then + if repl[i2 - 1] == 0x5C then + error("replacement string must not end with a trailing backslash", 3); + end; + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, repl[i2 - 1]); + else + table.insert(repl_r, { repl[i2 - 1] }); + end; + elseif subst_c == 0x5C and repl[i2 - 1] == 0x24 then + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, 0x24); + else + table.insert(repl_r, { 0x24 }); + end; + i2 -= 1; + min_repl_n += 1; + elseif subst_c == 0x30 then + table.insert(repl_r, 0); + elseif subst_c > 0x30 and subst_c <= 0x39 then + local start_i2 = i2; + local group_i = subst_c - 0x30; + while repl[i2 + 1] and repl[i2 + 1] >= 0x30 and repl[i2 + 1] <= 0x39 do + group_i ..= repl[i2 + 1] - 0x30; + i2 += 1; + end; + group_i = tonumber(group_i); + if not repl_flags.u and group_i > group_n then + error("reference to non-existent subpattern", 3); + end; + table.insert(repl_r, group_i); + elseif subst_c == 0x7B and repl[i2 - 1] == 0x24 then + i2 += 1; + local start_i2 = i2; + while repl[i2] and + (repl[i2] >= 0x30 and repl[i2] <= 0x39 + or repl[i2] >= 0x41 and repl[i2] <= 0x5A + or repl[i2] >= 0x61 and repl[i2] <= 0x7A + or repl[i2] == 0x5F) do + i2 += 1; + end; + if (repl[i2] == 0x7D or repl[i2] == 0x3A and (repl[i2 + 1] == 0x2B or repl[i2 + 1] == 0x2D)) and i2 ~= start_i2 then + local group_k = utf8_sub(repl.s, start_i2, i2); + if repl[start_i2] >= 0x30 and repl[start_i2] <= 0x39 then + group_k = tonumber(group_k); + if not repl_flags.u and group_k > group_n then + error("reference to non-existent subpattern", 3); + end; + else + group_k = self.group_id[group_k]; + if not repl_flags.u and (not group_k or group_k > group_n) then + error("reference to non-existent subpattern", 3); + end; + end; + if repl[i2] == 0x3A then + i2 += 1; + table.insert(conditional_c, { group_k, repl[i2] == 0x2D, #repl_r + 1 }); + else + table.insert(repl_r, group_k); + end; + else + error("malformed substitution pattern", 3); + end; + else + local c_escape_char; + if repl[i2 - 1] == 0x24 then + if subst_c ~= 0x24 then + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, 0x24); + else + table.insert(repl_r, { 0x24 }); + end; + end; + else + c_escape_char = escape_chars[repl[i2]]; + if type(c_escape_char) ~= "number" then + c_escape_char = nil; + end; + end; + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, c_escape_char or repl[i2]); + else + table.insert(repl_r, { c_escape_char or repl[i2] }); + end; + min_repl_n += 1; + end; + end; + i1 = i2; + end; + if conditional_c[1] then + error("malformed substitution pattern", 3); + end; + if not repl_r[2] and type(repl_r[1]) == "table" and repl_r[1][1] ~= "condition" then + repl, repl.n = repl_r[1], #repl_r[1]; + else + repl, repl_type = repl_r, "subst_string"; + end; + end; + end; + str = to_str_arr(str); + local incr, i0, count = 0, 1, 0; + while i0 <= str.n + incr + 1 do + local span = re_rawfind(self.token, str, i0, self.flags, self.verb_flags, false); + if not span then + break; + end; + local repl_r; + if repl_type == "string" then + repl_r = repl; + elseif repl_type == "subst_string" then + repl_r = insert_tokenized_sub(table.create(min_repl_n), str, span, repl); + else + local re_match; + local repl_c; + if repl_type == "table" then + re_match = utf8_sub(str.s, span[0][1], span[0][2]); + repl_c = repl[re_match]; + else + re_match = new_match(span, self.group_id, source, str.s); + repl_c = repl(re_match); + end; + if repl_c == re_match or repl_flags.o and not repl_c then + local repl_n = span[0][2] - span[0][1]; + repl_r = table.move(str, span[0][1], span[0][2] - 1, 1, table.create(repl_n)); + repl_r.n = repl_n; + elseif type(repl_c) == "string" then + repl_r = to_str_arr(repl_c); + elseif type(repl_c) == "number" then + repl_r = to_str_arr(repl_c .. ''); + elseif repl_flags.o then + error(string.format("invalid replacement value (a %s)", type(repl_c)), 3); + else + repl_r = { n = 0 }; + end; + end; + local match_len = span[0][2] - span[0][1]; + local repl_len = math.min(repl_r.n, match_len); + for i1 = 0, repl_len - 1 do + str[span[0][1] + i1] = repl_r[i1 + 1]; + end; + local i1 = span[0][1] + repl_len; + i0 = span[0][2]; + if match_len > repl_r.n then + for i2 = 1, match_len - repl_r.n do + table.remove(str, i1); + incr -= 1; + i0 -= 1; + end; + elseif repl_r.n > match_len then + for i2 = 1, repl_r.n - match_len do + table.insert(str, i1 + i2 - 1, repl_r[repl_len + i2]); + incr += 1; + i0 += 1; + end; + end; + if match_len <= 0 then + i0 += 1; + end; + count += 1; + if n < count + 1 then + break; + end; + end; + return from_str_arr(str), count; +end); + +re_m.split = check_re('RegEx', 'split', function(self, str, n) + if tonumber(n) then + n = tonumber(n); + if n <= -1 or n ~= n then + n = math.huge; + end; + elseif n ~= nil then + error(string.format("invalid argument #3 to 'split' (number expected, got %s)", typeof(n)), 3); + else + n = math.huge; + end; + str = to_str_arr(str); + local i, count = 1, 0; + local ret = { }; + local prev_empty = 0; + while i <= str.n + 1 do + count += 1; + local span = n >= count and re_rawfind(self.token, str, i, self.flags, self.verb_flags, false); + if not span then + break; + end; + table.insert(ret, utf8_sub(str.s, i - prev_empty, span[0][1])); + prev_empty = span[0][1] >= span[0][2] and 1 or 0; + i = span[0][2] + prev_empty; + end; + table.insert(ret, string.sub(str.s, utf8.offset(str.s, i - prev_empty))); + return ret; +end); + +-- +local function re_index(self, index) + return re_m[index] or proxy[self].flags[index]; +end; + +local function re_tostr(self) + return proxy[self].pattern_repr .. proxy[self].flag_repr; +end; +-- + +local other_valid_group_char = { + -- non-capturing group + [0x3A] = true, + -- lookarounds + [0x21] = true, [0x3D] = true, + -- atomic + [0x3E] = true, + -- branch reset + [0x7C] = true, +}; + +local function tokenize_ptn(codes, flags) + if flags.unicode and not options.unicodeData then + return "options.unicodeData cannot be turned off while having unicode flag"; + end; + local i, len = 1, codes.n; + local group_n = 0; + local outln, group_id, verb_flags = { }, { }, { + newline = 1, newline_seq = 1, not_empty = 0, + }; + while i <= len do + local c = codes[i]; + if c == 0x28 then + -- Match + local ret; + if codes[i + 1] == 0x2A then + i += 2; + local start_i = i; + while codes[i] + and (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F or codes[i] == 0x3A) do + i += 1; + end; + if codes[i] ~= 0x29 and codes[i - 1] ~= 0x3A then + -- fallback as normal and ( can't be repeated + return "quantifier doesn't follow a repeatable pattern"; + end; + local selected_verb = utf8_sub(codes.s, start_i, i); + if selected_verb == "positive_lookahead:" or selected_verb == "negative_lookhead:" + or selected_verb == "positive_lookbehind:" or selected_verb == "negative_lookbehind:" + or selected_verb:find("^[pn]l[ab]:$") then + ret = { 0x28, nil, nil, selected_verb:find('^n') and 0x21 or 0x3D, selected_verb:find('b', 3, true) and 1 }; + elseif selected_verb == "atomic:" then + ret = { 0x28, nil, nil, 0x3E, nil }; + elseif selected_verb == "ACCEPT" or selected_verb == "FAIL" or selected_verb == 'F' or selected_verb == "PRUNE" or selected_verb == "SKIP" then + ret = selected_verb == 'F' and "FAIL" or selected_verb; + else + if line_verbs[selected_verb] then + verb_flags.newline = selected_verb; + elseif selected_verb == "BSR_ANYCRLF" or selected_verb == "BSR_UNICODE" then + verb_flags.newline_seq = selected_verb == "BSR_UNICODE" and 1 or 0; + elseif selected_verb == "NOTEMPTY" or selected_verb == "NOTEMPTY_ATSTART" then + verb_flags.not_empty = selected_verb == "NOTEMPTY" and 1 or 2; + else + return "unknown or malformed verb"; + end; + if outln[1] then + return "this verb must be placed at the beginning of the regex"; + end; + end; + elseif codes[i + 1] == 0x3F then + -- ? syntax + i += 2; + if codes[i] == 0x23 then + -- comments + i = table.find(codes, 0x29, i); + if not i then + return "unterminated parenthetical"; + end; + i += 1; + continue; + elseif not codes[i] then + return "unterminated parenthetical"; + end; + ret = { 0x28, nil, nil, codes[i], nil }; + if codes[i] == 0x30 and codes[i + 1] == 0x29 then + -- recursive match entire pattern + ret[1], ret[2], ret[3], ret[5] = "recurmatch", 0, 0, nil; + elseif codes[i] > 0x30 and codes[i] <= 0x39 then + -- recursive match + local org_i = i; + i += 1; + while codes[i] >= 0x30 and codes[i] <= 0x30 do + i += 1; + end; + if codes[i] ~= 0x29 then + return "invalid group structure"; + end; + ret[1], ret[2], ret[4] = "recurmatch", tonumber(utf8_sub(codes.s, org_i, i)), nil; + elseif codes[i] == 0x3C and codes[i + 1] == 0x21 or codes[i + 1] == 0x3D then + -- lookbehinds + i += 1; + ret[4], ret[5] = codes[i], 1; + elseif codes[i] == 0x7C then + -- branch reset + ret[5] = group_n; + elseif codes[i] == 0x50 or codes[i] == 0x3C or codes[i] == 0x27 then + if codes[i] == 0x50 then + i += 1; + end; + if codes[i] == 0x3D then + -- backref + local start_i = i + 1; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if not codes[i] then + return "unterminated parenthetical"; + elseif codes[i] ~= 0x29 or i == start_i then + return "invalid group structure"; + end; + ret = { "backref", utf8_sub(codes.s, start_i, i) }; + elseif codes[i] == 0x3C or codes[i - 1] ~= 0x50 and codes[i] == 0x27 then + -- named capture + local delimiter = codes[i] == 0x27 and 0x27 or 0x3E; + local start_i = i + 1; + i += 1; + if codes[i] == 0x29 then + return "missing character in subpattern"; + elseif codes[i] >= 0x30 and codes[i] <= 0x39 then + return "subpattern name must not begin with a digit"; + elseif not (codes[i] >= 0x41 and codes[i] <= 0x5A or codes[i] >= 0x61 and codes[i] <= 0x7A or codes[i] == 0x5F) then + return "invalid character in subpattern"; + end; + i += 1; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if not codes[i] then + return "unterminated parenthetical"; + elseif codes[i] ~= delimiter then + return "invalid character in subpattern"; + end; + local name = utf8_sub(codes.s, start_i, i); + group_n += 1; + if (group_id[name] or group_n) ~= group_n then + return "subpattern name already exists"; + end; + for name1, group_n1 in pairs(group_id) do + if name ~= name1 and group_n == group_n1 then + return "different names for subpatterns of the same number aren't permitted"; + end; + end; + group_id[name] = group_n; + ret[2], ret[4] = group_n, nil; + else + return "invalid group structure"; + end; + elseif not other_valid_group_char[codes[i]] then + return "invalid group structure"; + end; + else + group_n += 1; + ret = { 0x28, group_n, nil, nil }; + end; + if ret then + table.insert(outln, ret); + end; + elseif c == 0x29 then + -- Close parenthesis + local i1 = #outln + 1; + local lookbehind_c = -1; + local current_lookbehind_c = 0; + local max_c, group_c = 0, 0; + repeat + i1 -= 1; + local v, is_table = outln[i1], type(outln[i1]) == "table"; + if is_table and v[1] == 0x28 then + group_c += 1; + if current_lookbehind_c and v.count then + current_lookbehind_c += v.count; + end; + if not v[3] then + if v[4] == 0x7C then + group_n = v[5] + math.max(max_c, group_c); + end; + if current_lookbehind_c ~= lookbehind_c and lookbehind_c ~= -1 then + lookbehind_c = nil; + else + lookbehind_c = current_lookbehind_c; + end; + break; + end; + elseif v == alternation then + if current_lookbehind_c ~= lookbehind_c and lookbehind_c ~= -1 then + lookbehind_c, current_lookbehind_c = nil, nil; + else + lookbehind_c, current_lookbehind_c = current_lookbehind_c, 0; + end; + max_c, group_c = math.max(max_c, group_c), 0; + elseif current_lookbehind_c then + if is_table and v[1] == "quantifier" then + if v[2] == v[3] then + current_lookbehind_c += v[2]; + else + current_lookbehind_c = nil; + end; + else + current_lookbehind_c += 1; + end; + end; + until i1 < 1; + if i1 < 1 then + return "unmatched ) in regular expression"; + end; + local v = outln[i1]; + local outln_len_p_1 = #outln + 1; + local ret = { 0x29, v[2], i1, v[4], v[5], count = lookbehind_c }; + if (v[4] == 0x21 or v[4] == 0x3D) and v[5] and not lookbehind_c then + return "lookbehind assertion is not fixed width"; + end; + v[3] = outln_len_p_1; + table.insert(outln, ret); + elseif c == 0x2E then + table.insert(outln, dot); + elseif c == 0x5B then + -- Character set + local negate, char_class = false, nil; + i += 1; + local start_i = i; + if codes[i] == 0x5E then + negate = true; + i += 1; + elseif codes[i] == 0x2E or codes[i] == 0x3A or codes[i] == 0x3D then + -- POSIX character classes + char_class = codes[i]; + end; + local ret; + if codes[i] == 0x5B or codes[i] == 0x5C then + ret = { }; + else + ret = { codes[i] }; + i += 1; + end; + while codes[i] ~= 0x5D do + if not codes[i] then + return "unterminated character class"; + elseif codes[i] == 0x2D and ret[1] and type(ret[1]) == "number" then + if codes[i + 1] == 0x5D then + table.insert(ret, 1, 0x2D); + else + i += 1; + local ret_c = codes[i]; + if ret_c == 0x5B then + if codes[i + 1] == 0x2E or codes[i + 1] == 0x3A or codes[i + 1] == 0x3D then + -- Check for POSIX character class, name does not matter + local i1 = i + 2; + repeat + i1 = table.find(codes, 0x5D, i1); + until not i1 or codes[i1 - 1] ~= 0x5C; + if not i1 then + return "unterminated character class"; + elseif codes[i1 - 1] == codes[i + 1] and i1 - 1 ~= i + 1 then + return "invalid range in character class"; + end; + end; + if ret[1] > 0x5B then + return "invalid range in character class"; + end; + elseif ret_c == 0x5C then + i += 1; + if codes[i] == 0x78 then + local radix0, radix1; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix0 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix1 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + else + i -= 1; + end; + else + i -= 1; + end; + ret_c = radix0 and (radix1 and 16 * radix0 + radix1 or radix0) or 0; + elseif codes[i] >= 0x30 and codes[i] <= 0x37 then + local radix0, radix1, radix2 = codes[i] - 0x30, nil, nil; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix1 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix2 = codes[i] - 0x30; + else + i -= 1; + end; + else + i -= 1; + end; + ret_c = radix1 and (radix2 and 64 * radix0 + 8 * radix1 + radix2 or 8 * radix0 + radix1) or radix0; + else + ret_c = escape_chars[codes[i]] or codes[i]; + if type(ret_c) ~= "number" then + return "invalid range in character class"; + end; + end; + elseif ret[1] > ret_c then + return "invalid range in character class"; + end; + ret[1] = { "range", ret[1], ret_c }; + end; + elseif codes[i] == 0x5B then + if codes[i + 1] == 0x2E or codes[i + 1] == 0x3A or codes[i + 1] == 0x3D then + local i1 = i + 2; + repeat + i1 = table.find(codes, 0x5D, i1); + until not i1 or codes[i1 - 1] ~= 0x5C; + if not i1 then + return "unterminated character class"; + elseif codes[i1 - 1] ~= codes[i + 1] or i1 - 1 == i + 1 then + table.insert(ret, 1, 0x5B); + elseif codes[i1 - 1] == 0x2E or codes[i1 - 1] == 0x3D then + return "POSIX collating elements aren't supported"; + elseif codes[i1 - 1] == 0x3A then + -- I have no plans to support escape codes (\) in character class names + local negate = codes[i + 3] == 0x5E; + local class_name = utf8_sub(codes.s, i + (negate and 3 or 2), i1 - 1); + -- If not valid then throw an error + if not posix_class_names[class_name] then + return "unknown POSIX class name"; + end; + table.insert(ret, 1, { "class", class_name, negate }); + i = i1; + end; + else + table.insert(ret, 1, 0x5B); + end; + elseif codes[i] == 0x5C then + i += 1; + if codes[i] == 0x78 then + local radix0, radix1; + i += 1; + if codes[i] == 0x7B then + i += 1; + local org_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed hexadecimal character"; + elseif i - org_i > 4 then + return "character offset too large"; + end; + table.insert(ret, 1, tonumber(utf8_sub(codes.s, org_i, i), 16)); + else + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix0 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix1 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + else + i -= 1; + end; + else + i -= 1; + end; + table.insert(ret, 1, radix0 and (radix1 and 16 * radix0 + radix1 or radix0) or 0); + end; + elseif codes[i] >= 0x30 and codes[i] <= 0x37 then + local radix0, radix1, radix2 = codes[i] - 0x30, nil, nil; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix1 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix2 = codes[i] - 0x30; + else + i -= 1; + end; + else + i -= 1; + end; + table.insert(ret, 1, radix1 and (radix2 and 64 * radix0 + 8 * radix1 + radix2 or 8 * radix0 + radix1) or radix0); + elseif codes[i] == 0x45 then + -- intentionally left blank, \E that's not preceded \Q is ignored + elseif codes[i] == 0x51 then + local start_i = i + 1; + repeat + i = table.find(codes, 0x5C, i + 1); + until not i or codes[i + 1] == 0x45; + table.move(codes, start_i, i and i - 1 or #codes, #outln + 1, outln); + if not i then + break; + end; + i += 1; + elseif codes[i] == 0x4E then + if codes[i + 1] == 0x7B and codes[i + 2] == 0x55 and codes[i + 3] == 0x2B and flags.unicode then + i += 4; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == start_i then + return "malformed Unicode code point"; + end; + local code_point = tonumber(utf8_sub(codes.s, start_i, i)); + table.insert(ret, 1, code_point); + else + return "invalid escape sequence"; + end; + elseif codes[i] == 0x50 or codes[i] == 0x70 then + if not options.unicodeData then + return "options.unicodeData cannot be turned off when using \\p"; + end; + i += 1; + if codes[i] ~= 0x7B then + local c_name = utf8.char(codes[i] or 0); + if not valid_categories[c_name] then + return "unknown or malformed script name"; + end; + table.insert(ret, 1, { "category", false, c_name }); + else + local negate = codes[i] == 0x50; + i += 1; + if codes[i] == 0x5E then + i += 1; + negate = not negate; + end; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if codes[i] ~= 0x7D then + return "unknown or malformed script name"; + end; + local c_name = utf8_sub(codes.s, start_i, i); + local script_set = chr_scripts[c_name]; + if script_set then + table.insert(ret, 1, { "charset", negate, script_set }); + elseif not valid_categories[c_name] then + return "unknown or malformed script name"; + else + table.insert(ret, 1, { "category", negate, c_name }); + end; + end; + elseif codes[i] == 0x6F then + i += 1; + if codes[i] ~= 0x7B then + return "malformed octal code"; + end; + i += 1; + local org_i = i; + while codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed octal code"; + end; + local ret_chr = tonumber(utf8_sub(codes.s, org_i, i), 8); + if ret_chr > 0xFFFF then + return "character offset too large"; + end; + table.insert(ret, 1, ret_chr); + else + local esc_char = escape_chars[codes[i]]; + table.insert(ret, 1, type(esc_char) == "string" and { "class", esc_char, false } or esc_char or codes[i]); + end; + elseif flags.ignoreCase and codes[i] >= 0x61 and codes[i] <= 0x7A then + table.insert(ret, 1, codes[i] - 0x20); + else + table.insert(ret, 1, codes[i]); + end; + i += 1; + end; + if codes[i - 1] == char_class and i - 1 ~= start_i then + return char_class == 0x3A and "POSIX named classes are only support within a character set" or "POSIX collating elements aren't supported"; + end; + if not ret[2] and not negate then + table.insert(outln, ret[1]); + else + table.insert(outln, { "charset", negate, ret }); + end; + elseif c == 0x5C then + -- Escape char + i += 1; + local escape_c = codes[i]; + if not escape_c then + return "pattern may not end with a trailing backslash"; + elseif escape_c >= 0x30 and escape_c <= 0x39 then + local org_i = i; + while codes[i + 1] and codes[i + 1] >= 0x30 and codes[i + 1] <= 0x39 do + i += 1; + end; + local escape_d = tonumber(utf8_sub(codes.s, org_i, i + 1)); + if escape_d > group_n and i ~= org_i then + i = org_i; + local radix0, radix1, radix2; + if codes[i] <= 0x37 then + radix0 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix1 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix2 = codes[i] - 0x30; + else + i -= 1; + end; + else + i -= 1; + end; + end; + table.insert(outln, radix0 and (radix1 and (radix2 and 64 * radix0 + 8 * radix1 + radix2 or 8 * radix0 + radix1) or radix0) or codes[org_i]); + else + table.insert(outln, { "backref", escape_d }); + end; + elseif escape_c == 0x45 then + -- intentionally left blank, \E that's not preceded \Q is ignored + elseif escape_c == 0x51 then + local start_i = i + 1; + repeat + i = table.find(codes, 0x5C, i + 1); + until not i or codes[i + 1] == 0x45; + table.move(codes, start_i, i and i - 1 or #codes, #outln + 1, outln); + if not i then + break; + end; + i += 1; + elseif escape_c == 0x4E then + if codes[i + 1] == 0x7B and codes[i + 2] == 0x55 and codes[i + 3] == 0x2B and flags.unicode then + i += 4; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == start_i then + return "malformed Unicode code point"; + end; + local code_point = tonumber(utf8_sub(codes.s, start_i, i)); + table.insert(outln, code_point); + else + table.insert(outln, escape_chars[0x4E]); + end; + elseif escape_c == 0x50 or escape_c == 0x70 then + if not options.unicodeData then + return "options.unicodeData cannot be turned off when using \\p"; + end; + i += 1; + if codes[i] ~= 0x7B then + local c_name = utf8.char(codes[i] or 0); + if not valid_categories[c_name] then + return "unknown or malformed script name"; + end; + table.insert(outln, { "category", false, c_name }); + else + local negate = escape_c == 0x50; + i += 1; + if codes[i] == 0x5E then + i += 1; + negate = not negate; + end; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if codes[i] ~= 0x7D then + return "unknown or malformed script name"; + end; + local c_name = utf8_sub(codes.s, start_i, i); + local script_set = chr_scripts[c_name]; + if script_set then + table.insert(outln, { "charset", negate, script_set }); + elseif not valid_categories[c_name] then + return "unknown or malformed script name"; + else + table.insert(outln, { "category", negate, c_name }); + end; + end; + elseif escape_c == 0x67 and (codes[i + 1] == 0x7B or codes[i + 1] >= 0x30 and codes[i + 1] <= 0x39) then + local is_grouped = false; + i += 1; + if codes[i] == 0x7B then + i += 1; + is_grouped = true; + elseif codes[i] < 0x30 or codes[i] > 0x39 then + return "malformed reference code"; + end; + local org_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if is_grouped and codes[i] ~= 0x7D then + return "malformed reference code"; + end; + local ref_name = tonumber(utf8_sub(codes.s, org_i, i + (is_grouped and 0 or 1))); + table.insert(outln, { "backref", ref_name }); + if not is_grouped then + i -= 1; + end; + elseif escape_c == 0x6F then + i += 1; + if codes[i + 1] ~= 0x7B then + return "malformed octal code"; + end + i += 1; + local org_i = i; + while codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed octal code"; + end; + local ret_chr = tonumber(utf8_sub(codes.s, org_i, i), 8); + if ret_chr > 0xFFFF then + return "character offset too large"; + end; + table.insert(outln, ret_chr); + elseif escape_c == 0x78 then + local radix0, radix1; + i += 1; + if codes[i] == 0x7B then + i += 1; + local org_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed hexadecimal code"; + elseif i - org_i > 4 then + return "character offset too large"; + end; + table.insert(outln, tonumber(utf8_sub(codes.s, org_i, i), 16)); + else + if codes[i] and (codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66) then + radix0 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + i += 1; + if codes[i] and (codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66) then + radix1 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + else + i -= 1; + end; + else + i -= 1; + end; + table.insert(outln, radix0 and (radix1 and 16 * radix0 + radix1 or radix0) or 0); + end; + else + local esc_char = b_escape_chars[escape_c] or escape_chars[escape_c]; + table.insert(outln, esc_char or escape_c); + end; + elseif c == 0x2A or c == 0x2B or c == 0x3F or c == 0x7B then + -- Quantifier + local start_q, end_q; + if c == 0x7B then + local org_i = i + 1; + local start_i; + while codes[i + 1] and (codes[i + 1] >= 0x30 and codes[i + 1] <= 0x39 or codes[i + 1] == 0x2C and not start_i and i + 1 ~= org_i) do + i += 1; + if codes[i] == 0x2C then + start_i = i; + end; + end; + if codes[i + 1] == 0x7D then + i += 1; + if not start_i then + start_q = tonumber(utf8_sub(codes.s, org_i, i)); + end_q = start_q; + else + start_q, end_q = tonumber(utf8_sub(codes.s, org_i, start_i)), start_i + 1 == i and math.huge or tonumber(utf8_sub(codes.s, start_i + 1, i)); + if end_q < start_q then + return "numbers out of order in {} quantifier"; + end; + end; + else + table.move(codes, org_i - 1, i, #outln + 1, outln); + end; + else + start_q, end_q = c == 0x2B and 1 or 0, c == 0x3F and 1 or math.huge; + end; + if start_q then + local quantifier_type = flags.ungreedy and "lazy" or "greedy"; + if codes[i + 1] == 0x2B or codes[i + 1] == 0x3F then + i += 1; + quantifier_type = codes[i] == 0x2B and "possessive" or flags.ungreedy and "greedy" or "lazy"; + end; + local outln_len = #outln; + local last_outln_value = outln[outln_len]; + if not last_outln_value or type(last_outln_value) == "table" and (last_outln_value[1] == "quantifier" or last_outln_value[1] == 0x28 or b_escape_chars[last_outln_value[1]]) + or last_outln_value == alternation or type(last_outln_value) == "string" then + return "quantifier doesn't follow a repeatable pattern"; + end; + if end_q == 0 then + table.remove(outln); + elseif start_q ~= 1 or end_q ~= 1 then + if type(last_outln_value) == "table" and last_outln_value[1] == 0x29 then + outln_len = last_outln_value[3]; + end; + outln[outln_len] = { "quantifier", start_q, end_q, quantifier_type, outln[outln_len] }; + end; + end; + elseif c == 0x7C then + -- Alternation + table.insert(outln, alternation); + local i1 = #outln; + repeat + i1 -= 1; + local v1, is_table = outln[i1], type(outln[i1]) == "table"; + if is_table and v1[1] == 0x29 then + i1 = outln[i1][3]; + elseif is_table and v1[1] == 0x28 then + if v1[4] == 0x7C then + group_n = v1[5]; + end; + break; + end; + until not v1; + elseif c == 0x24 or c == 0x5E then + table.insert(outln, c == 0x5E and beginning_str or end_str); + elseif flags.ignoreCase and c >= 0x61 and c <= 0x7A then + table.insert(outln, c - 0x20); + elseif flags.extended and (c >= 0x09 and c <= 0x0D or c == 0x20 or c == 0x23) then + if c == 0x23 then + repeat + i += 1; + until not codes[i] or codes[i] == 0x0A or codes[i] == 0x0D; + end; + else + table.insert(outln, c); + end; + i += 1; + end; + local max_group_n = 0; + for i, v in ipairs(outln) do + if type(v) == "table" and (v[1] == 0x28 or v[1] == "quantifier" and type(v[5]) == "table" and v[5][1] == 0x28) then + if v[1] == "quantifier" then + v = v[5]; + end; + if not v[3] then + return "unterminated parenthetical"; + elseif v[2] then + max_group_n = math.max(max_group_n, v[2]); + end; + elseif type(v) == "table" and (v[1] == "backref" or v[1] == "recurmatch") then + if not group_id[v[2]] and (type(v[2]) ~= "number" or v[2] > group_n) then + return "reference to a non-existent or invalid subpattern"; + elseif v[1] == "recurmatch" and v[2] ~= 0 then + for i1, v1 in ipairs(outln) do + if type(v1) == "table" and v1[1] == 0x28 and v1[2] == v[2] then + v[3] = i1; + break; + end; + end; + elseif type(v[2]) == "string" then + v[2] = group_id[v[2]]; + end; + end; + end; + outln.group_n = max_group_n; + return outln, group_id, verb_flags; +end; + +if not tonumber(options.cacheSize) then + error(string.format("expected number for options.cacheSize, got %s", typeof(options.cacheSize)), 2); +end; +local cacheSize = math.floor(options.cacheSize or 0) ~= 0 and tonumber(options.cacheSize); +local cache_pattern, cache_pattern_names; +if not cacheSize then +elseif cacheSize < 0 or cacheSize ~= cacheSize then + error("cache size cannot be a negative number or a NaN", 2); +elseif cacheSize == math.huge then + cache_pattern, cache_pattern_names = { nil }, { nil }; +elseif cacheSize >= 2 ^ 32 then + error("cache size too large", 2); +else + cache_pattern, cache_pattern_names = table.create(options.cacheSize), table.create(options.cacheSize); +end; +if cacheSize then + function re.pruge() + table.clear(cache_pattern_names); + table.clear(cache_pattern); + end; +end; + +local function new_re(str_arr, flags, flag_repr, pattern_repr) + local tokenized_ptn, group_id, verb_flags; + local cache_format = cacheSize and string.format("%s|%s", str_arr.s, flag_repr); + local cached_token = cacheSize and cache_pattern[table.find(cache_pattern_names, cache_format)]; + if cached_token then + tokenized_ptn, group_id, verb_flags = table.unpack(cached_token, 1, 3); + else + tokenized_ptn, group_id, verb_flags = tokenize_ptn(str_arr, flags); + if type(tokenized_ptn) == "string" then + error(tokenized_ptn, 2); + end; + if cacheSize and tokenized_ptn[1] then + table.insert(cache_pattern_names, 1, cache_format); + table.insert(cache_pattern, 1, { tokenized_ptn, group_id, verb_flags }); + if cacheSize ~= math.huge then + table.remove(cache_pattern_names, cacheSize + 1); + table.remove(cache_pattern, cacheSize + 1); + end; + end; + end; + + local object = newproxy(true); + proxy[object] = { name = "RegEx", flags = flags, flag_repr = flag_repr, pattern_repr = pattern_repr, token = tokenized_ptn, group_id = group_id, verb_flags = verb_flags }; + local object_mt = getmetatable(object); + object_mt.__index = setmetatable(flags, re_m); + object_mt.__tostring = re_tostr; + object_mt.__metatable = lockmsg; + + return object; +end; + +local function escape_fslash(pre) + return (#pre % 2 == 0 and '\\' or '') .. pre .. '.'; +end; + +local function sort_flag_chr(a, b) + return a:lower() < b:lower(); +end; + +function re.new(...) + if select('#', ...) == 0 then + error("missing argument #1 (string expected)", 2); + end; + local ptn, flags_str = ...; + if type(ptn) == "number" then + ptn ..= ''; + elseif type(ptn) ~= "string" then + error(string.format("invalid argument #1 (string expected, got %s)", typeof(ptn)), 2); + end; + if type(flags_str) ~= "string" and type(flags_str) ~= "number" and flags_str ~= nil then + error(string.format("invalid argument #2 (string expected, got %s)", typeof(flags_str)), 2); + end; + + local flags = { + anchored = false, caseless = false, multiline = false, dotall = false, unicode = false, ungreedy = false, extended = false, + }; + local flag_repr = { }; + for f in string.gmatch(flags_str or '', utf8.charpattern) do + if flags[flag_map[f]] ~= false then + error("invalid regular expression flag " .. f, 3); + end; + flags[flag_map[f]] = true; + table.insert(flag_repr, f); + end; + table.sort(flag_repr, sort_flag_chr); + flag_repr = table.concat(flag_repr); + return new_re(to_str_arr(ptn), flags, flag_repr, string.format("/%s/", ptn:gsub("(\\*)/", escape_fslash))); +end; + +function re.fromstring(...) + if select('#', ...) == 0 then + error("missing argument #1 (string expected)", 2); + end; + local ptn = ...; + if type(ptn) == "number" then + ptn ..= ''; + elseif type(ptn) ~= "string" then + error(string.format("invalid argument #1 (string expected, got %s)", typeof(ptn), 2)); + end; + local str_arr = to_str_arr(ptn); + local delimiter = str_arr[1]; + if not delimiter then + error("empty regex", 2); + elseif delimiter == 0x5C or (delimiter >= 0x30 and delimiter <= 0x39) or (delimiter >= 0x41 and delimiter <= 0x5A) or (delimiter >= 0x61 and delimiter <= 0x7A) then + error("delimiter must not be alphanumeric or a backslash", 2); + end; + + local i0 = 1; + repeat + i0 = table.find(str_arr, delimiter, i0 + 1); + if not i0 then + error(string.format("no ending delimiter ('%s') found", utf8.char(delimiter)), 2); + end; + local escape_count = 1; + while str_arr[i0 - escape_count] == 0x5C do + escape_count += 1; + end; + until escape_count % 2 == 1; + + local flags = { + anchored = false, caseless = false, multiline = false, dotall = false, unicode = false, ungreedy = false, extended = false, + }; + local flag_repr = { }; + while str_arr.n > i0 do + local f = utf8.char(table.remove(str_arr)); + str_arr.n -= 1; + if flags[flag_map[f]] ~= false then + error("invalid regular expression flag " .. f, 3); + end; + flags[flag_map[f]] = true; + table.insert(flag_repr, f); + end; + table.sort(flag_repr, sort_flag_chr); + flag_repr = table.concat(flag_repr); + table.remove(str_arr, 1); + table.remove(str_arr); + str_arr.n -= 2; + str_arr.s = string.sub(str_arr.s, 2, 1 + str_arr.n); + return new_re(str_arr, flags, flag_repr, string.sub(ptn, 1, 2 + str_arr.n)); +end; + +local re_escape_line_chrs = { + ['\0'] = '\\x00', ['\n'] = '\\n', ['\t'] = '\\t', ['\r'] = '\\r', ['\f'] = '\\f', +}; + +function re.escape(...) + if select('#', ...) == 0 then + error("missing argument #1 (string expected)", 2); + end; + local str, extended, delimiter = ...; + if type(str) == "number" then + str ..= ''; + elseif type(str) ~= "string" then + error(string.format("invalid argument #1 to 'escape' (string expected, got %s)", typeof(str)), 2); + end; + if delimiter == nil then + delimiter = ''; + elseif type(delimiter) == "number" then + delimiter ..= ''; + elseif type(delimiter) ~= "string" then + error(string.format("invalid argument #3 to 'escape' (string expected, got %s)", typeof(delimiter)), 2); + end; + if utf8.len(delimiter) > 1 or delimiter:match("^[%a\\]$") then + error("delimiter have not be alphanumeric", 2); + end; + return (string.gsub(str, "[\0\f\n\r\t]", re_escape_line_chrs):gsub(string.format("[\\%s#()%%%%*+.?[%%]^{|%s]", extended and '%s' or '', (delimiter:find'^[%%%]]$' and '%' or '') .. delimiter), "\\%1")); +end; + +function re.type(...) + if select('#', ...) == 0 then + error("missing argument #1", 2); + end; + return proxy[...] and proxy[...].name; +end; + +for k, f in pairs(re_m) do + re[k] = f; +end; + +re_m = { __index = re_m }; + +lockmsg = re.fromstring([[/The\s*metatable\s*is\s*(?:locked|inaccessible)(?#Nice try :])/i]]); +getmetatable(lockmsg).__metatable = lockmsg; + +local function readonly_table() + error("Attempt to modify a readonly table", 2); +end; + +match_m = { + __index = match_m, + __metatable = lockmsg, + __newindex = readonly_table, +}; + +re.Match = setmetatable({ }, match_m); + +return setmetatable({ }, { + __index = re, + __metatable = lockmsg, + __newindex = readonly_table, +}); diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 7f863c6..15813ae 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -213,6 +213,16 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfLea") SINGLE_COMPARE(lea(rax, qword[r13 + r12 * 4 + 4]), 0x4b, 0x8d, 0x44, 0xa5, 0x04); } +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfAbsoluteJumps") +{ + SINGLE_COMPARE(jmp(rax), 0x48, 0xff, 0xe0); + SINGLE_COMPARE(jmp(r14), 0x49, 0xff, 0xe6); + SINGLE_COMPARE(jmp(qword[r14 + rdx * 4]), 0x49, 0xff, 0x24, 0x96); + SINGLE_COMPARE(call(rax), 0x48, 0xff, 0xd0); + SINGLE_COMPARE(call(r14), 0x49, 0xff, 0xd6); + SINGLE_COMPARE(call(qword[r14 + rdx * 4]), 0x49, 0xff, 0x14, 0x96); +} + TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow") { // Jump back @@ -260,6 +270,23 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "ControlFlow") {0xe9, 0x04, 0x00, 0x00, 0x00, 0x48, 0x83, 0xe7, 0x3e}); } +TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "LabelCall") +{ + check( + [](AssemblyBuilderX64& build) { + Label fnB; + + build.and_(rcx, 0x3e); + build.call(fnB); + build.ret(); + + build.setLabel(fnB); + build.lea(rax, qword[rcx + 0x1f]); + build.ret(); + }, + {0x48, 0x83, 0xe1, 0x3e, 0xe8, 0x01, 0x00, 0x00, 0x00, 0xc3, 0x48, 0x8d, 0x41, 0x1f, 0xc3}); +} + TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms") { SINGLE_COMPARE(vaddpd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0xa9, 0x58, 0xc6); diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index f001750..6ec1426 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -105,4 +105,37 @@ if true then REQUIRE(parentStat->is()); } +TEST_CASE_FIXTURE(Fixture, "ac_ast_ancestry_at_number_const") +{ + check(R"( +print(3.) + )"); + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*getMainSourceModule(), Position(1, 8)); + REQUIRE_GE(ancestry.size(), 2); + REQUIRE(ancestry.back()->is()); +} + +TEST_CASE_FIXTURE(Fixture, "ac_ast_ancestry_in_workspace_dot") +{ + check(R"( +print(workspace.) + )"); + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*getMainSourceModule(), Position(1, 16)); + REQUIRE_GE(ancestry.size(), 2); + REQUIRE(ancestry.back()->is()); +} + +TEST_CASE_FIXTURE(Fixture, "ac_ast_ancestry_in_workspace_colon") +{ + check(R"( +print(workspace:) + )"); + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*getMainSourceModule(), Position(1, 16)); + REQUIRE_GE(ancestry.size(), 2); + REQUIRE(ancestry.back()->is()); +} + TEST_SUITE_END(); diff --git a/tests/Fixture.h b/tests/Fixture.h index 1bc573d..4bd6f1e 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -128,6 +128,7 @@ struct Fixture std::optional lookupImportedType(const std::string& moduleAlias, const std::string& name); ScopedFastFlag sff_DebugLuauFreezeArena; + ScopedFastFlag sff_UnknownNever{"LuauUnknownAndNeverType", true}; TestFileResolver fileResolver; TestConfigResolver configResolver; diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index b9c2470..4f33139 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -791,6 +791,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "discard_type_graphs") CHECK_EQ(0, module->internalTypes.typeVars.size()); CHECK_EQ(0, module->internalTypes.typePacks.size()); CHECK_EQ(0, module->astTypes.size()); + CHECK_EQ(0, module->astResolvedTypes.size()); + CHECK_EQ(0, module->astResolvedTypePacks.size()); } TEST_CASE_FIXTURE(FrontendFixture, "it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded") diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 7c2f4d1..dd94e9d 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -301,8 +301,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") { - ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; - fileResolver.source["Module/A"] = R"( export type A = B type B = A diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index a474b6e..fb0a899 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -1055,8 +1055,6 @@ export type t1 = { a: typeof(string.byte) } TEST_CASE_FIXTURE(Fixture, "intersection_combine_on_bound_self") { - ScopedFastFlag luauNormalizeCombineEqFix{"LuauNormalizeCombineEqFix", true}; - CheckResult result = check(R"( export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,})) )"); @@ -1064,6 +1062,46 @@ export type t0 = (((any)&({_:l0.t0,n0:t0,_G:any,}))&({_:any,}))&(((any)&({_:l0.t LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "normalize_unions_containing_never") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + type Foo = string | never + local foo: Foo + )"); + + CHECK_EQ("string", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "normalize_unions_containing_unknown") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + type Foo = string | unknown + local foo: Foo + )"); + + CHECK_EQ("unknown", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "any_wins_the_battle_over_unknown_in_unions") +{ + ScopedFastFlag sff{"LuauLowerBoundsCalculation", true}; + + CheckResult result = check(R"( + type Foo = unknown | any + local foo: Foo + + type Bar = any | unknown + local bar: Bar + )"); + + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "normalization_does_not_convert_ever") { ScopedFastFlag sff[]{ diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index c3c7599..c517853 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2648,7 +2648,6 @@ type Z = { a: string | T..., b: number } TEST_CASE_FIXTURE(Fixture, "recover_function_return_type_annotations") { - ScopedFastFlag sff{"LuauReturnTypeTokenConfusion", true}; ParseResult result = tryParse(R"( type Custom = { x: A, y: B, z: C } type Packed = { x: (A...) -> () } diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 387e07c..52a29bc 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -96,6 +96,37 @@ TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") //clang-format on } +TEST_CASE_FIXTURE(Fixture, "metatable") +{ + TypeVar table{TypeVariant(TableTypeVar())}; + TypeVar metatable{TypeVariant(TableTypeVar())}; + TypeVar mtv{TypeVariant(MetatableTypeVar{&table, &metatable})}; + CHECK_EQ("{ @metatable { }, { } }", toString(&mtv)); +} + +TEST_CASE_FIXTURE(Fixture, "named_metatable") +{ + TypeVar table{TypeVariant(TableTypeVar())}; + TypeVar metatable{TypeVariant(TableTypeVar())}; + TypeVar mtv{TypeVariant(MetatableTypeVar{&table, &metatable, "NamedMetatable"})}; + CHECK_EQ("NamedMetatable", toString(&mtv)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "named_metatable_toStringNamedFunction") +{ + CheckResult result = check(R"( + local function createTbl(): NamedMetatable + return setmetatable({}, {}) + end + type NamedMetatable = typeof(createTbl()) + )"); + + TypeId ty = requireType("createTbl"); + const FunctionTypeVar* ftv = get(follow(ty)); + REQUIRE(ftv); + CHECK_EQ("createTbl(): NamedMetatable", toStringNamedFunction("createTbl", *ftv)); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") { CheckResult result = check(R"( @@ -468,7 +499,7 @@ local function target(callback: nil) return callback(4, "hello") end )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("(nil) -> (*unknown*)", toString(requireType("target"))); + CHECK_EQ("(nil) -> ()", toString(requireType("target"))); } TEST_CASE_FIXTURE(Fixture, "toStringGenericPack") diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index b02a52b..d2ed9ae 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -583,7 +583,7 @@ TEST_CASE_FIXTURE(Fixture, "transpile_error_expr") auto names = AstNameTable{allocator}; ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); - CHECK_EQ("local a = (error-expr: f.%error-id%)-(error-expr)", transpileWithTypes(*parseResult.root)); + CHECK_EQ("local a = (error-expr: f:%error-id%)-(error-expr)", transpileWithTypes(*parseResult.root)); } TEST_CASE_FIXTURE(Fixture, "transpile_error_stat") diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index bc55940..4c5309e 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -94,7 +94,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("*unknown*", toString(requireType("a"))); + CHECK_EQ("", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") @@ -110,7 +110,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("*unknown*", toString(requireType("a"))); + CHECK_EQ("", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") @@ -225,7 +225,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") CHECK_EQ("unknown", err->name); - CHECK_EQ("*unknown*", toString(requireType("a"))); + CHECK_EQ("", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") @@ -234,7 +234,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") local a = Utility.Create "Foo" {} )"); - CHECK_EQ("*unknown*", toString(requireType("a"))); + CHECK_EQ("", toString(requireType("a"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 2f0266e..7d1bd6b 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -925,7 +925,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "assert_returns_false_and_string_iff_it_knows )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(nil) -> nil", toString(requireType("f"))); + CHECK_EQ("(nil) -> (never, ...never)", toString(requireType("f"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") @@ -952,7 +952,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("string", toString(requireType("b"))); CHECK_EQ("boolean", toString(requireType("c"))); - CHECK_EQ("*unknown*", toString(requireType("d"))); + CHECK_EQ("", toString(requireType("d"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") @@ -965,8 +965,8 @@ a:b() a:b({}) )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 0}, {2, 5}}, CountMismatch{2, 0}})); - CHECK_EQ(result.errors[1], (TypeError{Location{{3, 0}, {3, 5}}, CountMismatch{2, 1}})); + CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function expects 2 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function expects 2 arguments, but only 1 is specified"); } TEST_CASE_FIXTURE(Fixture, "typeof_unresolved_function") @@ -1008,4 +1008,139 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c = string.gmatch("This is a string", "(.()(%a+))")() + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types2") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c = ("This is a string"):gmatch("(.()(%a+))")() + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); + CHECK_EQ(toString(requireType("c")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_default_capture") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c, d = string.gmatch("T(his)() is a string", ".")() + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 1); + CHECK_EQ(acm->actual, 4); + + CHECK_EQ(toString(requireType("a")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_balanced_escaped_parens") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c, d = string.gmatch("T(his) is a string", "((.)%b()())")() + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 3); + CHECK_EQ(acm->actual, 4); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "string"); + CHECK_EQ(toString(requireType("c")), "number"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_parens_in_sets_are_ignored") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b, c = string.gmatch("T(his)() is a string", "(T[()])()")() + )END"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 2); + CHECK_EQ(acm->actual, 3); + + CHECK_EQ(toString(requireType("a")), "string"); + CHECK_EQ(toString(requireType("b")), "number"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_set_containing_lbracket") +{ + ScopedFastFlag sffs{"LuauDeduceGmatchReturnTypes", true}; + CheckResult result = check(R"END( + local a, b = string.gmatch("[[[", "()([[])")() + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "number"); + CHECK_EQ(toString(requireType("b")), "string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_leading_end_bracket_is_part_of_set") +{ + CheckResult result = check(R"END( + -- An immediate right-bracket following a left-bracket is included within the set; + -- thus, '[]]'' is the set containing ']', and '[]' is an invalid set missing an enclosing + -- right-bracket. We detect an invalid set in this case and fall back to to default gmatch + -- typing. + local foo = string.gmatch("T[hi%]s]]]() is a string", "([]s)") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("foo")), "() -> (...string)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_invalid_pattern_fallback_to_builtin") +{ + CheckResult result = check(R"END( + local foo = string.gmatch("T(his)() is a string", ")") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("foo")), "() -> (...string)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_invalid_pattern_fallback_to_builtin2") +{ + CheckResult result = check(R"END( + local foo = string.gmatch("T(his)() is a string", "[") + )END"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("foo")), "() -> (...string)"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 401a6c6..6e6549d 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -916,13 +916,13 @@ TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language") REQUIRE(tm1); CHECK_EQ("(string) -> number", toString(tm1->wantedType)); - CHECK_EQ("(string, *unknown*) -> number", toString(tm1->givenType)); + CHECK_EQ("(string, ) -> number", toString(tm1->givenType)); auto tm2 = get(result.errors[1]); REQUIRE(tm2); CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType)); - CHECK_EQ("(string, *unknown*) -> number", toString(tm2->givenType)); + CHECK_EQ("(string, ) -> number", toString(tm2->givenType)); } TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") @@ -1535,7 +1535,7 @@ function t:b() return 2 end -- not OK )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type '(*unknown*) -> number' could not be converted into '() -> number' + CHECK_EQ(R"(Type '() -> number' could not be converted into '() -> number' caused by: Argument count mismatch. Function expects 1 argument, but none are specified)", toString(result.errors[0])); @@ -1692,4 +1692,52 @@ TEST_CASE_FIXTURE(Fixture, "call_o_with_another_argument_after_foo_was_quantifie // TODO: check the normalized type of f } +TEST_CASE_FIXTURE(Fixture, "free_is_not_bound_to_unknown") +{ + CheckResult result = check(R"( + local function foo(f: (unknown) -> (), x) + f(x) + end + )"); + + CHECK_EQ("((unknown) -> (), a) -> ()", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_infer_parameter_types_for_functions_from_their_call_site") +{ + CheckResult result = check(R"( + local t = {} + + function t.f(x) + return x + end + + t.__index = t + + function g(s) + local q = s.p and s.p.q or nil + return q and t.f(q) or nil + end + + local f = t.f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(a) -> a", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "dont_mutate_the_underlying_head_of_typepack_when_calling_with_self") +{ + CheckResult result = check(R"( + local t = {} + function t:m(x) end + function f(): never return 5 :: never end + t:m(f()) + t:m(f()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index e9e94cf..4625807 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -9,6 +9,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauCheckGenericHOFTypes) + using namespace Luau; TEST_SUITE_BEGIN("GenericsTests"); @@ -1001,7 +1003,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK_EQ("*unknown*", toString(t0->type)); + CHECK_EQ("", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -1095,10 +1097,18 @@ local b = sumrec(sum) -- ok local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not inferred )"); - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " - "parameters", - toString(result.errors[0])); + if (FFlag::LuauCheckGenericHOFTypes) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ( + "Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " + "parameters", + toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") @@ -1185,4 +1195,23 @@ TEST_CASE_FIXTURE(Fixture, "quantify_functions_even_if_they_have_an_explicit_gen CHECK("((X) -> (a...), X) -> (a...)" == toString(requireType("foo"))); } +TEST_CASE_FIXTURE(Fixture, "do_not_always_instantiate_generic_intersection_types") +{ + ScopedFastFlag sff[] = { + {"LuauMaybeGenericIntersectionTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type Array = { [number]: T } + + type Array_Statics = { + new: () -> Array, + } + + local _Arr : Array & Array_Statics = {} :: Array_Statics + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 1c6fe1d..56b807f 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -142,7 +142,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") CHECK_EQ(2, result.errors.size()); TypeId p = requireType("p"); - CHECK_EQ("*unknown*", toString(p)); + CHECK_EQ("", toString(p)); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index a0f670f..2343a7f 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -143,7 +143,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require_module_that_does_not_export") auto hootyType = requireType(bModule, "Hooty"); - CHECK_EQ("*unknown*", toString(hootyType)); + CHECK_EQ("", toString(hootyType)); } TEST_CASE_FIXTURE(BuiltinsFixture, "warn_if_you_try_to_require_a_non_modulescript") @@ -244,7 +244,7 @@ local ModuleA = require(game.A) LUAU_REQUIRE_NO_ERRORS(result); std::optional oty = requireType("ModuleA"); - CHECK_EQ("*unknown*", toString(*oty)); + CHECK_EQ("", toString(*oty)); } TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types") @@ -302,6 +302,30 @@ type Rename = typeof(x.x) LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_4") +{ + fileResolver.source["game/A"] = R"( +export type Array = {T} +local arrayops = {} +function arrayops.foo(x: Array) end +return arrayops + )"; + + CheckResult result = check(R"( +local arrayops = require(game.A) + +local tbl = {} +tbl.a = 2 +function tbl:foo(b: number, c: number) + -- introduce BoundTypeVar to imported type + arrayops.foo(self._regions) +end +type Table = typeof(tbl) +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict") { fileResolver.source["game/A"] = R"( @@ -363,4 +387,21 @@ caused by: Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "constrained_anyification_clone_immutable_types") +{ + ScopedFastFlag luauAnyificationMustClone{"LuauAnyificationMustClone", true}; + + fileResolver.source["game/A"] = R"( +return function(...) end + )"; + + fileResolver.source["game/B"] = R"( +local l0 = require(game.A) +return l0 + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index e6174df..c90f0a4 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -871,4 +871,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "equality_operations_succeed_if_any_union_bra CHECK(toString(result2.errors[0]) == "Types Foo and Bar cannot be compared with == because they do not have the same metatable"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_and") +{ + ScopedFastFlag sff{"LuauBinaryNeedsExpectedTypesToo", true}; + + CheckResult result = check(R"( + local x: "a" | "b" | boolean = math.random() > 0.5 and "a" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "expected_types_through_binary_or") +{ + ScopedFastFlag sff{"LuauBinaryNeedsExpectedTypesToo", true}; + + CheckResult result = check(R"( + local x: "a" | "b" | boolean = math.random() > 0.5 or "b" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index e1684df..9e8e250 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -47,7 +47,7 @@ TEST_CASE_FIXTURE(Fixture, "string_index") REQUIRE(nat); CHECK_EQ("string", toString(nat->ty)); - CHECK_EQ("*unknown*", toString(requireType("t"))); + CHECK_EQ("", toString(requireType("t"))); } TEST_CASE_FIXTURE(Fixture, "string_method") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 059aed2..dc68689 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -225,7 +225,7 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_x_not_equal_to_nil") CHECK_EQ("{| x: nil, y: nil |} | {| x: string, y: number |}", toString(requireTypeAtPosition({7, 28}))); } -TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) +TEST_CASE_FIXTURE(BuiltinsFixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) { ScopedFastInt sffi{"LuauTarjanChildLimit", 1}; ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 1}; @@ -499,6 +499,17 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); } +TEST_CASE_FIXTURE(Fixture, "free_is_not_bound_to_any") +{ + CheckResult result = check(R"( + local function foo(f: (any) -> (), x) + f(x) + end + )"); + + CHECK_EQ("((any) -> (), any) -> ()", toString(requireType("foo"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_function_with_no_returns") { ScopedFastFlag sff{"DebugLuauSharedSelf", true}; @@ -518,7 +529,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_f )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Not all codepaths in this function return '{ @metatable T, {| |} }, a...'.", toString(result.errors[0])); + CHECK_EQ("Not all codepaths in this function return 'self, a...'.", toString(result.errors[0])); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 3f5dad3..cc8cdee 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -272,8 +272,8 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({8, 44}))); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({9, 38}))); + CHECK_EQ("never", toString(requireTypeAtPosition({8, 44}))); + CHECK_EQ("never", toString(requireTypeAtPosition({9, 38}))); } TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") @@ -526,7 +526,7 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("", toString(requireTypeAtPosition({3, 28}))); } TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") @@ -651,7 +651,7 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } -TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_narrowed_into_nothingness") { CheckResult result = check(R"( local function f(t: {x: number}) @@ -666,7 +666,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_warns_on_no_overlapping_types_onl LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); } TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b") @@ -1074,7 +1074,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" + CHECK_EQ("never", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" } @@ -1206,6 +1206,24 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns") +{ + CheckResult result = check(R"( + local function f(x: unknown) + if type(x) == "string" then + local foo = x + else + local bar = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("unknown", toString(requireTypeAtPosition({5, 28}))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil") { ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true}; @@ -1227,4 +1245,19 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_ni CHECK_EQ("number", toString(requireTypeAtPosition({6, 28}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "what_nonsensical_condition") +{ + CheckResult result = check(R"( + local function f(x) + if type(x) == "string" and type(x) == "number" then + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index eead5b3..3e830f2 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -3070,4 +3070,18 @@ TEST_CASE_FIXTURE(Fixture, "quantify_even_that_table_was_never_exported_at_all") CHECK_EQ("{| m: ({+ x: a, y: b +}) -> a, n: ({+ x: a, y: b +}) -> b |}", toString(requireType("T"), opts)); } +TEST_CASE_FIXTURE(BuiltinsFixture, "leaking_bad_metatable_errors") +{ + ScopedFastFlag luauIndexSilenceErrors{"LuauIndexSilenceErrors", true}; + + CheckResult result = check(R"( +local a = setmetatable({}, 1) +local b = a.x + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Metatable was not a table", toString(result.errors[0])); + CHECK_EQ("Type 'a' does not have key 'x'", toString(result.errors[1])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index efdfe0b..7d1fb56 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -238,10 +238,10 @@ TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") // TODO: Should we assert anything about these tests when DCR is being used? if (!FFlag::DebugLuauDeferredConstraintResolution) { - CHECK_EQ("*unknown*", toString(requireType("c"))); - CHECK_EQ("*unknown*", toString(requireType("d"))); - CHECK_EQ("*unknown*", toString(requireType("e"))); - CHECK_EQ("*unknown*", toString(requireType("f"))); + CHECK_EQ("", toString(requireType("c"))); + CHECK_EQ("", toString(requireType("d"))); + CHECK_EQ("", toString(requireType("e"))); + CHECK_EQ("", toString(requireType("f"))); } } @@ -622,7 +622,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK_EQ("*unknown*", toString(t0->type)); + CHECK_EQ("", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -1003,4 +1003,27 @@ TEST_CASE_FIXTURE(Fixture, "do_not_bind_a_free_table_to_a_union_containing_that_ )"); } +TEST_CASE_FIXTURE(Fixture, "types stored in astResolvedTypes") +{ + CheckResult result = check(R"( +type alias = typeof("hello") +local function foo(param: alias) +end + )"); + + auto node = findNodeAtPosition(*getMainSourceModule(), {2, 16}); + auto ty = lookupType("alias"); + REQUIRE(node); + REQUIRE(node->is()); + REQUIRE(ty); + + auto func = node->as(); + REQUIRE(func->args.size == 1); + + auto arg = *func->args.begin(); + auto annotation = arg->annotation; + + CHECK_EQ(*getMainModule()->astResolvedTypes.find(annotation), *ty); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 49deae7..d51a38f 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -121,7 +121,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("a", toString(requireType("a"))); - CHECK_EQ("*unknown*", toString(requireType("b"))); + CHECK_EQ("", toString(requireType("b"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") @@ -136,7 +136,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_con LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("a", toString(requireType("a"))); - CHECK_EQ("*unknown*", toString(requireType("b"))); + CHECK_EQ("", toString(requireType("b"))); CHECK_EQ("number", toString(requireType("c"))); } diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 2b48133..9491869 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -199,7 +199,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->key, "x"); - CHECK_EQ("*unknown*", toString(requireType("r"))); + CHECK_EQ("", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp new file mode 100644 index 0000000..bc742b0 --- /dev/null +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -0,0 +1,280 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferUnknownNever"); + +TEST_CASE_FIXTURE(Fixture, "string_subtype_and_unknown_supertype") +{ + CheckResult result = check(R"( + local function f(x: string) + local foo: unknown = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "unknown_subtype_and_string_supertype") +{ + CheckResult result = check(R"( + local function f(x: unknown) + local foo: string = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "unknown_is_reflexive") +{ + CheckResult result = check(R"( + local function f(x: unknown) + local foo: unknown = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_subtype_and_never_supertype") +{ + CheckResult result = check(R"( + local function f(x: string) + local foo: never = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "never_subtype_and_string_supertype") +{ + CheckResult result = check(R"( + local function f(x: never) + local foo: string = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "never_is_reflexive") +{ + CheckResult result = check(R"( + local function f(x: never) + local foo: never = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "unknown_is_optional_because_it_too_encompasses_nil") +{ + CheckResult result = check(R"( + local t: {x: unknown} = {} + )"); +} + +TEST_CASE_FIXTURE(Fixture, "table_with_prop_of_type_never_is_uninhabitable") +{ + CheckResult result = check(R"( + local t: {x: never} = {} + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "table_with_prop_of_type_never_is_also_reflexive") +{ + CheckResult result = check(R"( + local t: {x: never} = {x = 5 :: never} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "array_like_table_of_never_is_inhabitable") +{ + CheckResult result = check(R"( + local t: {never} = {} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_packs_containing_never_is_itself_uninhabitable") +{ + CheckResult result = check(R"( + local function f() return "foo", 5 :: never end + + local x, y, z = f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); + CHECK_EQ("never", toString(requireType("y"))); + CHECK_EQ("never", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(Fixture, "type_packs_containing_never_is_itself_uninhabitable2") +{ + CheckResult result = check(R"( + local function f(): (string, never) return "", 5 :: never end + local function g(): (never, string) return 5 :: never, "" end + + local x1, x2 = f() + local y1, y2 = g() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x1"))); + CHECK_EQ("never", toString(requireType("x2"))); + CHECK_EQ("never", toString(requireType("y1"))); + CHECK_EQ("never", toString(requireType("y2"))); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_never") +{ + CheckResult result = check(R"( + local x: never = 5 :: never + local z = x.y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(Fixture, "call_never") +{ + CheckResult result = check(R"( + local f: never = 5 :: never + local x, y, z = f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); + CHECK_EQ("never", toString(requireType("y"))); + CHECK_EQ("never", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_local_which_is_never") +{ + CheckResult result = check(R"( + local t: never + t = 3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_global_which_is_never") +{ + CheckResult result = check(R"( + --!nonstrict + t = 5 :: never + t = "" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_prop_which_is_never") +{ + CheckResult result = check(R"( + local t: never + t.x = 5 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never") +{ + CheckResult result = check(R"( + local t: never + t[5] = 7 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_to_subscript_which_is_never") +{ + CheckResult result = check(R"( + for i, v in (5 :: never) do + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "pick_never_from_variadic_type_pack") +{ + CheckResult result = check(R"( + local function f(...: never) + local x, y = (...) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_union_of_tables_for_properties_that_is_never") +{ + CheckResult result = check(R"( + type Disjoint = {foo: never, bar: unknown, tag: "ok"} | {foo: never, baz: unknown, tag: "err"} + local disjoint: Disjoint = {foo = 5 :: never, bar = true, tag = "ok"} + local foo = disjoint.foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_union_of_tables_for_properties_that_is_sorta_never") +{ + CheckResult result = check(R"( + type Disjoint = {foo: string, bar: unknown, tag: "ok"} | {foo: never, baz: unknown, tag: "err"} + local disjoint: Disjoint = {foo = 5 :: never, bar = true, tag = "ok"} + local foo = disjoint.foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "unary_minus_of_never") +{ + CheckResult result = check(R"( + local x = -(5 :: never) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "length_of_never") +{ + CheckResult result = check(R"( + local x = #({} :: never) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("never", toString(requireType("x"))); +} + +TEST_SUITE_END(); diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index 8a5a65f..35852c0 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -199,8 +199,6 @@ TEST_CASE_FIXTURE(TypePackFixture, "std_distance") TEST_CASE("content_reassignment") { - ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; - TypePackVar myError{Unifiable::Error{}, /*presistent*/ true}; TypeArena arena; diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 4f8fc50..f467004 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -418,8 +418,6 @@ TEST_CASE("proof_that_isBoolean_uses_all_of") TEST_CASE("content_reassignment") { - ScopedFastFlag luauNonCopyableTypeVarFields{"LuauNonCopyableTypeVarFields", true}; - TypeVar myAny{AnyTypeVar{}, /*presistent*/ true}; myAny.normal = true; myAny.documentationSymbol = "@global/any"; diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.lua index 22d6adf..6164e92 100644 --- a/tests/conformance/vector.lua +++ b/tests/conformance/vector.lua @@ -101,4 +101,20 @@ if vector_size == 4 then assert(vector(1, 2, 3, 4).W == 4) end +-- negative zero should hash the same as zero +-- note: our earlier test only really checks the low hash bit, so in absence of perfect avalanche it's insufficient +do + local larget = {} + for i = 1, 2^14 do + larget[vector(0, 0, i)] = true + end + + larget[vector(0, 0, 0)] = 42 + + assert(larget[vector(0, 0, 0)] == 42) + assert(larget[vector(0, 0, -0)] == 42) + assert(larget[vector(0, -0, 0)] == 42) + assert(larget[vector(-0, 0, 0)] == 42) +end + return 'OK' diff --git a/tools/natvis/VM.natvis b/tools/natvis/VM.natvis index ccc7e39..cb2f355 100644 --- a/tools/natvis/VM.natvis +++ b/tools/natvis/VM.natvis @@ -137,16 +137,28 @@ {data,s} + + + + + + + + + empty + none + + + {proto()->source->data,sb}:{line()} function {proto()->debugname->data,sb}() + {proto()->source->data,sb}:{line()} function() + + + =[C] function {cl().c.debugname,sb}() {cl().c.f,na} + =[C] {cl().c.f,na} + + - - {ci->func->value.gc->cl.c.f,na} - - - {ci->func->value.gc->cl.l.p->source->data,sb}:{ci->func->value.gc->cl.l.p->linedefined,d} {ci->func->value.gc->cl.l.p->debugname->data,sb} - - - {ci->func->value.gc->cl.l.p->source->data,sb}:{ci->func->value.gc->cl.l.p->linedefined,d} - + {ci,na} thread @@ -156,7 +168,7 @@ ci-base_ci - base_ci[ci-base_ci - $i].func->value.gc->cl,view(short) + base_ci[ci-base_ci - $i]