diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 8f17fff..57a1907 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -1,7 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "TypeInfer.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" namespace Luau { diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 946bc92..ac6f13e 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -120,6 +120,7 @@ struct IncorrectGenericParameterCount Name name; TypeFun typeFun; size_t actualParameters; + size_t actualPackParameters; bool operator==(const IncorrectGenericParameterCount& rhs) const; }; diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index 71f9464..a05ec5e 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -25,51 +25,39 @@ struct SourceCode Type type; }; +struct ModuleInfo +{ + ModuleName name; + bool optional = false; +}; + struct FileResolver { virtual ~FileResolver() {} - /** Fetch the source code associated with the provided ModuleName. - * - * FIXME: This requires a string copy! - * - * @returns The actual Lua code on success. - * @returns std::nullopt if no such file exists. When this occurs, type inference will report an UnknownRequire error. - */ virtual std::optional readSource(const ModuleName& name) = 0; - /** Does the module exist? - * - * Saves a string copy over reading the source and throwing it away. - */ - virtual bool moduleExists(const ModuleName& name) const = 0; + virtual std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) + { + return std::nullopt; + } - virtual std::optional fromAstFragment(AstExpr* expr) const = 0; - - /** Given a valid module name and a string of arbitrary data, figure out the concatenation. - */ - virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; - - /** Goes "up" a level in the hierarchy that the ModuleName represents. - * - * For instances, this is analogous to someInstance.Parent; for paths, this is equivalent to removing the last - * element of the path. Other ModuleName representations may have other ways of doing this. - * - * @returns The parent ModuleName, if one exists. - * @returns std::nullopt if there is no parent for this module name. - */ - virtual std::optional getParentModuleName(const ModuleName& name) const = 0; - - virtual std::optional getHumanReadableModuleName_(const ModuleName& name) const + virtual std::string getHumanReadableModuleName(const ModuleName& name) const { return name; } - virtual std::optional getEnvironmentForModule(const ModuleName& name) const = 0; + virtual std::optional getEnvironmentForModule(const ModuleName& name) const + { + return std::nullopt; + } - /** LanguageService only: - * std::optional fromInstance(Instance* inst) - */ + // DEPRECATED APIS + // These are going to be removed with LuauNewRequireTracer + virtual bool moduleExists(const ModuleName& name) const = 0; + virtual std::optional fromAstFragment(AstExpr* expr) const = 0; + virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; + virtual std::optional getParentModuleName(const ModuleName& name) const = 0; }; struct NullFileResolver : FileResolver @@ -94,10 +82,6 @@ struct NullFileResolver : FileResolver { return std::nullopt; } - std::optional getEnvironmentForModule(const ModuleName& name) const override - { - return std::nullopt; - } }; } // namespace Luau diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 413b68f..d084483 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -90,10 +90,12 @@ struct Module TypeArena internalTypes; std::vector> scopes; // never empty - std::unordered_map astTypes; - std::unordered_map astExpectedTypes; - std::unordered_map astOriginalCallTypes; - std::unordered_map astOverloadResolvedTypes; + + DenseHashMap astTypes{nullptr}; + DenseHashMap astExpectedTypes{nullptr}; + DenseHashMap astOriginalCallTypes{nullptr}; + DenseHashMap astOverloadResolvedTypes{nullptr}; + std::unordered_map declaredGlobals; ErrorVec errors; Mode mode; diff --git a/Analysis/include/Luau/ModuleResolver.h b/Analysis/include/Luau/ModuleResolver.h index a394a21..d892ccd 100644 --- a/Analysis/include/Luau/ModuleResolver.h +++ b/Analysis/include/Luau/ModuleResolver.h @@ -15,12 +15,6 @@ struct Module; using ModulePtr = std::shared_ptr; -struct ModuleInfo -{ - ModuleName name; - bool optional = false; -}; - struct ModuleResolver { virtual ~ModuleResolver() {} diff --git a/Analysis/include/Luau/RequireTracer.h b/Analysis/include/Luau/RequireTracer.h index e977887..c25545f 100644 --- a/Analysis/include/Luau/RequireTracer.h +++ b/Analysis/include/Luau/RequireTracer.h @@ -17,12 +17,11 @@ struct AstLocal; struct RequireTraceResult { - DenseHashMap exprs{0}; - DenseHashMap optional{0}; + DenseHashMap exprs{nullptr}; std::vector> requires; }; -RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName); +RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName); } // namespace Luau diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h new file mode 100644 index 0000000..4533840 --- /dev/null +++ b/Analysis/include/Luau/Scope.h @@ -0,0 +1,67 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" +#include "Luau/TypeVar.h" + +#include +#include +#include + +namespace Luau +{ + +struct Scope; + +using ScopePtr = std::shared_ptr; + +struct Binding +{ + TypeId typeId; + Location location; + bool deprecated = false; + std::string deprecatedSuggestion; + std::optional documentationSymbol; +}; + +struct Scope +{ + explicit Scope(TypePackId returnType); // root scope + explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr. + + const ScopePtr parent; // null for the root + std::unordered_map bindings; + TypePackId returnType; + bool breakOk = false; + std::optional varargPack; + + TypeLevel level; + + std::unordered_map exportedTypeBindings; + std::unordered_map privateTypeBindings; + std::unordered_map typeAliasLocations; + + std::unordered_map> importedTypeBindings; + + std::optional lookup(const Symbol& name); + + std::optional lookupType(const Name& name); + std::optional lookupImportedType(const Name& moduleAlias, const Name& name); + + std::unordered_map privateTypePackBindings; + std::optional lookupPack(const Name& name); + + // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) + std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true); + + RefinementMap refinements; + + // For mutually recursive type aliases, it's important that + // they use the same types for the same names. + // For instance, in `type Tree { data: T, children: Forest } type Forest = {Tree}` + // we need that the generic type `T` in both cases is the same, so we use a cache. + std::unordered_map typeAliasTypeParameters; + std::unordered_map typeAliasTypePackParameters; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 6ac868f..80a14e8 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -52,8 +52,6 @@ // `T`, and the type of `f` are in the same SCC, which is why `f` gets // replaced. -LUAU_FASTFLAG(DebugLuauTrackOwningArena) - namespace Luau { @@ -188,20 +186,12 @@ struct Substitution : FindDirty template TypeId addType(const T& tv) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(tv); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addType(tv); } template TypePackId addTypePack(const T& tp) { - TypePackId allocated = currentModule->internalTypes.typePacks.allocate(tp); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addTypePack(TypePackVar{tp}); } }; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index ec2a1a2..d701eb2 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -86,7 +86,10 @@ struct ApplyTypeFunction : Substitution { TypeLevel level; bool encounteredForwardedType; - std::unordered_map arguments; + std::unordered_map typeArguments; + std::unordered_map typePackArguments; + bool ignoreChildren(TypeId ty) override; + bool ignoreChildren(TypePackId tp) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; TypeId clean(TypeId ty) override; @@ -328,7 +331,8 @@ private: TypeId resolveType(const ScopePtr& scope, const AstType& annotation, bool canBeGeneric = false); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); - TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const Location& location); + TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, + const std::vector& typePackParams, const Location& location); // Note: `scope` must be a fresh scope. std::pair, std::vector> createGenericTypes( @@ -398,54 +402,6 @@ private: int recursionCount = 0; }; -struct Binding -{ - TypeId typeId; - Location location; - bool deprecated = false; - std::string deprecatedSuggestion; - std::optional documentationSymbol; -}; - -struct Scope -{ - explicit Scope(TypePackId returnType); // root scope - explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr. - - const ScopePtr parent; // null for the root - std::unordered_map bindings; - TypePackId returnType; - bool breakOk = false; - std::optional varargPack; - - TypeLevel level; - - std::unordered_map exportedTypeBindings; - std::unordered_map privateTypeBindings; - std::unordered_map typeAliasLocations; - - std::unordered_map> importedTypeBindings; - - std::optional lookup(const Symbol& name); - - std::optional lookupType(const Name& name); - std::optional lookupImportedType(const Name& moduleAlias, const Name& name); - - std::unordered_map privateTypePackBindings; - std::optional lookupPack(const Name& name); - - // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) - std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true); - - RefinementMap refinements; - - // For mutually recursive type aliases, it's important that - // they use the same types for the same names. - // For instance, in `type Tree { data: T, children: Forest } type Forest = {Tree}` - // we need that the generic type `T` in both cases is the same, so we use a cache. - std::unordered_map typeAliasParameters; -}; - // Unit test hook void setPrintLine(void (*pl)(const std::string& s)); void resetPrintLine(); diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 0d0adce..d987d46 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -117,7 +117,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); TypePackId follow(TypePackId tp); -size_t size(const TypePackId tp); +size_t size(TypePackId tp); +bool finite(TypePackId tp); size_t size(const TypePack& tp); std::optional first(TypePackId tp); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 2e028df..e04aa2c 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -228,6 +228,7 @@ struct TableTypeVar std::map methodDefinitionLocations; std::vector instantiatedTypeParams; + std::vector instantiatedTypePackParams; ModuleName definitionModuleName; std::optional boundTo; @@ -284,8 +285,9 @@ struct ClassTypeVar struct TypeFun { - /// These should all be generic + // These should all be generic std::vector typeParams; + std::vector typePackParams; /** The underlying type. * @@ -293,6 +295,20 @@ struct TypeFun * You must first use TypeChecker::instantiateTypeFun to turn it into a real type. */ TypeId type; + + TypeFun() = default; + TypeFun(std::vector typeParams, TypeId type) + : typeParams(std::move(typeParams)) + , type(type) + { + } + + TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) + : typeParams(std::move(typeParams)) + , typePackParams(std::move(typePackParams)) + , type(type) + { + } }; // Anything! All static checking is off. @@ -524,8 +540,4 @@ UnionTypeVarIterator end(const UnionTypeVar* utv); using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); -// TEMP: Clip this prototype with FFlag::LuauStringMetatable -std::optional> magicFunctionFormat( - struct TypeChecker& typechecker, const std::shared_ptr& scope, const AstExprCall& expr, ExprResult exprResult); - } // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 0ddc3cc..522914b 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -36,12 +36,17 @@ struct Unifier Variance variance = Covariant; CountMismatch::Context ctx = CountMismatch::Arg; - std::shared_ptr counters; + UnifierCounters* counters; + UnifierCounters countersData; + + std::shared_ptr counters_DEPRECATED; + InternalErrorReporter* iceHandler; Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler); Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters = nullptr); + Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED = nullptr, + UnifierCounters* counters = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId superTy, TypeId subTy); @@ -58,11 +63,13 @@ private: void tryUnifyPrimitives(TypeId superTy, TypeId subTy); void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); + void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); void tryUnifyFreeTable(TypeId free, TypeId other); void tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection); void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed); void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed); void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer); + TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); public: void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); @@ -80,9 +87,9 @@ private: public: // Report an "infinite type error" if the type "needle" already occurs within "haystack" void occursCheck(TypeId needle, TypeId haystack); - void occursCheck(std::unordered_set& seen, TypeId needle, TypeId haystack); + void occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack); void occursCheck(TypePackId needle, TypePackId haystack); - void occursCheck(std::unordered_set& seen, TypePackId needle, TypePackId haystack); + void occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack); Unifier makeChildUnifier(); @@ -93,6 +100,9 @@ private: [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); + + DenseHashSet tempSeenTy{nullptr}; + DenseHashSet tempSeenTp{nullptr}; }; } // namespace Luau diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index d3de175..0aed34c 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -2,6 +2,7 @@ #include "Luau/AstQuery.h" #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Luau/ToString.h" @@ -143,8 +144,8 @@ std::optional findTypeAtPosition(const Module& module, const SourceModul { if (auto expr = findExprAtPosition(sourceModule, pos)) { - if (auto it = module.astTypes.find(expr); it != module.astTypes.end()) - return it->second; + if (auto it = module.astTypes.find(expr)) + return *it; } return std::nullopt; @@ -154,8 +155,8 @@ std::optional findExpectedTypeAtPosition(const Module& module, const Sou { if (auto expr = findExprAtPosition(sourceModule, pos)) { - if (auto it = module.astExpectedTypes.find(expr); it != module.astExpectedTypes.end()) - return it->second; + if (auto it = module.astExpectedTypes.find(expr)) + return *it; } return std::nullopt; @@ -322,9 +323,9 @@ std::optional getDocumentationSymbolAtPosition(const Source TypeId matchingOverload = nullptr; if (parentExpr && parentExpr->is()) { - if (auto it = module.astOverloadResolvedTypes.find(parentExpr); it != module.astOverloadResolvedTypes.end()) + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) { - matchingOverload = it->second; + matchingOverload = *it; } } @@ -345,9 +346,9 @@ std::optional getDocumentationSymbolAtPosition(const Source { if (AstExprIndexName* indexName = targetExpr->as()) { - if (auto it = module.astTypes.find(indexName->expr); it != module.astTypes.end()) + if (auto it = module.astTypes.find(indexName->expr)) { - TypeId parentTy = follow(it->second); + TypeId parentTy = follow(*it); if (const TableTypeVar* ttv = get(parentTy)) { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 1cfa90d..b31d85a 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -210,10 +210,10 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ return TypeCorrectKind::None; auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) + if (!it) return TypeCorrectKind::None; - TypeId expectedType = follow(it->second); + TypeId expectedType = follow(*it); if (canUnify(expectedType, ty)) return TypeCorrectKind::Correct; @@ -682,10 +682,10 @@ static std::optional functionIsExpectedAt(const Module& module, AstNode* n return std::nullopt; auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) + if (!it) return std::nullopt; - TypeId expectedType = follow(it->second); + TypeId expectedType = follow(*it); if (const FunctionTypeVar* ftv = get(expectedType)) return true; @@ -784,9 +784,9 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi if (AstExprCall* exprCall = expr->as()) { - if (auto it = module.astTypes.find(exprCall->func); it != module.astTypes.end()) + if (auto it = module.astTypes.find(exprCall->func)) { - if (const FunctionTypeVar* ftv = get(follow(it->second))) + if (const FunctionTypeVar* ftv = get(follow(*it))) { if (auto ty = tryGetTypePackTypeAt(ftv->retType, tailPos)) inferredType = *ty; @@ -798,8 +798,8 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi if (tailPos != 0) break; - if (auto it = module.astTypes.find(expr); it != module.astTypes.end()) - inferredType = it->second; + if (auto it = module.astTypes.find(expr)) + inferredType = *it; } if (inferredType) @@ -815,10 +815,10 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionTypeVar* { auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) + if (!it) return nullptr; - TypeId ty = follow(it->second); + TypeId ty = follow(*it); if (const FunctionTypeVar* ftv = get(ty)) return ftv; @@ -1129,9 +1129,8 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul if (node->is()) { - auto it = module.astTypes.find(node->asExpr()); - if (it != module.astTypes.end()) - autocompleteProps(module, typeArena, it->second, PropIndexType::Point, ancestry, result); + if (auto it = module.astTypes.find(node->asExpr())) + autocompleteProps(module, typeArena, *it, PropIndexType::Point, ancestry, result); } else if (FFlag::LuauIfElseExpressionAnalysisSupport && autocompleteIfElseExpression(node, ancestry, position, result)) return; @@ -1203,13 +1202,13 @@ static std::optional getMethodContainingClass(const ModuleP return std::nullopt; } - auto parentIter = module->astTypes.find(parentExpr); - if (parentIter == module->astTypes.end()) + auto parentIt = module->astTypes.find(parentExpr); + if (!parentIt) { return std::nullopt; } - Luau::TypeId parentType = Luau::follow(parentIter->second); + Luau::TypeId parentType = Luau::follow(*parentIt); if (auto parentClass = Luau::get(parentType)) { @@ -1250,8 +1249,8 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - auto iter = module->astTypes.find(candidate->func); - if (iter == module->astTypes.end()) + auto it = module->astTypes.find(candidate->func); + if (!it) { return std::nullopt; } @@ -1267,7 +1266,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; }; - auto followedId = Luau::follow(iter->second); + auto followedId = Luau::follow(*it); if (auto functionType = Luau::get(followedId)) { return performCallback(functionType); @@ -1316,10 +1315,10 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (auto indexName = node->as()) { auto it = module->astTypes.find(indexName->expr); - if (it == module->astTypes.end()) + if (!it) return {}; - TypeId ty = follow(it->second); + TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; if (isString(ty)) @@ -1447,9 +1446,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M // If item doesn't have a key, maybe the value is actually the key if (key ? key == node : node->is() && value == node) { - if (auto it = module->astExpectedTypes.find(exprTable); it != module->astExpectedTypes.end()) + if (auto it = module->astExpectedTypes.find(exprTable)) { - auto result = autocompleteProps(*module, typeArena, it->second, PropIndexType::Key, finder.ancestry); + auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, finder.ancestry); // Remove keys that are already completed for (const auto& item : exprTable->items) @@ -1485,9 +1484,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) { - if (auto it = module->astTypes.find(idxExpr->expr); it != module->astTypes.end()) + if (auto it = module->astTypes.find(idxExpr->expr)) { - return {autocompleteProps(*module, typeArena, follow(it->second), PropIndexType::Point, finder.ancestry), finder.ancestry}; + return {autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry), finder.ancestry}; } } } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 68ad5ac..3b0c216 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -11,7 +11,7 @@ LUAU_FASTFLAG(LuauParseGenericFunctions) LUAU_FASTFLAG(LuauGenericFunctions) LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauStringMetatable) +LUAU_FASTFLAG(LuauNewRequireTrace) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -218,7 +218,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypePackId anyTypePack = typeChecker.anyTypePack; TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}}); - TypePackId stringVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{stringType}}); TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList}); TypeId listOfAtLeastOneNumberToNumberType = arena.addType(FunctionTypeVar{ @@ -255,85 +254,18 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId genericV = arena.addType(GenericTypeVar{"V"}); TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level}); - if (FFlag::LuauStringMetatable) + std::optional stringMetatableTy = getMetatable(singletonTypes.stringType); + LUAU_ASSERT(stringMetatableTy); + const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); + LUAU_ASSERT(stringMetatableTable); + + auto it = stringMetatableTable->props.find("__index"); + LUAU_ASSERT(it != stringMetatableTable->props.end()); + + addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); + + if (!FFlag::LuauParseGenericFunctions || !FFlag::LuauGenericFunctions) { - std::optional stringMetatableTy = getMetatable(singletonTypes.stringType); - LUAU_ASSERT(stringMetatableTy); - const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); - LUAU_ASSERT(stringMetatableTable); - - auto it = stringMetatableTable->props.find("__index"); - LUAU_ASSERT(it != stringMetatableTable->props.end()); - - TypeId stringLib = it->second.type; - addGlobalBinding(typeChecker, "string", stringLib, "@luau"); - } - - if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions) - { - if (!FFlag::LuauStringMetatable) - { - TypeId stringLibTy = getGlobalBinding(typeChecker, "string"); - TableTypeVar* stringLib = getMutable(stringLibTy); - TypeId replArgType = makeUnion( - arena, {stringType, - arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)), - makeFunction(arena, std::nullopt, {stringType}, {stringType})}); - TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType}); - - stringLib->props["gsub"] = makeProperty(gsubFunc, "@luau/global/string.gsub"); - } - } - else - { - if (!FFlag::LuauStringMetatable) - { - TypeId stringToStringType = makeFunction(arena, std::nullopt, {stringType}, {stringType}); - - TypeId gmatchFunc = makeFunction(arena, stringType, {stringType}, {arena.addType(FunctionTypeVar{emptyPack, stringVariadicList})}); - - TypeId replArgType = makeUnion( - arena, {stringType, - arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)), - makeFunction(arena, std::nullopt, {stringType}, {stringType})}); - TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType}); - - TypeId formatFn = arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}); - - TableTypeVar::Props stringLib = { - // FIXME string.byte "can" return a pack of numbers, but only if 2nd or 3rd arguments were supplied - {"byte", {makeFunction(arena, stringType, {optionalNumber, optionalNumber}, {optionalNumber})}}, - // FIXME char takes a variadic pack of numbers - {"char", {makeFunction(arena, std::nullopt, {numberType, optionalNumber, optionalNumber, optionalNumber}, {stringType})}}, - {"find", {makeFunction(arena, stringType, {stringType, optionalNumber, optionalBoolean}, {optionalNumber, optionalNumber})}}, - {"format", {formatFn}}, // FIXME - {"gmatch", {gmatchFunc}}, - {"gsub", {gsubFunc}}, - {"len", {makeFunction(arena, stringType, {}, {numberType})}}, - {"lower", {stringToStringType}}, - {"match", {makeFunction(arena, stringType, {stringType, optionalNumber}, {optionalString})}}, - {"rep", {makeFunction(arena, stringType, {numberType}, {stringType})}}, - {"reverse", {stringToStringType}}, - {"sub", {makeFunction(arena, stringType, {numberType, optionalNumber}, {stringType})}}, - {"upper", {stringToStringType}}, - {"split", {makeFunction(arena, stringType, {stringType, optionalString}, - {arena.addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, typeChecker.globalScope->level})})}}, - {"pack", {arena.addType(FunctionTypeVar{ - arena.addTypePack(TypePack{{stringType}, anyTypePack}), - oneStringPack, - })}}, - {"packsize", {makeFunction(arena, stringType, {}, {numberType})}}, - {"unpack", {arena.addType(FunctionTypeVar{ - arena.addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - anyTypePack, - })}}, - }; - - assignPropDocumentationSymbols(stringLib, "@luau/global/string"); - addGlobalBinding(typeChecker, "string", - arena.addType(TableTypeVar{stringLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - } - TableTypeVar::Props debugLib{ {"info", {makeIntersection(arena, { @@ -601,9 +533,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) auto tableLib = getMutable(getGlobalBinding(typeChecker, "table")); attachMagicFunction(tableLib->props["pack"].type, magicFunctionPack); - auto stringLib = getMutable(getGlobalBinding(typeChecker, "string")); - attachMagicFunction(stringLib->props["format"].type, magicFunctionFormat); - attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); } @@ -791,11 +720,11 @@ static std::optional> magicFunctionRequire( return std::nullopt; } - AstExpr* require = expr.args.data[0]; - - if (!checkRequirePath(typechecker, require)) + if (!checkRequirePath(typechecker, expr.args.data[0])) return std::nullopt; + const AstExpr* require = FFlag::LuauNewRequireTrace ? &expr : expr.args.data[0]; + if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require)) return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 61a63f0..1e91561 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -206,27 +206,6 @@ std::string getBuiltinDefinitionSource() graphemes: (string, number?, number?) -> (() -> (number, number)), } - declare string: { - byte: (string, number?, number?) -> ...number, - char: (number, ...number) -> string, - find: (string, string, number?, boolean?) -> (number?, number?), - -- `string.format` has a magic function attached that will provide more type information for literal format strings. - format: (string, A...) -> string, - gmatch: (string, string) -> () -> (...string), - -- gsub is defined in C++ because we don't have syntax for describing a generic table. - len: (string) -> number, - lower: (string) -> string, - match: (string, string, number?) -> string?, - rep: (string, number) -> string, - reverse: (string) -> string, - sub: (string, number, number?) -> string, - upper: (string) -> string, - split: (string, string, string?) -> {string}, - pack: (string, A...) -> string, - packsize: (string) -> number, - unpack: (string, string, number?) -> R..., - } - -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. declare function unpack(tab: {V}, i: number?, j: number?): ...V )"; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index a622f8a..04d9144 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,9 +7,9 @@ #include -LUAU_FASTFLAG(LuauFasterStringifier) +LUAU_FASTFLAG(LuauTypeAliasPacks) -static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, bool isTypeArgs = false) +static std::string wrongNumberOfArgsString_DEPRECATED(size_t expectedCount, size_t actualCount, bool isTypeArgs = false) { std::string s = "expects " + std::to_string(expectedCount) + " "; @@ -41,6 +41,52 @@ static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCo return s; } +static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) +{ + std::string s; + + if (FFlag::LuauTypeAliasPacks) + { + s = "expects "; + + if (isVariadic) + s += "at least "; + + s += std::to_string(expectedCount) + " "; + } + else + { + s = "expects " + std::to_string(expectedCount) + " "; + } + + if (argPrefix) + s += std::string(argPrefix) + " "; + + s += "argument"; + if (expectedCount != 1) + s += "s"; + + s += ", but "; + + if (actualCount == 0) + { + s += "none"; + } + else + { + if (actualCount < expectedCount) + s += "only "; + + s += std::to_string(actualCount); + } + + s += (actualCount == 1) ? " is" : " are"; + + s += " specified"; + + return s; +} + namespace Luau { @@ -127,7 +173,10 @@ struct ErrorConverter return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) + " are required here"; case CountMismatch::Arg: - return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); + if (FFlag::LuauTypeAliasPacks) + return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); + else + return "Argument count mismatch. Function " + wrongNumberOfArgsString_DEPRECATED(e.expected, e.actual); } LUAU_ASSERT(!"Unknown context"); @@ -159,13 +208,16 @@ struct ErrorConverter std::string operator()(const Luau::UnknownRequire& e) const { - return "Unknown require: " + e.modulePath; + if (e.modulePath.empty()) + return "Unknown require: unsupported path"; + else + return "Unknown require: " + e.modulePath; } std::string operator()(const Luau::IncorrectGenericParameterCount& e) const { std::string name = e.name; - if (!e.typeFun.typeParams.empty()) + if (!e.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !e.typeFun.typePackParams.empty())) { name += "<"; bool first = true; @@ -178,10 +230,37 @@ struct ErrorConverter name += toString(t); } + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId t : e.typeFun.typePackParams) + { + if (first) + first = false; + else + name += ", "; + + name += toString(t); + } + } + name += ">"; } - return "Generic type '" + name + "' " + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true); + if (FFlag::LuauTypeAliasPacks) + { + if (e.typeFun.typeParams.size() != e.actualParameters) + return "Generic type '" + name + "' " + + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty()); + + return "Generic type '" + name + "' " + + wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false); + } + else + { + return "Generic type '" + name + "' " + + wrongNumberOfArgsString_DEPRECATED(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true); + } } std::string operator()(const Luau::SyntaxError& e) const @@ -470,9 +549,26 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC if (typeFun.typeParams.size() != rhs.typeFun.typeParams.size()) return false; + if (FFlag::LuauTypeAliasPacks) + { + if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size()) + return false; + } + for (size_t i = 0; i < typeFun.typeParams.size(); ++i) + { if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i]) return false; + } + + if (FFlag::LuauTypeAliasPacks) + { + for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) + { + if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) + return false; + } + } return true; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 4d385ec..b252984 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1,9 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Frontend.h" +#include "Luau/Common.h" #include "Luau/Config.h" #include "Luau/FileResolver.h" +#include "Luau/Scope.h" #include "Luau/StringUtils.h" +#include "Luau/TimeTrace.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" #include "Luau/Common.h" @@ -19,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauSecondTypecheckKnowsTheDataModel, false) LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) +LUAU_FASTFLAG(LuauNewRequireTrace) namespace Luau { @@ -69,6 +73,8 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName) LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName) { + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); + Luau::Allocator allocator; Luau::AstNameTable names(allocator); @@ -350,6 +356,9 @@ FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend) CheckResult Frontend::check(const ModuleName& name) { + LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + CheckResult checkResult; auto it = sourceNodes.find(name); @@ -479,6 +488,9 @@ CheckResult Frontend::check(const ModuleName& name) bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root) { + LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); + // https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search enum Mark { @@ -597,6 +609,9 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config LintResult Frontend::lint(const ModuleName& name, std::optional enabledLintWarnings) { + LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + CheckResult checkResult; auto [_sourceNode, sourceModule] = getSourceNode(checkResult, name); @@ -608,6 +623,8 @@ LintResult Frontend::lint(const ModuleName& name, std::optional Frontend::lintFragment(std::string_view source, std::optional enabledLintWarnings) { + LUAU_TIMETRACE_SCOPE("Frontend::lintFragment", "Frontend"); + const Config& config = configResolver->getConfig(""); SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions); @@ -627,6 +644,9 @@ std::pair Frontend::lintFragment(std::string_view sour CheckResult Frontend::check(const SourceModule& module) { + LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + const Config& config = configResolver->getConfig(module.name); Mode mode = module.mode.value_or(config.mode); @@ -648,6 +668,9 @@ CheckResult Frontend::check(const SourceModule& module) LintResult Frontend::lint(const SourceModule& module, std::optional enabledLintWarnings) { + LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + const Config& config = configResolver->getConfig(module.name); LintOptions options = enabledLintWarnings.value_or(config.enabledLint); @@ -746,6 +769,9 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name) { + LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + auto it = sourceNodes.find(name); if (it != sourceNodes.end() && !it->second.dirty) { @@ -815,6 +841,9 @@ std::pair Frontend::getSourceNode(CheckResult& check */ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions) { + LUAU_TIMETRACE_SCOPE("Frontend::parse", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + SourceModule sourceModule; double timestamp = getTimestamp(); @@ -864,20 +893,11 @@ std::optional FrontendModuleResolver::resolveModuleInfo(const Module const auto& exprs = it->second.exprs; - const ModuleName* relativeName = exprs.find(&pathExpr); - if (!relativeName || relativeName->empty()) + const ModuleInfo* info = exprs.find(&pathExpr); + if (!info || (!FFlag::LuauNewRequireTrace && info->name.empty())) return std::nullopt; - if (FFlag::LuauTraceRequireLookupChild) - { - const bool* optional = it->second.optional.find(&pathExpr); - - return {{*relativeName, optional ? *optional : false}}; - } - else - { - return {{*relativeName, false}}; - } + return *info; } const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) const @@ -891,12 +911,15 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const { - return frontend->fileResolver->moduleExists(moduleName); + if (FFlag::LuauNewRequireTrace) + return frontend->sourceNodes.count(moduleName) != 0; + else + return frontend->fileResolver->moduleExists(moduleName); } std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& moduleName) const { - return frontend->fileResolver->getHumanReadableModuleName_(moduleName).value_or(moduleName); + return frontend->fileResolver->getHumanReadableModuleName(moduleName); } ScopePtr Frontend::addEnvironment(const std::string& environmentName) diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 84e9b77..3b26712 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -2,6 +2,8 @@ #include "Luau/IostreamHelpers.h" #include "Luau/ToString.h" +LUAU_FASTFLAG(LuauTypeAliasPacks) + namespace Luau { @@ -92,7 +94,7 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo { stream << "IncorrectGenericParameterCount { name = " << error.name; - if (!error.typeFun.typeParams.empty()) + if (!error.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !error.typeFun.typePackParams.empty())) { stream << "<"; bool first = true; @@ -105,6 +107,20 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo stream << toString(t); } + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId t : error.typeFun.typePackParams) + { + if (first) + first = false; + else + stream << ", "; + + stream << toString(t); + } + } + stream << ">"; } diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index a101829..064accb 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -3,6 +3,9 @@ #include "Luau/Ast.h" #include "Luau/StringUtils.h" +#include "Luau/Common.h" + +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -612,6 +615,12 @@ struct AstJsonEncoder : public AstVisitor writeNode(node, "AstStatTypeAlias", [&]() { PROP(name); PROP(generics); + + if (FFlag::LuauTypeAliasPacks) + { + PROP(genericPacks); + } + PROP(type); PROP(exported); }); @@ -664,13 +673,21 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(struct AstTypeOrPack node) + { + if (node.type) + write(node.type); + else + write(node.typePack); + } + void write(class AstTypeReference* node) { writeNode(node, "AstTypeReference", [&]() { if (node->hasPrefix) PROP(prefix); PROP(name); - PROP(generics); + PROP(parameters); }); } @@ -734,6 +751,13 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(class AstTypePackExplicit* node) + { + writeNode(node, "AstTypePackExplicit", [&]() { + PROP(typeList); + }); + } + void write(class AstTypePackVariadic* node) { writeNode(node, "AstTypePackVariadic", [&]() { @@ -1018,6 +1042,12 @@ struct AstJsonEncoder : public AstVisitor return false; } + bool visit(class AstTypePackExplicit* node) override + { + write(node); + return false; + } + bool visit(class AstTypePackVariadic* node) override { write(node); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index f97f6a4..bff947a 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -3,6 +3,7 @@ #include "Luau/AstQuery.h" #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/StringUtils.h" #include "Luau/Common.h" @@ -12,6 +13,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauLinterUnknownTypeVectorAware, false) +LUAU_FASTFLAGVARIABLE(LuauLinterTableMoveZero, false) namespace Luau { @@ -85,10 +87,10 @@ struct LintContext return std::nullopt; auto it = module->astTypes.find(expr); - if (it == module->astTypes.end()) + if (!it) return std::nullopt; - return it->second; + return *it; } }; @@ -2144,6 +2146,19 @@ private: "wrap it in parentheses to silence"); } + if (FFlag::LuauLinterTableMoveZero && func->index == "move" && node->args.size >= 4) + { + // table.move(t, 0, _, _) + if (isConstant(args[1], 0.0)) + emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + + // table.move(t, _, _, 0) + else if (isConstant(args[3], 0.0)) + emitWarning(*context, LintWarning::Code_TableOperations, args[3]->location, + "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + } + return true; } diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index f1d975f..df6be76 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -13,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -188,7 +190,7 @@ struct TypePackCloner template void defaultClone(const T& t) { - TypePackId cloned = dest.typePacks.allocate(t); + TypePackId cloned = dest.addTypePack(TypePackVar{t}); seenTypePacks[typePackId] = cloned; } @@ -197,7 +199,7 @@ struct TypePackCloner if (encounteredFreeType) *encounteredFreeType = true; - seenTypePacks[typePackId] = dest.typePacks.allocate(TypePackVar{Unifiable::Error{}}); + seenTypePacks[typePackId] = dest.addTypePack(TypePackVar{Unifiable::Error{}}); } void operator()(const Unifiable::Generic& t) @@ -219,13 +221,13 @@ struct TypePackCloner void operator()(const VariadicTypePack& t) { - TypePackId cloned = dest.typePacks.allocate(VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}); + TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}}); seenTypePacks[typePackId] = cloned; } void operator()(const TypePack& t) { - TypePackId cloned = dest.typePacks.allocate(TypePack{}); + TypePackId cloned = dest.addTypePack(TypePack{}); TypePack* destTp = getMutable(cloned); LUAU_ASSERT(destTp != nullptr); seenTypePacks[typePackId] = cloned; @@ -241,7 +243,7 @@ struct TypePackCloner template void TypeCloner::defaultClone(const T& t) { - TypeId cloned = dest.typeVars.allocate(t); + TypeId cloned = dest.addType(t); seenTypes[typeId] = cloned; } @@ -250,7 +252,7 @@ void TypeCloner::operator()(const Unifiable::Free& t) if (encounteredFreeType) *encounteredFreeType = true; - seenTypes[typeId] = dest.typeVars.allocate(ErrorTypeVar{}); + seenTypes[typeId] = dest.addType(ErrorTypeVar{}); } void TypeCloner::operator()(const Unifiable::Generic& t) @@ -275,7 +277,7 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t) void TypeCloner::operator()(const FunctionTypeVar& t) { - TypeId result = dest.typeVars.allocate(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); + TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); FunctionTypeVar* ftv = getMutable(result); LUAU_ASSERT(ftv != nullptr); @@ -297,7 +299,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t) void TypeCloner::operator()(const TableTypeVar& t) { - TypeId result = dest.typeVars.allocate(TableTypeVar{}); + TypeId result = dest.addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(result); LUAU_ASSERT(ttv != nullptr); @@ -323,7 +325,13 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); for (TypeId& arg : ttv->instantiatedTypeParams) - arg = (clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType)); + arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId& arg : ttv->instantiatedTypePackParams) + arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); + } if (ttv->state == TableState::Free) { @@ -343,7 +351,7 @@ void TypeCloner::operator()(const TableTypeVar& t) void TypeCloner::operator()(const MetatableTypeVar& t) { - TypeId result = dest.typeVars.allocate(MetatableTypeVar{}); + TypeId result = dest.addType(MetatableTypeVar{}); MetatableTypeVar* mtv = getMutable(result); seenTypes[typeId] = result; @@ -353,7 +361,7 @@ void TypeCloner::operator()(const MetatableTypeVar& t) void TypeCloner::operator()(const ClassTypeVar& t) { - TypeId result = dest.typeVars.allocate(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); + TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); ClassTypeVar* ctv = getMutable(result); seenTypes[typeId] = result; @@ -378,7 +386,7 @@ void TypeCloner::operator()(const AnyTypeVar& t) void TypeCloner::operator()(const UnionTypeVar& t) { - TypeId result = dest.typeVars.allocate(UnionTypeVar{}); + TypeId result = dest.addType(UnionTypeVar{}); seenTypes[typeId] = result; UnionTypeVar* option = getMutable(result); @@ -390,7 +398,7 @@ void TypeCloner::operator()(const UnionTypeVar& t) void TypeCloner::operator()(const IntersectionTypeVar& t) { - TypeId result = dest.typeVars.allocate(IntersectionTypeVar{}); + TypeId result = dest.addType(IntersectionTypeVar{}); seenTypes[typeId] = result; IntersectionTypeVar* option = getMutable(result); @@ -451,8 +459,14 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) { TypeFun result; - for (TypeId param : typeFun.typeParams) - result.typeParams.push_back(clone(param, dest, seenTypes, seenTypePacks, encounteredFreeType)); + for (TypeId ty : typeFun.typeParams) + result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId tp : typeFun.typePackParams) + result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, encounteredFreeType)); + } result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, encounteredFreeType); diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index c9e15cb..4634eff 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -5,6 +5,7 @@ #include "Luau/Module.h" LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false) +LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace, false) namespace Luau { @@ -12,17 +13,18 @@ namespace Luau namespace { -struct RequireTracer : AstVisitor +struct RequireTracerOld : AstVisitor { - explicit RequireTracer(FileResolver* fileResolver, ModuleName currentModuleName) + explicit RequireTracerOld(FileResolver* fileResolver, const ModuleName& currentModuleName) : fileResolver(fileResolver) - , currentModuleName(std::move(currentModuleName)) + , currentModuleName(currentModuleName) { + LUAU_ASSERT(!FFlag::LuauNewRequireTrace); } FileResolver* const fileResolver; ModuleName currentModuleName; - DenseHashMap locals{0}; + DenseHashMap locals{nullptr}; RequireTraceResult result; std::optional fromAstFragment(AstExpr* expr) @@ -50,9 +52,9 @@ struct RequireTracer : AstVisitor AstExpr* expr = stat->values.data[i]; expr->visit(this); - const ModuleName* name = result.exprs.find(expr); - if (name) - locals[local] = *name; + const ModuleInfo* info = result.exprs.find(expr); + if (info) + locals[local] = info->name; } } @@ -63,7 +65,7 @@ struct RequireTracer : AstVisitor { std::optional name = fromAstFragment(global); if (name) - result.exprs[global] = *name; + result.exprs[global] = {*name}; return false; } @@ -72,7 +74,7 @@ struct RequireTracer : AstVisitor { const ModuleName* name = locals.find(local->local); if (name) - result.exprs[local] = *name; + result.exprs[local] = {*name}; return false; } @@ -81,16 +83,16 @@ struct RequireTracer : AstVisitor { indexName->expr->visit(this); - const ModuleName* name = result.exprs.find(indexName->expr); - if (name) + const ModuleInfo* info = result.exprs.find(indexName->expr); + if (info) { if (indexName->index == "parent" || indexName->index == "Parent") { - if (auto parent = fileResolver->getParentModuleName(*name)) - result.exprs[indexName] = *parent; + if (auto parent = fileResolver->getParentModuleName(info->name)) + result.exprs[indexName] = {*parent}; } else - result.exprs[indexName] = fileResolver->concat(*name, indexName->index.value); + result.exprs[indexName] = {fileResolver->concat(info->name, indexName->index.value)}; } return false; @@ -100,11 +102,11 @@ struct RequireTracer : AstVisitor { indexExpr->expr->visit(this); - const ModuleName* name = result.exprs.find(indexExpr->expr); + const ModuleInfo* info = result.exprs.find(indexExpr->expr); const AstExprConstantString* str = indexExpr->index->as(); - if (name && str) + if (info && str) { - result.exprs[indexExpr] = fileResolver->concat(*name, std::string_view(str->value.data, str->value.size)); + result.exprs[indexExpr] = {fileResolver->concat(info->name, std::string_view(str->value.data, str->value.size))}; } indexExpr->index->visit(this); @@ -129,8 +131,8 @@ struct RequireTracer : AstVisitor AstExprGlobal* globalName = call->func->as(); if (globalName && globalName->name == "require" && call->args.size >= 1) { - if (const ModuleName* moduleName = result.exprs.find(call->args.data[0])) - result.requires.push_back({*moduleName, call->location}); + if (const ModuleInfo* moduleInfo = result.exprs.find(call->args.data[0])) + result.requires.push_back({moduleInfo->name, call->location}); return false; } @@ -143,8 +145,8 @@ struct RequireTracer : AstVisitor if (FFlag::LuauTraceRequireLookupChild && !rootName) { - if (const ModuleName* moduleName = result.exprs.find(indexName->expr)) - rootName = *moduleName; + if (const ModuleInfo* moduleInfo = result.exprs.find(indexName->expr)) + rootName = moduleInfo->name; } if (!rootName) @@ -167,24 +169,183 @@ struct RequireTracer : AstVisitor if (v.end() != std::find(v.begin(), v.end(), '/')) return false; - result.exprs[call] = fileResolver->concat(*rootName, v); + result.exprs[call] = {fileResolver->concat(*rootName, v)}; // 'WaitForChild' can be used on modules that are not available at the typecheck time, but will be available at runtime // If we fail to find such module, we will not report an UnknownRequire error if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild") - result.optional[call] = true; + result.exprs[call].optional = true; return false; } }; +struct RequireTracer : AstVisitor +{ + RequireTracer(RequireTraceResult& result, FileResolver * fileResolver, const ModuleName& currentModuleName) + : result(result) + , fileResolver(fileResolver) + , currentModuleName(currentModuleName) + , locals(nullptr) + { + LUAU_ASSERT(FFlag::LuauNewRequireTrace); + } + + bool visit(AstExprTypeAssertion* expr) override + { + // suppress `require() :: any` + return false; + } + + bool visit(AstExprCall* expr) override + { + AstExprGlobal* global = expr->func->as(); + + if (global && global->name == "require" && expr->args.size >= 1) + requires.push_back(expr); + + return true; + } + + bool visit(AstStatLocal* stat) override + { + for (size_t i = 0; i < stat->vars.size && i < stat->values.size; ++i) + { + AstLocal* local = stat->vars.data[i]; + AstExpr* expr = stat->values.data[i]; + + // track initializing expression to be able to trace modules through locals + locals[local] = expr; + } + + return true; + } + + bool visit(AstStatAssign* stat) override + { + for (size_t i = 0; i < stat->vars.size; ++i) + { + // locals that are assigned don't have a known expression + if (AstExprLocal* expr = stat->vars.data[i]->as()) + locals[expr->local] = nullptr; + } + + return true; + } + + bool visit(AstType* node) override + { + // allow resolving require inside `typeof` annotations + return true; + } + + AstExpr* getDependent(AstExpr* node) + { + if (AstExprLocal* expr = node->as()) + return locals[expr->local]; + else if (AstExprIndexName* expr = node->as()) + return expr->expr; + else if (AstExprIndexExpr* expr = node->as()) + return expr->expr; + else if (AstExprCall* expr = node->as(); expr && expr->self) + return expr->func->as()->expr; + else + return nullptr; + } + + void process() + { + ModuleInfo moduleContext{currentModuleName}; + + // seed worklist with require arguments + work.reserve(requires.size()); + + for (AstExprCall* require: requires) + work.push_back(require->args.data[0]); + + // push all dependent expressions to the work stack; note that the vector is modified during traversal + for (size_t i = 0; i < work.size(); ++i) + if (AstExpr* dep = getDependent(work[i])) + work.push_back(dep); + + // resolve all expressions to a module info + for (size_t i = work.size(); i > 0; --i) + { + AstExpr* expr = work[i - 1]; + + // when multiple expressions depend on the same one we push it to work queue multiple times + if (result.exprs.contains(expr)) + continue; + + std::optional info; + + if (AstExpr* dep = getDependent(expr)) + { + const ModuleInfo* context = result.exprs.find(dep); + + // locals just inherit their dependent context, no resolution required + if (expr->is()) + info = context ? std::optional(*context) : std::nullopt; + else + info = fileResolver->resolveModule(context, expr); + } + else + { + info = fileResolver->resolveModule(&moduleContext, expr); + } + + if (info) + result.exprs[expr] = std::move(*info); + } + + // resolve all requires according to their argument + result.requires.reserve(requires.size()); + + for (AstExprCall* require : requires) + { + AstExpr* arg = require->args.data[0]; + + if (const ModuleInfo* info = result.exprs.find(arg)) + { + result.requires.push_back({info->name, require->location}); + + ModuleInfo infoCopy = *info; // copy *info out since next line invalidates info! + result.exprs[require] = std::move(infoCopy); + } + else + { + result.exprs[require] = {}; // mark require as unresolved + } + } + } + + RequireTraceResult& result; + FileResolver* fileResolver; + ModuleName currentModuleName; + + DenseHashMap locals; + std::vector work; + std::vector requires; +}; + } // anonymous namespace -RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName) +RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) { - RequireTracer tracer{fileResolver, std::move(currentModuleName)}; - root->visit(&tracer); - return tracer.result; + if (FFlag::LuauNewRequireTrace) + { + RequireTraceResult result; + RequireTracer tracer{result, fileResolver, currentModuleName}; + root->visit(&tracer); + tracer.process(); + return result; + } + else + { + RequireTracerOld tracer{fileResolver, currentModuleName}; + root->visit(&tracer); + return tracer.result; + } } } // namespace Luau diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp new file mode 100644 index 0000000..c30db9c --- /dev/null +++ b/Analysis/src/Scope.cpp @@ -0,0 +1,123 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Scope.h" + +namespace Luau +{ + +Scope::Scope(TypePackId returnType) + : parent(nullptr) + , returnType(returnType) + , level(TypeLevel()) +{ +} + +Scope::Scope(const ScopePtr& parent, int subLevel) + : parent(parent) + , returnType(parent->returnType) + , level(parent->level.incr()) +{ + level.subLevel = subLevel; +} + +std::optional Scope::lookup(const Symbol& name) +{ + Scope* scope = this; + + while (scope) + { + auto it = scope->bindings.find(name); + if (it != scope->bindings.end()) + return it->second.typeId; + + scope = scope->parent.get(); + } + + return std::nullopt; +} + +std::optional Scope::lookupType(const Name& name) +{ + const Scope* scope = this; + while (true) + { + auto it = scope->exportedTypeBindings.find(name); + if (it != scope->exportedTypeBindings.end()) + return it->second; + + it = scope->privateTypeBindings.find(name); + if (it != scope->privateTypeBindings.end()) + return it->second; + + if (scope->parent) + scope = scope->parent.get(); + else + return std::nullopt; + } +} + +std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) +{ + const Scope* scope = this; + while (scope) + { + auto it = scope->importedTypeBindings.find(moduleAlias); + if (it == scope->importedTypeBindings.end()) + { + scope = scope->parent.get(); + continue; + } + + auto it2 = it->second.find(name); + if (it2 == it->second.end()) + { + scope = scope->parent.get(); + continue; + } + + return it2->second; + } + + return std::nullopt; +} + +std::optional Scope::lookupPack(const Name& name) +{ + const Scope* scope = this; + while (true) + { + auto it = scope->privateTypePackBindings.find(name); + if (it != scope->privateTypePackBindings.end()) + return it->second; + + if (scope->parent) + scope = scope->parent.get(); + else + return std::nullopt; + } +} + +std::optional Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain) +{ + Scope* scope = this; + + while (scope) + { + for (const auto& [n, binding] : scope->bindings) + { + if (n.local && n.local->name == name.c_str()) + return binding; + else if (n.global.value && n.global == name.c_str()) + return binding; + } + + scope = scope->parent.get(); + + if (!traverseScopeChain) + break; + } + + return std::nullopt; +} + +} // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 7223998..d861eb3 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -6,9 +6,11 @@ #include #include -LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 0) +LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) +LUAU_FASTFLAGVARIABLE(LuauSubstitutionDontReplaceIgnoredTypes, false) LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauRankNTypes) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -35,8 +37,15 @@ void Tarjan::visitChildren(TypeId ty, int index) visitChild(ttv->indexer->indexType); visitChild(ttv->indexer->indexResultType); } + for (TypeId itp : ttv->instantiatedTypeParams) visitChild(itp); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp); + } } else if (const MetatableTypeVar* mtv = get(ty)) { @@ -332,9 +341,11 @@ std::optional Substitution::substitute(TypeId ty) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - replaceChildren(newTy); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy)) + replaceChildren(newTy); for (auto [oldTp, newTp] : newPacks) - replaceChildren(newTp); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp)) + replaceChildren(newTp); TypeId newTy = replace(ty); return newTy; } @@ -350,9 +361,11 @@ std::optional Substitution::substitute(TypePackId tp) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - replaceChildren(newTy); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy)) + replaceChildren(newTy); for (auto [oldTp, newTp] : newPacks) - replaceChildren(newTp); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp)) + replaceChildren(newTp); TypePackId newTp = replace(tp); return newTp; } @@ -382,6 +395,10 @@ TypeId Substitution::clone(TypeId ty) clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; clone.instantiatedTypeParams = ttv->instantiatedTypeParams; + + if (FFlag::LuauTypeAliasPacks) + clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; + if (FFlag::LuauSecondTypecheckKnowsTheDataModel) clone.tags = ttv->tags; result = addType(std::move(clone)); @@ -487,8 +504,15 @@ void Substitution::replaceChildren(TypeId ty) ttv->indexer->indexType = replace(ttv->indexer->indexType); ttv->indexer->indexResultType = replace(ttv->indexer->indexResultType); } + for (TypeId& itp : ttv->instantiatedTypeParams) itp = replace(itp); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId& itp : ttv->instantiatedTypePackParams) + itp = replace(itp); + } } else if (MetatableTypeVar* mtv = getMutable(ty)) { diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 9d2f47b..5651af7 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/ToString.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -9,10 +10,10 @@ #include #include -LUAU_FASTFLAG(LuauToStringFollowsBoundTo) LUAU_FASTFLAG(LuauExtraNilRecovery) LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -59,6 +60,13 @@ struct FindCyclicTypes { for (TypeId itp : ttv.instantiatedTypeParams) visitTypeVar(itp, *this, seen); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId itp : ttv.instantiatedTypePackParams) + visitTypeVar(itp, *this, seen); + } + return exhaustive; } @@ -258,23 +266,60 @@ struct TypeVarStringifier void stringify(TypePackId tp); void stringify(TypePackId tpid, const std::vector>& names); - void stringify(const std::vector& types) + void stringify(const std::vector& types, const std::vector& typePacks) { - if (types.size() == 0) + if (types.size() == 0 && (!FFlag::LuauTypeAliasPacks || typePacks.size() == 0)) return; - if (types.size()) + if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size())) state.emit("<"); - for (size_t i = 0; i < types.size(); ++i) + if (FFlag::LuauTypeAliasPacks) { - if (i > 0) - state.emit(", "); + bool first = true; - stringify(types[i]); + for (TypeId ty : types) + { + if (!first) + state.emit(", "); + first = false; + + stringify(ty); + } + + bool singleTp = typePacks.size() == 1; + + for (TypePackId tp : typePacks) + { + if (isEmpty(tp) && singleTp) + continue; + + if (!first) + state.emit(", "); + else + first = false; + + if (!singleTp) + state.emit("("); + + stringify(tp); + + if (!singleTp) + state.emit(")"); + } + } + else + { + for (size_t i = 0; i < types.size(); ++i) + { + if (i > 0) + state.emit(", "); + + stringify(types[i]); + } } - if (types.size()) + if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size())) state.emit(">"); } @@ -388,7 +433,7 @@ struct TypeVarStringifier void operator()(TypeId, const TableTypeVar& ttv) { - if (FFlag::LuauToStringFollowsBoundTo && ttv.boundTo) + if (ttv.boundTo) return stringify(*ttv.boundTo); if (!state.exhaustive) @@ -411,14 +456,14 @@ struct TypeVarStringifier } state.emit(*ttv.name); - stringify(ttv.instantiatedTypeParams); + stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams); return; } if (ttv.syntheticName) { state.result.invalid = true; state.emit(*ttv.syntheticName); - stringify(ttv.instantiatedTypeParams); + stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams); return; } } @@ -900,13 +945,26 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - if (ttv->instantiatedTypeParams.empty()) + if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty())) return result; std::vector params; for (TypeId tp : ttv->instantiatedTypeParams) params.push_back(toString(tp)); + if (FFlag::LuauTypeAliasPacks) + { + // Doesn't preserve grouping of multiple type packs + // But this is under a parent block of code that is being removed later + for (TypePackId tp : ttv->instantiatedTypePackParams) + { + std::string content = toString(tp); + + if (!content.empty()) + params.push_back(std::move(content)); + } + } + result.name += "<" + join(params, ", ") + ">"; return result; } @@ -950,30 +1008,37 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - if (ttv->instantiatedTypeParams.empty()) - return result; - - result.name += "<"; - - bool first = true; - for (TypeId ty : ttv->instantiatedTypeParams) + if (FFlag::LuauTypeAliasPacks) { - if (!first) - result.name += ", "; - else - first = false; - - tvs.stringify(ty); - } - - if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) - { - result.truncated = true; - result.name += "... "; + tvs.stringify(ttv->instantiatedTypeParams, ttv->instantiatedTypePackParams); } else { - result.name += ">"; + if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty())) + return result; + + result.name += "<"; + + bool first = true; + for (TypeId ty : ttv->instantiatedTypeParams) + { + if (!first) + result.name += ", "; + else + first = false; + + tvs.stringify(ty); + } + + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) + { + result.truncated = true; + result.name += "... "; + } + else + { + result.name += ">"; + } } return result; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 462c70f..1b83ccd 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace { @@ -280,10 +281,19 @@ struct Printer void visualizeTypePackAnnotation(const AstTypePack& annotation) { - if (const AstTypePackVariadic* variadic = annotation.as()) + if (const AstTypePackVariadic* variadicTp = annotation.as()) { writer.symbol("..."); - visualizeTypeAnnotation(*variadic->variadicType); + visualizeTypeAnnotation(*variadicTp->variadicType); + } + else if (const AstTypePackGeneric* genericTp = annotation.as()) + { + writer.symbol(genericTp->genericName.value); + writer.symbol("..."); + } + else if (const AstTypePackExplicit* explicitTp = annotation.as()) + { + visualizeTypeList(explicitTp->typeList, true); } else { @@ -807,7 +817,7 @@ struct Printer writer.keyword("type"); writer.identifier(a->name.value); - if (a->generics.size > 0) + if (a->generics.size > 0 || (FFlag::LuauTypeAliasPacks && a->genericPacks.size > 0)) { writer.symbol("<"); CommaSeparatorInserter comma(writer); @@ -817,6 +827,17 @@ struct Printer comma(); writer.identifier(o.value); } + + if (FFlag::LuauTypeAliasPacks) + { + for (auto o : a->genericPacks) + { + comma(); + writer.identifier(o.value); + writer.symbol("..."); + } + } + writer.symbol(">"); } writer.maybeSpace(a->type->location.begin, 2); @@ -960,15 +981,20 @@ struct Printer if (const auto& a = typeAnnotation.as()) { writer.write(a->name.value); - if (a->generics.size > 0) + if (a->parameters.size > 0) { CommaSeparatorInserter comma(writer); writer.symbol("<"); - for (auto o : a->generics) + for (auto o : a->parameters) { comma(); - visualizeTypeAnnotation(*o); + + if (o.type) + visualizeTypeAnnotation(*o.type); + else + visualizeTypePackAnnotation(*o.typePack); } + writer.symbol(">"); } } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 17c57c8..266c198 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -5,6 +5,7 @@ #include "Luau/Module.h" #include "Luau/Parser.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -12,6 +13,7 @@ #include LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAG(LuauTypeAliasPacks) static char* allocateString(Luau::Allocator& allocator, std::string_view contents) { @@ -33,7 +35,6 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data namespace Luau { - class TypeRehydrationVisitor { mutable std::map seen; @@ -57,6 +58,8 @@ public: { } + AstTypePack* rehydrate(TypePackId tp) const; + AstType* operator()(const PrimitiveTypeVar& ptv) const { switch (ptv.type) @@ -85,16 +88,24 @@ public: if (ttv.name && options.bannedNames.find(*ttv.name) == options.bannedNames.end()) { - AstArray generics; - generics.size = ttv.instantiatedTypeParams.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstType*) * generics.size)); + AstArray parameters; + parameters.size = ttv.instantiatedTypeParams.size(); + parameters.data = static_cast(allocator->allocate(sizeof(AstTypeOrPack) * parameters.size)); for (size_t i = 0; i < ttv.instantiatedTypeParams.size(); ++i) { - generics.data[i] = Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty); + parameters.data[i] = {Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty), {}}; } - return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), generics); + if (FFlag::LuauTypeAliasPacks) + { + for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i) + { + parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])}; + } + } + + return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), parameters.size != 0, parameters); } if (hasSeen(&ttv)) @@ -222,10 +233,17 @@ public: AstTypePack* argTailAnnotation = nullptr; if (argTail) { - TypePackId tail = *argTail; - if (const VariadicTypePack* vtp = get(tail)) + if (FFlag::LuauTypeAliasPacks) { - argTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + argTailAnnotation = rehydrate(*argTail); + } + else + { + TypePackId tail = *argTail; + if (const VariadicTypePack* vtp = get(tail)) + { + argTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + } } } @@ -255,10 +273,17 @@ public: AstTypePack* retTailAnnotation = nullptr; if (retTail) { - TypePackId tail = *retTail; - if (const VariadicTypePack* vtp = get(tail)) + if (FFlag::LuauTypeAliasPacks) { - retTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + retTailAnnotation = rehydrate(*retTail); + } + else + { + TypePackId tail = *retTail; + if (const VariadicTypePack* vtp = get(tail)) + { + retTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + } } } @@ -313,6 +338,68 @@ private: const TypeRehydrationOptions& options; }; +class TypePackRehydrationVisitor +{ +public: + TypePackRehydrationVisitor(Allocator* allocator, const TypeRehydrationVisitor& typeVisitor) + : allocator(allocator) + , typeVisitor(typeVisitor) + { + } + + AstTypePack* operator()(const BoundTypePack& btp) const + { + return Luau::visit(*this, btp.boundTo->ty); + } + + AstTypePack* operator()(const TypePack& tp) const + { + AstArray head; + head.size = tp.head.size(); + head.data = static_cast(allocator->allocate(sizeof(AstType*) * tp.head.size())); + + for (size_t i = 0; i < tp.head.size(); i++) + head.data[i] = Luau::visit(typeVisitor, tp.head[i]->ty); + + AstTypePack* tail = nullptr; + + if (tp.tail) + tail = Luau::visit(*this, (*tp.tail)->ty); + + return allocator->alloc(Location(), AstTypeList{head, tail}); + } + + AstTypePack* operator()(const VariadicTypePack& vtp) const + { + return allocator->alloc(Location(), Luau::visit(typeVisitor, vtp.ty->ty)); + } + + AstTypePack* operator()(const GenericTypePack& gtp) const + { + return allocator->alloc(Location(), AstName(gtp.name.c_str())); + } + + AstTypePack* operator()(const FreeTypePack& gtp) const + { + return allocator->alloc(Location(), AstName("free")); + } + + AstTypePack* operator()(const Unifiable::Error&) const + { + return allocator->alloc(Location(), AstName("Unifiable")); + } + +private: + Allocator* allocator; + const TypeRehydrationVisitor& typeVisitor; +}; + +AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) const +{ + TypePackRehydrationVisitor tprv(allocator, *this); + return Luau::visit(tprv, tp->ty); +} + class TypeAttacher : public AstVisitor { public: @@ -406,9 +493,16 @@ public: if (tail) { - TypePackId tailPack = *tail; - if (const VariadicTypePack* vtp = get(tailPack)) - variadicAnnotation = allocator->alloc(Location(), typeAst(vtp->ty)); + if (FFlag::LuauTypeAliasPacks) + { + variadicAnnotation = TypeRehydrationVisitor(allocator).rehydrate(*tail); + } + else + { + TypePackId tailPack = *tail; + if (const VariadicTypePack* vtp = get(tailPack)) + variadicAnnotation = allocator->alloc(Location(), typeAst(vtp->ty)); + } } fn->returnAnnotation = AstTypeList{typeAstPack(ret), variadicAnnotation}; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 5c5217e..da93c8e 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -5,21 +5,22 @@ #include "Luau/ModuleResolver.h" #include "Luau/Parser.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/Substitution.h" #include "Luau/TopoSortStatements.h" #include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" +#include "Luau/TimeTrace.h" #include #include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) -LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 0) -LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 0) +LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) +LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) -LUAU_FASTFLAGVARIABLE(LuauIndexTablesWithIndexers, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctions, false) LUAU_FASTFLAGVARIABLE(LuauGenericVariadicsUnification, false) LUAU_FASTFLAG(LuauKnowsTheDataModel3) @@ -27,14 +28,11 @@ LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAGVARIABLE(LuauClassPropertyAccessAsString, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. -LUAU_FASTFLAGVARIABLE(LuauImprovedTypeGuardPredicate2, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) -LUAU_FASTFLAG(DebugLuauTrackOwningArena) LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false) LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false) -LUAU_FASTFLAGVARIABLE(LuauFixTableTypeAliasClone, false) LUAU_FASTFLAGVARIABLE(LuauExtraNilRecovery, false) LUAU_FASTFLAGVARIABLE(LuauMissingUnionPropertyError, false) LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false) @@ -45,6 +43,10 @@ LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false) LUAU_FASTFLAGVARIABLE(LuauInferFunctionArgsFix, false) LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) +LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) +LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) +LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -216,9 +218,8 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , nilType(singletonTypes.nilType) , numberType(singletonTypes.numberType) , stringType(singletonTypes.stringType) - , booleanType( - FFlag::LuauImprovedTypeGuardPredicate2 ? singletonTypes.booleanType : globalTypes.addType(PrimitiveTypeVar(PrimitiveTypeVar::Boolean))) - , threadType(FFlag::LuauImprovedTypeGuardPredicate2 ? singletonTypes.threadType : globalTypes.addType(PrimitiveTypeVar(PrimitiveTypeVar::Thread))) + , booleanType(singletonTypes.booleanType) + , threadType(singletonTypes.threadType) , anyType(singletonTypes.anyType) , errorType(singletonTypes.errorType) , optionalNumberType(globalTypes.addType(UnionTypeVar{{numberType, nilType}})) @@ -237,6 +238,9 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) { + LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker"); + LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + currentModule.reset(new Module()); currentModule->type = module.type; @@ -1177,44 +1181,61 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { Location location = scope->typeAliasLocations[name]; reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); - bindingsMap[name] = TypeFun{binding->typeParams, errorType}; + + if (FFlag::LuauTypeAliasPacks) + bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorType}; + else + bindingsMap[name] = TypeFun{binding->typeParams, errorType}; } else { ScopePtr aliasScope = childScope(scope, typealias.location); - std::vector generics; - for (AstName generic : typealias.generics) + if (FFlag::LuauTypeAliasPacks) { - Name n = generic.value; + auto [generics, genericPacks] = createGenericTypes(aliasScope, typealias, typealias.generics, typealias.genericPacks); - // These generics are the only thing that will ever be added to aliasScope, so we can be certain that - // a collision can only occur when two generic typevars have the same name. - if (aliasScope->privateTypeBindings.end() != aliasScope->privateTypeBindings.find(n)) - { - // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. - reportError(TypeError{typealias.location, DuplicateGenericParameter{n}}); - } - - TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction) - { - TypeId& cached = scope->typeAliasParameters[n]; - if (!cached) - cached = addType(GenericTypeVar{aliasScope->level, n}); - g = cached; - } - else - g = addType(GenericTypeVar{aliasScope->level, n}); - generics.push_back(g); - aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; + TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + FreeTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->forwardedTypeAlias = true; + bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; } + else + { + std::vector generics; + for (AstName generic : typealias.generics) + { + Name n = generic.value; - TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); - FreeTypeVar* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - ftv->forwardedTypeAlias = true; - bindingsMap[name] = {std::move(generics), ty}; + // These generics are the only thing that will ever be added to aliasScope, so we can be certain that + // a collision can only occur when two generic typevars have the same name. + if (aliasScope->privateTypeBindings.end() != aliasScope->privateTypeBindings.find(n)) + { + // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. + reportError(TypeError{typealias.location, DuplicateGenericParameter{n}}); + } + + TypeId g; + if (FFlag::LuauRecursiveTypeParameterRestriction) + { + TypeId& cached = scope->typeAliasTypeParameters[n]; + if (!cached) + cached = addType(GenericTypeVar{aliasScope->level, n}); + g = cached; + } + else + g = addType(GenericTypeVar{aliasScope->level, n}); + generics.push_back(g); + aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; + } + + TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + FreeTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->forwardedTypeAlias = true; + bindingsMap[name] = {std::move(generics), ty}; + } } } else @@ -1231,6 +1252,16 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty}; } + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId tp : binding->typePackParams) + { + auto generic = get(tp); + LUAU_ASSERT(generic); + aliasScope->privateTypePackBindings[generic->name] = tp; + } + } + TypeId ty = (FFlag::LuauRankNTypes ? resolveType(aliasScope, *typealias.type) : resolveType(aliasScope, *typealias.type, true)); if (auto ttv = getMutable(follow(ty))) { @@ -1238,7 +1269,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (ttv->name) { // Copy can be skipped if this is an identical alias - if (!FFlag::LuauFixTableTypeAliasClone || ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams) + if (ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams || + (FFlag::LuauTypeAliasPacks && ttv->instantiatedTypePackParams != binding->typePackParams)) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -1249,6 +1281,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias clone.name = name; clone.instantiatedTypeParams = binding->typeParams; + if (FFlag::LuauTypeAliasPacks) + clone.instantiatedTypePackParams = binding->typePackParams; + ty = addType(std::move(clone)); } } @@ -1256,6 +1291,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { ttv->name = name; ttv->instantiatedTypeParams = binding->typeParams; + + if (FFlag::LuauTypeAliasPacks) + ttv->instantiatedTypePackParams = binding->typePackParams; } } else if (auto mtv = getMutable(follow(ty))) @@ -1280,7 +1318,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar } // We don't have generic classes, so this assertion _should_ never be hit. - LUAU_ASSERT(lookupType->typeParams.size() == 0); + LUAU_ASSERT(lookupType->typeParams.size() == 0 && (!FFlag::LuauTypeAliasPacks || lookupType->typePackParams.size() == 0)); superTy = lookupType->type; if (FFlag::LuauAddMissingFollow) @@ -1465,7 +1503,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& if (FFlag::LuauStoreMatchingOverloadFnType) { - currentModule->astTypes.try_emplace(&expr, result.type); + if (!currentModule->astTypes.find(&expr)) + currentModule->astTypes[&expr] = result.type; } else { @@ -2193,7 +2232,7 @@ TypeId TypeChecker::checkRelationalOperation( * have a better, more descriptive error teed up. */ Unifier state = mkUnifier(expr.location); - if (!FFlag::LuauEqConstraint || !isEquality) + if (!isEquality) state.tryUnify(lhsType, rhsType); bool needsMetamethod = !isEquality; @@ -2262,7 +2301,7 @@ TypeId TypeChecker::checkRelationalOperation( } } - if (get(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && (!FFlag::LuauEqConstraint || !isEquality)) + if (get(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && !isEquality) { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); @@ -2276,18 +2315,6 @@ TypeId TypeChecker::checkRelationalOperation( return errorType; } - if (!FFlag::LuauEqConstraint) - { - if (isEquality) - { - ErrorVec errVec = tryUnify(rhsType, lhsType, expr.location); - if (!state.errors.empty() && !errVec.empty()) - reportError(expr.location, TypeMismatch{lhsType, rhsType}); - } - else - reportErrors(state.errors); - } - return booleanType; } @@ -2443,7 +2470,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi TypeId result = checkBinaryOperation(innerScope, expr, lhs.type, rhs.type, lhs.predicates); return {result, {OrPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; } - else if (FFlag::LuauEqConstraint && (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe)) + else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) { if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; @@ -2466,14 +2493,6 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi } else { - // Once we have EqPredicate, we should break this else branch into its' own branch. - // For now, fall through is intentional. - if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) - { - if (auto predicate = tryGetTypeGuardPredicate(expr)) - return {booleanType, {std::move(*predicate)}}; - } - ExprResult lhs = checkExpr(scope, *expr.left); ExprResult rhs = checkExpr(scope, *expr.right); @@ -2755,12 +2774,6 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return std::pair(resultType, nullptr); } - else if (FFlag::LuauIndexTablesWithIndexers) - { - // We allow t[x] where x:string for tables without an indexer - unify(indexType, stringType, expr.location); - return std::pair(anyType, nullptr); - } else { TypeId resultType = freshType(scope); @@ -3076,6 +3089,13 @@ static Location getEndLocation(const AstExprFunction& function) void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstExprFunction& function) { + LUAU_TIMETRACE_SCOPE("TypeChecker::checkFunctionBody", "TypeChecker"); + + if (function.debugname.value) + LUAU_TIMETRACE_ARGUMENT("name", function.debugname.value); + else + LUAU_TIMETRACE_ARGUMENT("line", std::to_string(function.location.begin.line).c_str()); + if (FunctionTypeVar* funTy = getMutable(ty)) { check(scope, *function.body); @@ -3885,6 +3905,20 @@ std::optional TypeChecker::matchRequire(const AstExprCall& call) TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location) { + LUAU_TIMETRACE_SCOPE("TypeChecker::checkRequire", "TypeChecker"); + LUAU_TIMETRACE_ARGUMENT("moduleInfo", moduleInfo.name.c_str()); + + if (FFlag::LuauNewRequireTrace && moduleInfo.name.empty()) + { + if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) + { + reportError(TypeError{location, UnknownRequire{}}); + return errorType; + } + + return anyType; + } + ModulePtr module = resolver->getModule(moduleInfo.name); if (!module) { @@ -4472,7 +4506,7 @@ TypeId TypeChecker::freshType(const ScopePtr& scope) TypeId TypeChecker::freshType(TypeLevel level) { - return currentModule->internalTypes.typeVars.allocate(TypeVar(FreeTypeVar(level))); + return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); } TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric) @@ -4482,11 +4516,7 @@ TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneri TypeId TypeChecker::DEPRECATED_freshType(TypeLevel level, bool canBeGeneric) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(TypeVar(FreeTypeVar(level, canBeGeneric))); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level, canBeGeneric))); } std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) @@ -4506,20 +4536,12 @@ TypeId TypeChecker::addType(const UnionTypeVar& utv) TypeId TypeChecker::addTV(TypeVar&& tv) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(std::move(tv)); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addType(std::move(tv)); } TypePackId TypeChecker::addTypePack(TypePackVar&& tv) { - TypePackId allocated = currentModule->internalTypes.typePacks.allocate(std::move(tv)); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addTypePack(std::move(tv)); } TypePackId TypeChecker::addTypePack(TypePack&& tp) @@ -4578,7 +4600,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (FFlag::DebugLuauMagicTypes && lit->name == "_luau_print") { - if (lit->generics.size != 1) + if (lit->parameters.size != 1 || !lit->parameters.data[0].type) { reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}}); return addType(ErrorTypeVar{}); @@ -4588,7 +4610,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation opts.exhaustive = true; opts.maxTableLength = 0; - TypeId param = resolveType(scope, *lit->generics.data[0]); + TypeId param = resolveType(scope, *lit->parameters.data[0].type); luauPrintLine(format("_luau_print\t%s\t|\t%s", toString(param, opts).c_str(), toString(lit->location).c_str())); return param; } @@ -4614,18 +4636,86 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return addType(ErrorTypeVar{}); } - if (lit->generics.size == 0 && tf->typeParams.empty()) - return tf->type; - else if (lit->generics.size != tf->typeParams.size()) + if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty())) { - reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->generics.size}}); + return tf->type; + } + else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size()) + { + reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}}); return addType(ErrorTypeVar{}); } + else if (FFlag::LuauTypeAliasPacks) + { + if (!lit->hasParameterList && !tf->typePackParams.empty()) + { + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + return addType(ErrorTypeVar{}); + } + + std::vector typeParams; + std::vector extraTypes; + std::vector typePackParams; + + for (size_t i = 0; i < lit->parameters.size; ++i) + { + if (AstType* type = lit->parameters.data[i].type) + { + TypeId ty = resolveType(scope, *type); + + if (typeParams.size() < tf->typeParams.size() || tf->typePackParams.empty()) + typeParams.push_back(ty); + else if (typePackParams.empty()) + extraTypes.push_back(ty); + else + reportError(TypeError{annotation.location, GenericError{"Type parameters must come before type pack parameters"}}); + } + else if (AstTypePack* typePack = lit->parameters.data[i].typePack) + { + TypePackId tp = resolveTypePack(scope, *typePack); + + // If we have collected an implicit type pack, materialize it + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); + + // If we need more regular types, we can use single element type packs to fill those in + if (typeParams.size() < tf->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) + typeParams.push_back(*first(tp)); + else + typePackParams.push_back(tp); + } + } + + // If we still haven't meterialized an implicit type pack, do it now + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); + + // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack + if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) + typePackParams.push_back(addTypePack({})); + + if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) + { + reportError( + TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); + return addType(ErrorTypeVar{}); + } + + if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) + { + // If the generic parameters and the type arguments are the same, we are about to + // perform an identity substitution, which we can just short-circuit. + return tf->type; + } + + return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); + } else { std::vector typeParams; - for (AstType* paramAnnot : lit->generics) - typeParams.push_back(resolveType(scope, *paramAnnot)); + + for (const auto& param : lit->parameters) + typeParams.push_back(resolveType(scope, *param.type)); if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) { @@ -4634,7 +4724,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return tf->type; } - return instantiateTypeFun(scope, *tf, typeParams, annotation.location); + return instantiateTypeFun(scope, *tf, typeParams, {}, annotation.location); } } else if (const auto& table = annotation.as()) @@ -4765,6 +4855,18 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack return *genericTy; } + else if (const AstTypePackExplicit* explicitTp = annotation.as()) + { + std::vector types; + + for (auto type : explicitTp->typeList.types) + types.push_back(resolveType(scope, *type)); + + if (auto tailType = explicitTp->typeList.tailType) + return addTypePack(types, resolveTypePack(scope, *tailType)); + + return addTypePack(types); + } else { ice("Unknown AstTypePack kind"); @@ -4799,12 +4901,28 @@ bool ApplyTypeFunction::isDirty(TypePackId tp) return false; } +bool ApplyTypeFunction::ignoreChildren(TypeId ty) +{ + if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get(ty)) + return true; + else + return false; +} + +bool ApplyTypeFunction::ignoreChildren(TypePackId tp) +{ + if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get(tp)) + return true; + else + return false; +} + TypeId ApplyTypeFunction::clean(TypeId ty) { // Really this should just replace the arguments, // but for bug-compatibility with existing code, we replace // all generics by free type variables. - TypeId& arg = arguments[ty]; + TypeId& arg = typeArguments[ty]; if (arg) return arg; else @@ -4816,17 +4934,37 @@ TypePackId ApplyTypeFunction::clean(TypePackId tp) // Really this should just replace the arguments, // but for bug-compatibility with existing code, we replace // all generics by free type variables. - return addTypePack(FreeTypePack{level}); + if (FFlag::LuauTypeAliasPacks) + { + TypePackId& arg = typePackArguments[tp]; + if (arg) + return arg; + else + return addTypePack(FreeTypePack{level}); + } + else + { + return addTypePack(FreeTypePack{level}); + } } -TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const Location& location) +TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, + const std::vector& typePackParams, const Location& location) { - if (tf.typeParams.empty()) + if (tf.typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf.typePackParams.empty())) return tf.type; - applyTypeFunction.arguments.clear(); + applyTypeFunction.typeArguments.clear(); for (size_t i = 0; i < tf.typeParams.size(); ++i) - applyTypeFunction.arguments[tf.typeParams[i]] = typeParams[i]; + applyTypeFunction.typeArguments[tf.typeParams[i]] = typeParams[i]; + + if (FFlag::LuauTypeAliasPacks) + { + applyTypeFunction.typePackArguments.clear(); + for (size_t i = 0; i < tf.typePackParams.size(); ++i) + applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; + } + applyTypeFunction.currentModule = currentModule; applyTypeFunction.level = scope->level; applyTypeFunction.encounteredForwardedType = false; @@ -4875,6 +5013,9 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (ttv) { ttv->instantiatedTypeParams = typeParams; + + if (FFlag::LuauTypeAliasPacks) + ttv->instantiatedTypePackParams = typePackParams; } } else @@ -4890,6 +5031,9 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } ttv->instantiatedTypeParams = typeParams; + + if (FFlag::LuauTypeAliasPacks) + ttv->instantiatedTypePackParams = typePackParams; } } @@ -4899,6 +5043,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, std::pair, std::vector> TypeChecker::createGenericTypes( const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) { + LUAU_ASSERT(scope->parent); + std::vector generics; for (const AstName& generic : genericNames) { @@ -4912,7 +5058,19 @@ std::pair, std::vector> TypeChecker::createGener reportError(TypeError{node.location, DuplicateGenericParameter{n}}); } - TypeId g = addType(Unifiable::Generic{scope->level, n}); + TypeId g; + if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks) + { + TypeId& cached = scope->parent->typeAliasTypeParameters[n]; + if (!cached) + cached = addType(GenericTypeVar{scope->level, n}); + g = cached; + } + else + { + g = addType(Unifiable::Generic{scope->level, n}); + } + generics.push_back(g); scope->privateTypeBindings[n] = TypeFun{{}, g}; } @@ -4930,7 +5088,19 @@ std::pair, std::vector> TypeChecker::createGener reportError(TypeError{node.location, DuplicateGenericParameter{n}}); } - TypePackId g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + TypePackId g; + if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks) + { + TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; + if (!cached) + cached = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + g = cached; + } + else + { + g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + } + genericPacks.push_back(g); scope->privateTypePackBindings[n] = g; } @@ -5013,13 +5183,8 @@ void TypeChecker::resolve(const Predicate& predicate, ErrorVec& errVec, Refineme else if (auto isaP = get(predicate)) resolve(*isaP, errVec, refis, scope, sense); else if (auto typeguardP = get(predicate)) - { - if (FFlag::LuauImprovedTypeGuardPredicate2) - resolve(*typeguardP, errVec, refis, scope, sense); - else - DEPRECATED_resolve(*typeguardP, errVec, refis, scope, sense); - } - else if (auto eqP = get(predicate); eqP && FFlag::LuauEqConstraint) + resolve(*typeguardP, errVec, refis, scope, sense); + else if (auto eqP = get(predicate)) resolve(*eqP, errVec, refis, scope, sense); else ice("Unhandled predicate kind"); @@ -5145,7 +5310,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return isaP.ty; } } - else if (FFlag::LuauImprovedTypeGuardPredicate2) + else { auto lctv = get(option); auto rctv = get(isaP.ty); @@ -5159,19 +5324,6 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement if (canUnify(option, isaP.ty, isaP.location).empty() == sense) return isaP.ty; } - else - { - auto lctv = get(option); - auto rctv = get(isaP.ty); - - if (lctv && rctv) - { - if (isSubclass(lctv, rctv) == sense) - return option; - else if (isSubclass(rctv, lctv) == sense) - return isaP.ty; - } - } return std::nullopt; }; @@ -5266,7 +5418,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); auto typeFun = globalScope->lookupType(typeguardP.kind); - if (!typeFun || !typeFun->typeParams.empty()) + if (!typeFun || !typeFun->typeParams.empty() || (FFlag::LuauTypeAliasPacks && !typeFun->typePackParams.empty())) return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); TypeId type = follow(typeFun->type); @@ -5292,7 +5444,8 @@ void TypeChecker::DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, Error "userdata", // no op. Requires special handling. }; - if (auto typeFun = globalScope->lookupType(typeguardP.kind); typeFun && typeFun->typeParams.empty()) + if (auto typeFun = globalScope->lookupType(typeguardP.kind); + typeFun && typeFun->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || typeFun->typePackParams.empty())) { if (auto it = std::find(primitives.begin(), primitives.end(), typeguardP.kind); it != primitives.end()) addRefinement(refis, typeguardP.lvalue, typeFun->type); @@ -5319,38 +5472,41 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa return; } - std::optional ty = resolveLValue(refis, scope, eqP.lvalue); - if (!ty) - return; - - std::vector lhs = options(*ty); - std::vector rhs = options(eqP.type); - - if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) + if (FFlag::LuauEqConstraint) { - addRefinement(refis, eqP.lvalue, eqP.type); - return; - } - else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) - return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. + std::optional ty = resolveLValue(refis, scope, eqP.lvalue); + if (!ty) + return; - std::unordered_set set; - for (TypeId left : lhs) - { - for (TypeId right : rhs) + std::vector lhs = options(*ty); + std::vector rhs = options(eqP.type); + + if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) { - // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. - if (canUnify(left, right, eqP.location).empty() == sense || (!sense && !isNil(left))) - set.insert(left); + addRefinement(refis, eqP.lvalue, eqP.type); + return; } + else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. + + std::unordered_set set; + for (TypeId left : lhs) + { + for (TypeId right : rhs) + { + // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. + if (canUnify(left, right, eqP.location).empty() == sense || (!sense && !isNil(left))) + set.insert(left); + } + } + + if (set.empty()) + return; + + std::vector viable(set.begin(), set.end()); + TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); + addRefinement(refis, eqP.lvalue, result); } - - if (set.empty()) - return; - - std::vector viable(set.begin(), set.end()); - TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); - addRefinement(refis, eqP.lvalue, result); } bool TypeChecker::isNonstrictMode() const @@ -5379,119 +5535,4 @@ std::vector> TypeChecker::getScopes() const return currentModule->scopes; } -Scope::Scope(TypePackId returnType) - : parent(nullptr) - , returnType(returnType) - , level(TypeLevel()) -{ -} - -Scope::Scope(const ScopePtr& parent, int subLevel) - : parent(parent) - , returnType(parent->returnType) - , level(parent->level.incr()) -{ - level.subLevel = subLevel; -} - -std::optional Scope::lookup(const Symbol& name) -{ - Scope* scope = this; - - while (scope) - { - auto it = scope->bindings.find(name); - if (it != scope->bindings.end()) - return it->second.typeId; - - scope = scope->parent.get(); - } - - return std::nullopt; -} - -std::optional Scope::lookupType(const Name& name) -{ - const Scope* scope = this; - while (true) - { - auto it = scope->exportedTypeBindings.find(name); - if (it != scope->exportedTypeBindings.end()) - return it->second; - - it = scope->privateTypeBindings.find(name); - if (it != scope->privateTypeBindings.end()) - return it->second; - - if (scope->parent) - scope = scope->parent.get(); - else - return std::nullopt; - } -} - -std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) -{ - const Scope* scope = this; - while (scope) - { - auto it = scope->importedTypeBindings.find(moduleAlias); - if (it == scope->importedTypeBindings.end()) - { - scope = scope->parent.get(); - continue; - } - - auto it2 = it->second.find(name); - if (it2 == it->second.end()) - { - scope = scope->parent.get(); - continue; - } - - return it2->second; - } - - return std::nullopt; -} - -std::optional Scope::lookupPack(const Name& name) -{ - const Scope* scope = this; - while (true) - { - auto it = scope->privateTypePackBindings.find(name); - if (it != scope->privateTypePackBindings.end()) - return it->second; - - if (scope->parent) - scope = scope->parent.get(); - else - return std::nullopt; - } -} - -std::optional Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain) -{ - Scope* scope = this; - - while (scope) - { - for (const auto& [n, binding] : scope->bindings) - { - if (n.local && n.local->name == name.c_str()) - return binding; - else if (n.global.value && n.global == name.c_str()) - return binding; - } - - scope = scope->parent.get(); - - if (!traverseScopeChain) - break; - } - - return std::nullopt; -} - } // namespace Luau diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 5970f30..68a16ef 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -209,6 +209,19 @@ size_t size(TypePackId tp) return 0; } +bool finite(TypePackId tp) +{ + tp = follow(tp); + + if (auto pack = get(tp)) + return pack->tail ? finite(*pack->tail) : true; + + if (auto pack = get(tp)) + return false; + + return true; +} + size_t size(const TypePack& tp) { size_t result = tp.head.size(); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index b9f5097..0d9d91e 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -1,11 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeUtils.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" -LUAU_FASTFLAG(LuauStringMetatable) - namespace Luau { @@ -13,21 +12,6 @@ std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globa { type = follow(type); - if (!FFlag::LuauStringMetatable) - { - if (const PrimitiveTypeVar* primType = get(type)) - { - if (primType->type != PrimitiveTypeVar::String || "__index" != entry) - return std::nullopt; - - auto it = globalScope->bindings.find(AstName{"string"}); - if (it != globalScope->bindings.end()) - return it->second.typeId; - else - return std::nullopt; - } - } - std::optional metatable = getMetatable(type); if (!metatable) return std::nullopt; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 111f4f5..e963fc7 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -19,11 +19,9 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) -LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2) -LUAU_FASTFLAGVARIABLE(LuauToStringFollowsBoundTo, false) LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAGVARIABLE(LuauStringMetatable, false) LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -193,27 +191,11 @@ bool isOptional(TypeId ty) bool isTableIntersection(TypeId ty) { - if (FFlag::LuauImprovedTypeGuardPredicate2) - { - if (!get(follow(ty))) - return false; - - std::vector parts = flattenIntersection(ty); - return std::all_of(parts.begin(), parts.end(), getTableType); - } - else - { - if (const IntersectionTypeVar* itv = get(ty)) - { - for (TypeId part : itv->parts) - { - if (getTableType(follow(part))) - return true; - } - } - + if (!get(follow(ty))) return false; - } + + std::vector parts = flattenIntersection(ty); + return std::all_of(parts.begin(), parts.end(), getTableType); } bool isOverloadedFunction(TypeId ty) @@ -236,7 +218,7 @@ std::optional getMetatable(TypeId type) else if (const ClassTypeVar* classType = get(type)) return classType->metatable; else if (const PrimitiveTypeVar* primitiveType = get(type); - FFlag::LuauStringMetatable && primitiveType && primitiveType->metatable) + primitiveType && primitiveType->metatable) { LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); return primitiveType->metatable; @@ -871,6 +853,12 @@ void StateDot::visitChildren(TypeId ty, int index) } for (TypeId itp : ttv->instantiatedTypeParams) visitChild(itp, index, "typeParam"); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp, index, "typePackParam"); + } } else if (const MetatableTypeVar* mtv = get(ty)) { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 9679599..56f2515 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -3,23 +3,25 @@ #include "Luau/Common.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" +#include "Luau/TimeTrace.h" #include LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 0); -LUAU_FASTFLAGVARIABLE(LuauLogTableTypeVarBoundTo, false) +LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false); LUAU_FASTFLAGVARIABLE(LuauDontMutatePersistentFunctions, false) LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauStringMetatable) LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) +LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) namespace Luau { @@ -43,21 +45,23 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati , globalScope(std::move(globalScope)) , location(location) , variance(variance) - , counters(std::make_shared()) + , counters(&countersData) + , counters_DEPRECATED(std::make_shared()) , iceHandler(iceHandler) { LUAU_ASSERT(iceHandler); } Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters) + Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) : types(types) , mode(mode) , globalScope(std::move(globalScope)) , log(seen) , location(location) , variance(variance) - , counters(counters ? counters : std::make_shared()) + , counters(counters ? counters : &countersData) + , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) , iceHandler(iceHandler) { LUAU_ASSERT(iceHandler); @@ -65,16 +69,26 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::v void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) { - counters->iterationCount = 0; + if (FFlag::LuauTypecheckOpts) + counters->iterationCount = 0; + else + counters_DEPRECATED->iterationCount = 0; + return tryUnify_(superTy, subTy, isFunctionCall, isIntersection); } void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) { - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); - ++counters->iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < counters->iterationCount) + if (FFlag::LuauTypecheckOpts) + ++counters->iterationCount; + else + ++counters_DEPRECATED->iterationCount; + + if (FInt::LuauTypeInferIterationLimit > 0 && + FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount)) { errors.push_back(TypeError{location, UnificationTooComplex{}}); return; @@ -440,7 +454,11 @@ ErrorVec Unifier::canUnify(TypePackId superTy, TypePackId subTy, bool isFunction void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall) { - counters->iterationCount = 0; + if (FFlag::LuauTypecheckOpts) + counters->iterationCount = 0; + else + counters_DEPRECATED->iterationCount = 0; + return tryUnify_(superTp, subTp, isFunctionCall); } @@ -450,10 +468,16 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall */ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCall) { - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); - ++counters->iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < counters->iterationCount) + if (FFlag::LuauTypecheckOpts) + ++counters->iterationCount; + else + ++counters_DEPRECATED->iterationCount; + + if (FInt::LuauTypeInferIterationLimit > 0 && + FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount)) { errors.push_back(TypeError{location, UnificationTooComplex{}}); return; @@ -762,9 +786,210 @@ struct Resetter void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) { - std::unique_ptr resetter; + if (!FFlag::LuauTableSubtypingVariance) + return DEPRECATED_tryUnifyTables(left, right, isIntersection); - resetter.reset(new Resetter{&variance}); + TableTypeVar* lt = getMutable(left); + TableTypeVar* rt = getMutable(right); + if (!lt || !rt) + ice("passed non-table types to unifyTables"); + + std::vector missingProperties; + std::vector extraProperties; + + // Reminder: left is the supertype, right is the subtype. + // Width subtyping: any property in the supertype must be in the subtype, + // and the types must agree. + for (const auto& [name, prop] : lt->props) + { + const auto& r = rt->props.find(name); + if (r != rt->props.end()) + { + // TODO: read-only properties don't need invariance + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, r->second.type); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (rt->indexer && isString(rt->indexer->indexType)) + { + // TODO: read-only indexers don't need invariance + // TODO: really we should only allow this if prop.type is optional. + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, rt->indexer->indexResultType); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (isOptional(prop.type) || get(follow(prop.type))) + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + {} + else if (rt->state == TableState::Free) + { + log(rt); + rt->props[name] = prop; + } + else + missingProperties.push_back(name); + } + + for (const auto& [name, prop] : rt->props) + { + if (lt->props.count(name)) + { + // If both lt and rt contain the property, then + // we're done since we already unified them above + } + else if (lt->indexer && isString(lt->indexer->indexType)) + { + // TODO: read-only indexers don't need invariance + // TODO: really we should only allow this if prop.type is optional. + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, lt->indexer->indexResultType); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (lt->state == TableState::Unsealed) + { + // TODO: this case is unsound when variance is Invariant, but without it lua-apps fails to typecheck. + // TODO: file a JIRA + // TODO: hopefully readonly/writeonly properties will fix this. + Property clone = prop; + clone.type = deeplyOptional(clone.type); + log(lt); + lt->props[name] = clone; + } + else if (variance == Covariant) + {} + else if (isOptional(prop.type) || get(follow(prop.type))) + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + {} + else if (lt->state == TableState::Free) + { + log(lt); + lt->props[name] = prop; + } + else + extraProperties.push_back(name); + } + + // Unify indexers + if (lt->indexer && rt->indexer) + { + // TODO: read-only indexers don't need invariance + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify(*lt->indexer, *rt->indexer); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (lt->indexer) + { + if (rt->state == TableState::Unsealed || rt->state == TableState::Free) + { + // passing/assigning a table without an indexer to something that has one + // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. + // TODO: we only need to do this if the supertype's indexer is read/write + // since that can add indexed elements. + log(rt); + rt->indexer = lt->indexer; + } + } + else if (rt->indexer && variance == Invariant) + { + // Symmetric if we are invariant + if (lt->state == TableState::Unsealed || lt->state == TableState::Free) + { + log(lt); + lt->indexer = rt->indexer; + } + } + + if (!missingProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}}); + return; + } + + if (!extraProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}}); + return; + } + + /* + * TypeVars are commonly cyclic, so it is entirely possible + * for unifying a property of a table to change the table itself! + * We need to check for this and start over if we notice this occurring. + * + * I believe this is guaranteed to terminate eventually because this will + * only happen when a free table is bound to another table. + */ + if (lt->boundTo || rt->boundTo) + return tryUnify_(left, right); + + if (lt->state == TableState::Free) + { + log(lt); + lt->boundTo = right; + } + else if (rt->state == TableState::Free) + { + log(rt); + rt->boundTo = left; + } +} + +TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) +{ + ty = follow(ty); + if (get(ty)) + return ty; + else if (isOptional(ty)) + return ty; + else if (const TableTypeVar* ttv = get(ty)) + { + TypeId& result = seen[ty]; + if (result) + return result; + result = types->addType(*ttv); + TableTypeVar* resultTtv = getMutable(result); + for (auto& [name, prop] : resultTtv->props) + prop.type = deeplyOptional(prop.type, seen); + return types->addType(UnionTypeVar{{ singletonTypes.nilType, result }});; + } + else + return types->addType(UnionTypeVar{{ singletonTypes.nilType, ty }}); +} + +void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) +{ + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance); + Resetter resetter{&variance}; variance = Invariant; TableTypeVar* lt = getMutable(left); @@ -894,10 +1119,7 @@ void Unifier::tryUnifyFreeTable(TypeId freeTypeId, TypeId otherTypeId) if (!freeTable->boundTo && otherTable->state != TableState::Free) { - if (FFlag::LuauLogTableTypeVarBoundTo) - log(freeTable); - else - log(freeTypeId); + log(freeTable); freeTable->boundTo = otherTypeId; } } @@ -1196,9 +1418,11 @@ void Unifier::tryUnify(const TableIndexer& superIndexer, const TableIndexer& sub tryUnify_(superIndexer.indexResultType, subIndexer.indexResultType); } -static void queueTypePack( +static void queueTypePack_DEPRECATED( std::vector& queue, std::unordered_set& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) { + LUAU_ASSERT(!FFlag::LuauTypecheckOpts); + while (true) { if (FFlag::LuauAddMissingFollow) @@ -1244,6 +1468,55 @@ static void queueTypePack( } } +static void queueTypePack(std::vector& queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) +{ + LUAU_ASSERT(FFlag::LuauTypecheckOpts); + + while (true) + { + if (FFlag::LuauAddMissingFollow) + a = follow(a); + + if (seenTypePacks.find(a)) + break; + seenTypePacks.insert(a); + + if (FFlag::LuauAddMissingFollow) + { + if (get(a)) + { + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; + } + else if (auto tp = get(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } + } + else + { + if (get(a)) + { + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; + } + + if (auto tp = get(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } + } + } +} + void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool reversed, int subOffset) { const VariadicTypePack* lv = get(superTp); @@ -1297,9 +1570,11 @@ void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool rever } } -static void tryUnifyWithAny( +static void tryUnifyWithAny_DEPRECATED( std::vector& queue, Unifier& state, std::unordered_set& seenTypePacks, TypeId anyType, TypePackId anyTypePack) { + LUAU_ASSERT(!FFlag::LuauTypecheckOpts); + std::unordered_set seen; while (!queue.empty()) @@ -1310,6 +1585,59 @@ static void tryUnifyWithAny( continue; seen.insert(ty); + if (get(ty)) + { + state.log(ty); + *asMutable(ty) = BoundTypeVar{anyType}; + } + else if (auto fun = get(ty)) + { + queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->argTypes, anyTypePack); + queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->retType, anyTypePack); + } + else if (auto table = get(ty)) + { + for (const auto& [_name, prop] : table->props) + queue.push_back(prop.type); + + if (table->indexer) + { + queue.push_back(table->indexer->indexType); + queue.push_back(table->indexer->indexResultType); + } + } + else if (auto mt = get(ty)) + { + queue.push_back(mt->table); + queue.push_back(mt->metatable); + } + else if (get(ty)) + { + // ClassTypeVars never contain free typevars. + } + else if (auto union_ = get(ty)) + queue.insert(queue.end(), union_->options.begin(), union_->options.end()); + else if (auto intersection = get(ty)) + queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); + else + { + } // Primitives, any, errors, and generics are left untouched. + } +} + +static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHashSet& seen, DenseHashSet& seenTypePacks, + TypeId anyType, TypePackId anyTypePack) +{ + LUAU_ASSERT(FFlag::LuauTypecheckOpts); + + while (!queue.empty()) + { + TypeId ty = follow(queue.back()); + queue.pop_back(); + if (seen.find(ty)) + continue; + seen.insert(ty); + if (get(ty)) { state.log(ty); @@ -1354,14 +1682,33 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) { LUAU_ASSERT(get(any) || get(any)); + if (FFlag::LuauTypecheckOpts) + { + // These types are not visited in general loop below + if (get(ty) || get(ty) || get(ty)) + return; + } + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}}); const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); - std::unordered_set seenTypePacks; - std::vector queue = {ty}; + if (FFlag::LuauTypecheckOpts) + { + std::vector queue = {ty}; - Luau::tryUnifyWithAny(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP); + tempSeenTy.clear(); + tempSeenTp.clear(); + + Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, singletonTypes.anyType, anyTP); + } + else + { + std::unordered_set seenTypePacks; + std::vector queue = {ty}; + + Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP); + } } void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) @@ -1370,12 +1717,26 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) const TypeId anyTy = singletonTypes.errorType; - std::unordered_set seenTypePacks; - std::vector queue; + if (FFlag::LuauTypecheckOpts) + { + std::vector queue; - queueTypePack(queue, seenTypePacks, *this, ty, any); + tempSeenTy.clear(); + tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, seenTypePacks, anyTy, any); + queueTypePack(queue, tempSeenTp, *this, ty, any); + + Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, anyTy, any); + } + else + { + std::unordered_set seenTypePacks; + std::vector queue; + + queueTypePack_DEPRECATED(queue, seenTypePacks, *this, ty, any); + + Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, anyTy, any); + } } std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) @@ -1387,21 +1748,6 @@ std::optional Unifier::findMetatableEntry(TypeId type, std::string entry { type = follow(type); - if (!FFlag::LuauStringMetatable) - { - if (const PrimitiveTypeVar* primType = get(type)) - { - if (primType->type != PrimitiveTypeVar::String || "__index" != entry) - return std::nullopt; - - auto found = globalScope->bindings.find(AstName{"string"}); - if (found == globalScope->bindings.end()) - return std::nullopt; - else - return found->second.typeId; - } - } - std::optional metatable = getMetatable(type); if (!metatable) return std::nullopt; @@ -1427,21 +1773,36 @@ std::optional Unifier::findMetatableEntry(TypeId type, std::string entry void Unifier::occursCheck(TypeId needle, TypeId haystack) { - std::unordered_set seen; - return occursCheck(seen, needle, haystack); + std::unordered_set seen_DEPRECATED; + + if (FFlag::LuauTypecheckOpts) + tempSeenTy.clear(); + + return occursCheck(seen_DEPRECATED, tempSeenTy, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen, TypeId needle, TypeId haystack) +void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack) { - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); needle = follow(needle); haystack = follow(haystack); - if (seen.end() != seen.find(haystack)) - return; + if (FFlag::LuauTypecheckOpts) + { + if (seen.find(haystack)) + return; - seen.insert(haystack); + seen.insert(haystack); + } + else + { + if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack)) + return; + + seen_DEPRECATED.insert(haystack); + } if (get(needle)) return; @@ -1458,7 +1819,7 @@ void Unifier::occursCheck(std::unordered_set& seen, TypeId needle, TypeI } auto check = [&](TypeId tv) { - occursCheck(seen, needle, tv); + occursCheck(seen_DEPRECATED, seen, needle, tv); }; if (get(haystack)) @@ -1488,19 +1849,33 @@ void Unifier::occursCheck(std::unordered_set& seen, TypeId needle, TypeI void Unifier::occursCheck(TypePackId needle, TypePackId haystack) { - std::unordered_set seen; - return occursCheck(seen, needle, haystack); + std::unordered_set seen_DEPRECATED; + + if (FFlag::LuauTypecheckOpts) + tempSeenTp.clear(); + + return occursCheck(seen_DEPRECATED, tempSeenTp, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen, TypePackId needle, TypePackId haystack) +void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack) { needle = follow(needle); haystack = follow(haystack); - if (seen.find(haystack) != seen.end()) - return; + if (FFlag::LuauTypecheckOpts) + { + if (seen.find(haystack)) + return; - seen.insert(haystack); + seen.insert(haystack); + } + else + { + if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack)) + return; + + seen_DEPRECATED.insert(haystack); + } if (get(needle)) return; @@ -1508,7 +1883,8 @@ void Unifier::occursCheck(std::unordered_set& seen, TypePackId needl if (!get(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); while (!get(haystack)) { @@ -1528,8 +1904,8 @@ void Unifier::occursCheck(std::unordered_set& seen, TypePackId needl { if (auto f = get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)) { - occursCheck(seen, needle, f->argTypes); - occursCheck(seen, needle, f->retType); + occursCheck(seen_DEPRECATED, seen, needle, f->argTypes); + occursCheck(seen_DEPRECATED, seen, needle, f->retType); } } } @@ -1546,7 +1922,7 @@ void Unifier::occursCheck(std::unordered_set& seen, TypePackId needl Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters}; + return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters_DEPRECATED, counters}; } bool Unifier::isNonstrictMode() const diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index df38cfe..a2189f7 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -264,6 +264,10 @@ public: { return false; } + virtual bool visit(class AstTypePackExplicit* node) + { + return visit((class AstTypePack*)node); + } virtual bool visit(class AstTypePackVariadic* node) { return visit((class AstTypePack*)node); @@ -930,12 +934,14 @@ class AstStatTypeAlias : public AstStat public: LUAU_RTTI(AstStatTypeAlias) - AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, AstType* type, bool exported); + AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, const AstArray& genericPacks, + AstType* type, bool exported); void visit(AstVisitor* visitor) override; AstName name; AstArray generics; + AstArray genericPacks; AstType* type; bool exported; }; @@ -1007,19 +1013,28 @@ public: } }; +// Don't have Luau::Variant available, it's a bit of an overhead, but a plain struct is nice to use +struct AstTypeOrPack +{ + AstType* type = nullptr; + AstTypePack* typePack = nullptr; +}; + class AstTypeReference : public AstType { public: LUAU_RTTI(AstTypeReference) - AstTypeReference(const Location& location, std::optional prefix, AstName name, const AstArray& generics = {}); + AstTypeReference(const Location& location, std::optional prefix, AstName name, bool hasParameterList = false, + const AstArray& parameters = {}); void visit(AstVisitor* visitor) override; bool hasPrefix; + bool hasParameterList; AstName prefix; AstName name; - AstArray generics; + AstArray parameters; }; struct AstTableProp @@ -1152,6 +1167,18 @@ public: } }; +class AstTypePackExplicit : public AstTypePack +{ +public: + LUAU_RTTI(AstTypePackExplicit) + + AstTypePackExplicit(const Location& location, AstTypeList typeList); + + void visit(AstVisitor* visitor) override; + + AstTypeList typeList; +}; + class AstTypePackVariadic : public AstTypePack { public: diff --git a/Ast/include/Luau/DenseHash.h b/Ast/include/Luau/DenseHash.h index 02924e8..a7b2515 100644 --- a/Ast/include/Luau/DenseHash.h +++ b/Ast/include/Luau/DenseHash.h @@ -136,7 +136,10 @@ public: const Key& key = ItemInterface::getKey(data[i]); if (!eq(key, empty_key)) - *newtable.insert_unsafe(key) = data[i]; + { + Item* item = newtable.insert_unsafe(key); + *item = std::move(data[i]); + } } LUAU_ASSERT(count == newtable.count); diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index e6ebd50..42c64dc 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -218,13 +218,14 @@ private: AstTableIndexer* parseTableIndexerAnnotation(); - AstType* parseFunctionTypeAnnotation(); + AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack); AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation); AstType* parseTableTypeAnnotation(); - AstType* parseSimpleTypeAnnotation(); + AstTypeOrPack parseSimpleTypeAnnotation(bool allowPack); + AstTypeOrPack parseTypeOrPackAnnotation(); AstType* parseTypeAnnotation(TempVector& parts, const Location& begin); AstType* parseTypeAnnotation(); @@ -284,7 +285,7 @@ private: std::pair, AstArray> parseGenericTypeListIfFFlagParseGenericFunctions(); // `<' typeAnnotation[, ...] `>' - AstArray parseTypeParams(); + AstArray parseTypeParams(); AstExpr* parseString(); @@ -413,6 +414,7 @@ private: std::vector scratchLocal; std::vector scratchTableTypeProps; std::vector scratchAnnotation; + std::vector scratchTypeOrPackAnnotation; std::vector scratchDeclaredClassProps; std::vector scratchItem; std::vector scratchArgName; diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h new file mode 100644 index 0000000..641dfd3 --- /dev/null +++ b/Ast/include/Luau/TimeTrace.h @@ -0,0 +1,223 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Common.h" + +#include + +#include + +LUAU_FASTFLAG(DebugLuauTimeTracing) + +#if defined(LUAU_ENABLE_TIME_TRACE) + +namespace Luau +{ +namespace TimeTrace +{ +uint32_t getClockMicroseconds(); + +struct Token +{ + const char* name; + const char* category; +}; + +enum class EventType : uint8_t +{ + Enter, + Leave, + + ArgName, + ArgValue, +}; + +struct Event +{ + EventType type; + uint16_t token; + + union + { + uint32_t microsec; // 1 hour trace limit + uint32_t dataPos; + } data; +}; + +struct GlobalContext; +struct ThreadContext; + +GlobalContext& getGlobalContext(); + +uint16_t createToken(GlobalContext& context, const char* name, const char* category); +uint32_t createThread(GlobalContext& context, ThreadContext* threadContext); +void releaseThread(GlobalContext& context, ThreadContext* threadContext); +void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector& events, const std::vector& data); + +struct ThreadContext +{ + ThreadContext() + : globalContext(getGlobalContext()) + { + threadId = createThread(globalContext, this); + } + + ~ThreadContext() + { + if (!events.empty()) + flushEvents(); + + releaseThread(globalContext, this); + } + + void flushEvents() + { + static uint16_t flushToken = createToken(globalContext, "flushEvents", "TimeTrace"); + + events.push_back({EventType::Enter, flushToken, {getClockMicroseconds()}}); + + TimeTrace::flushEvents(globalContext, threadId, events, data); + + events.clear(); + data.clear(); + + events.push_back({EventType::Leave, 0, {getClockMicroseconds()}}); + } + + void eventEnter(uint16_t token) + { + eventEnter(token, getClockMicroseconds()); + } + + void eventEnter(uint16_t token, uint32_t microsec) + { + events.push_back({EventType::Enter, token, {microsec}}); + } + + void eventLeave() + { + eventLeave(getClockMicroseconds()); + } + + void eventLeave(uint32_t microsec) + { + events.push_back({EventType::Leave, 0, {microsec}}); + + if (events.size() > kEventFlushLimit) + flushEvents(); + } + + void eventArgument(const char* name, const char* value) + { + uint32_t pos = uint32_t(data.size()); + data.insert(data.end(), name, name + strlen(name) + 1); + events.push_back({EventType::ArgName, 0, {pos}}); + + pos = uint32_t(data.size()); + data.insert(data.end(), value, value + strlen(value) + 1); + events.push_back({EventType::ArgValue, 0, {pos}}); + } + + GlobalContext& globalContext; + uint32_t threadId; + std::vector events; + std::vector data; + + static constexpr size_t kEventFlushLimit = 8192; +}; + +ThreadContext& getThreadContext(); + +struct Scope +{ + explicit Scope(ThreadContext& context, uint16_t token) + : context(context) + { + if (!FFlag::DebugLuauTimeTracing) + return; + + context.eventEnter(token); + } + + ~Scope() + { + if (!FFlag::DebugLuauTimeTracing) + return; + + context.eventLeave(); + } + + ThreadContext& context; +}; + +struct OptionalTailScope +{ + explicit OptionalTailScope(ThreadContext& context, uint16_t token, uint32_t threshold) + : context(context) + , token(token) + , threshold(threshold) + { + if (!FFlag::DebugLuauTimeTracing) + return; + + pos = uint32_t(context.events.size()); + microsec = getClockMicroseconds(); + } + + ~OptionalTailScope() + { + if (!FFlag::DebugLuauTimeTracing) + return; + + if (pos == context.events.size()) + { + uint32_t curr = getClockMicroseconds(); + + if (curr - microsec > threshold) + { + context.eventEnter(token, microsec); + context.eventLeave(curr); + } + } + } + + ThreadContext& context; + uint16_t token; + uint32_t threshold; + uint32_t microsec; + uint32_t pos; +}; + +LUAU_NOINLINE std::pair createScopeData(const char* name, const char* category); + +} // namespace TimeTrace +} // namespace Luau + +// Regular scope +#define LUAU_TIMETRACE_SCOPE(name, category) \ + static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first) + +// A scope without nested scopes that may be skipped if the time it took is less than the threshold +#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ + static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec) + +// Extra key/value data can be added to regular scopes +#define LUAU_TIMETRACE_ARGUMENT(name, value) \ + do \ + { \ + if (FFlag::DebugLuauTimeTracing) \ + lttScopeStatic.second.eventArgument(name, value); \ + } while (false) + +#else + +#define LUAU_TIMETRACE_SCOPE(name, category) +#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) +#define LUAU_TIMETRACE_ARGUMENT(name, value) \ + do \ + { \ + } while (false) + +#endif diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index fff1537..b1209fa 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -641,10 +641,12 @@ void AstStatLocalFunction::visit(AstVisitor* visitor) func->visit(visitor); } -AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, AstType* type, bool exported) +AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported) : AstStat(ClassIndex(), location) , name(name) , generics(generics) + , genericPacks(genericPacks) , type(type) , exported(exported) { @@ -729,12 +731,14 @@ void AstStatError::visit(AstVisitor* visitor) } } -AstTypeReference::AstTypeReference(const Location& location, std::optional prefix, AstName name, const AstArray& generics) +AstTypeReference::AstTypeReference( + const Location& location, std::optional prefix, AstName name, bool hasParameterList, const AstArray& parameters) : AstType(ClassIndex(), location) , hasPrefix(bool(prefix)) + , hasParameterList(hasParameterList) , prefix(prefix ? *prefix : AstName()) , name(name) - , generics(generics) + , parameters(parameters) { } @@ -742,8 +746,13 @@ void AstTypeReference::visit(AstVisitor* visitor) { if (visitor->visit(this)) { - for (AstType* generic : generics) - generic->visit(visitor); + for (const AstTypeOrPack& param : parameters) + { + if (param.type) + param.type->visit(visitor); + else + param.typePack->visit(visitor); + } } } @@ -849,6 +858,24 @@ void AstTypeError::visit(AstVisitor* visitor) } } +AstTypePackExplicit::AstTypePackExplicit(const Location& location, AstTypeList typeList) + : AstTypePack(ClassIndex(), location) + , typeList(typeList) +{ +} + +void AstTypePackExplicit::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstType* type : typeList.types) + type->visit(visitor); + + if (typeList.tailType) + typeList.tailType->visit(visitor); + } +} + AstTypePackVariadic::AstTypePackVariadic(const Location& location, AstType* variadicType) : AstTypePack(ClassIndex(), location) , variadicType(variadicType) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 9794a03..846bc0b 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/TimeTrace.h" + #include // Warning: If you are introducing new syntax, ensure that it is behind a separate @@ -13,6 +15,8 @@ LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctions, false) LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) +LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) +LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) namespace Luau { @@ -148,6 +152,8 @@ static bool shouldParseTypePackAnnotation(Lexer& lexer) ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options) { + LUAU_TIMETRACE_SCOPE("Parser::parse", "Parser"); + Parser p(buffer, bufferSize, names, allocator); try @@ -769,14 +775,14 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) if (!name) name = Name(nameError, lexer.current().location); - // TODO: support generic type pack parameters in type aliases CLI-39907 auto [generics, genericPacks] = parseGenericTypeList(); expectAndConsume('=', "type alias"); AstType* type = parseTypeAnnotation(); - return allocator.alloc(Location(start, type->location), name->name, generics, type, exported); + return allocator.alloc( + Location(start, type->location), name->name, generics, FFlag::LuauTypeAliasPacks ? genericPacks : AstArray{}, type, exported); } AstDeclaredClassProp Parser::parseDeclaredClassMethod() @@ -1333,7 +1339,7 @@ AstType* Parser::parseTableTypeAnnotation() // ReturnType ::= TypeAnnotation | `(' TypeList `)' // FunctionTypeAnnotation ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstType* Parser::parseFunctionTypeAnnotation() +AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); @@ -1364,14 +1370,23 @@ AstType* Parser::parseFunctionTypeAnnotation() matchRecoveryStopOnToken[Lexeme::SkinnyArrow]--; - // Not a function at all. Just a parenthesized type. - if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow) - return params[0]; - AstArray paramTypes = copy(params); + + // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element + if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow) + { + if (allowPack) + return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; + else + return {params[0], {}}; + } + + if (lexer.current().type != Lexeme::SkinnyArrow && monomorphic && allowPack) + return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, varargAnnotation})}; + AstArray> paramNames = copy(names); - return parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation); + return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, @@ -1421,7 +1436,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (c == '|') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(false).type); isUnion = true; } else if (c == '?') @@ -1434,7 +1449,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location else if (c == '&') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(false).type); isIntersection = true; } else @@ -1462,6 +1477,30 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location ParseError::raise(begin, "Composite type was not an intersection or union."); } +AstTypeOrPack Parser::parseTypeOrPackAnnotation() +{ + unsigned int oldRecursionCount = recursionCounter; + incrementRecursionCounter("type annotation"); + + Location begin = lexer.current().location; + + TempVector parts(scratchAnnotation); + + auto [type, typePack] = parseSimpleTypeAnnotation(true); + + if (typePack) + { + LUAU_ASSERT(!type); + return {{}, typePack}; + } + + parts.push_back(type); + + recursionCounter = oldRecursionCount; + + return {parseTypeAnnotation(parts, begin), {}}; +} + AstType* Parser::parseTypeAnnotation() { unsigned int oldRecursionCount = recursionCounter; @@ -1470,7 +1509,7 @@ AstType* Parser::parseTypeAnnotation() Location begin = lexer.current().location; TempVector parts(scratchAnnotation); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(false).type); recursionCounter = oldRecursionCount; @@ -1479,7 +1518,7 @@ AstType* Parser::parseTypeAnnotation() // typeannotation ::= nil | Name[`.' Name] [ `<' typeannotation [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' // | [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstType* Parser::parseSimpleTypeAnnotation() +AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); @@ -1488,7 +1527,7 @@ AstType* Parser::parseSimpleTypeAnnotation() if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); - return allocator.alloc(begin, std::nullopt, nameNil); + return {allocator.alloc(begin, std::nullopt, nameNil), {}}; } else if (lexer.current().type == Lexeme::Name) { @@ -1514,22 +1553,41 @@ AstType* Parser::parseSimpleTypeAnnotation() expectMatchAndConsume(')', typeofBegin); - return allocator.alloc(Location(begin, end), expr); + return {allocator.alloc(Location(begin, end), expr), {}}; } - AstArray generics = parseTypeParams(); + if (FFlag::LuauParseTypePackTypeParameters) + { + bool hasParameters = false; + AstArray parameters{}; - Location end = lexer.previousLocation(); + if (lexer.current().type == '<') + { + hasParameters = true; + parameters = parseTypeParams(); + } - return allocator.alloc(Location(begin, end), prefix, name.name, generics); + Location end = lexer.previousLocation(); + + return {allocator.alloc(Location(begin, end), prefix, name.name, hasParameters, parameters), {}}; + } + else + { + AstArray generics = parseTypeParams(); + + Location end = lexer.previousLocation(); + + // false in 'hasParameterList' as it is not used without FFlagLuauTypeAliasPacks + return {allocator.alloc(Location(begin, end), prefix, name.name, false, generics), {}}; + } } else if (lexer.current().type == '{') { - return parseTableTypeAnnotation(); + return {parseTableTypeAnnotation(), {}}; } else if (lexer.current().type == '(' || (FFlag::LuauParseGenericFunctions && lexer.current().type == '<')) { - return parseFunctionTypeAnnotation(); + return parseFunctionTypeAnnotation(allowPack); } else { @@ -1538,7 +1596,7 @@ AstType* Parser::parseSimpleTypeAnnotation() // For a missing type annotation, capture 'space' between last token and the next one location = Location(lexer.previousLocation().end, lexer.current().location.begin); - return reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()); + return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}}; } } @@ -2312,18 +2370,59 @@ std::pair, AstArray> Parser::parseGenericTypeList() return {generics, genericPacks}; } -AstArray Parser::parseTypeParams() +AstArray Parser::parseTypeParams() { - TempVector result{scratchAnnotation}; + TempVector parameters{scratchTypeOrPackAnnotation}; if (lexer.current().type == '<') { Lexeme begin = lexer.current(); nextLexeme(); + bool seenPack = false; while (true) { - result.push_back(parseTypeAnnotation()); + if (FFlag::LuauParseTypePackTypeParameters) + { + if (shouldParseTypePackAnnotation(lexer)) + { + seenPack = true; + + auto typePack = parseTypePackAnnotation(); + + if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them + parameters.push_back({{}, typePack}); + } + else if (lexer.current().type == '(') + { + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (typePack) + { + seenPack = true; + + if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them + parameters.push_back({{}, typePack}); + } + else + { + parameters.push_back({type, {}}); + } + } + else if (lexer.current().type == '>' && parameters.empty()) + { + break; + } + else + { + parameters.push_back({parseTypeAnnotation(), {}}); + } + } + else + { + parameters.push_back({parseTypeAnnotation(), {}}); + } + if (lexer.current().type == ',') nextLexeme(); else @@ -2333,7 +2432,7 @@ AstArray Parser::parseTypeParams() expectMatchAndConsume('>', begin); } - return copy(result); + return copy(parameters); } AstExpr* Parser::parseString() diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp new file mode 100644 index 0000000..e6aab20 --- /dev/null +++ b/Ast/src/TimeTrace.cpp @@ -0,0 +1,248 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TimeTrace.h" + +#include "Luau/StringUtils.h" + +#include +#include + +#include + +#ifdef _WIN32 +#include +#endif + +#ifdef __APPLE__ +#include +#include +#endif + +#include + +LUAU_FASTFLAGVARIABLE(DebugLuauTimeTracing, false) + +#if defined(LUAU_ENABLE_TIME_TRACE) + +namespace Luau +{ +namespace TimeTrace +{ +static double getClockPeriod() +{ +#if defined(_WIN32) + LARGE_INTEGER result = {}; + QueryPerformanceFrequency(&result); + return 1.0 / double(result.QuadPart); +#elif defined(__APPLE__) + mach_timebase_info_data_t result = {}; + mach_timebase_info(&result); + return double(result.numer) / double(result.denom) * 1e-9; +#elif defined(__linux__) + return 1e-9; +#else + return 1.0 / double(CLOCKS_PER_SEC); +#endif +} + +static double getClockTimestamp() +{ +#if defined(_WIN32) + LARGE_INTEGER result = {}; + QueryPerformanceCounter(&result); + return double(result.QuadPart); +#elif defined(__APPLE__) + return double(mach_absolute_time()); +#elif defined(__linux__) + timespec now; + clock_gettime(CLOCK_MONOTONIC, &now); + return now.tv_sec * 1e9 + now.tv_nsec; +#else + return double(clock()); +#endif +} + +uint32_t getClockMicroseconds() +{ + static double period = getClockPeriod() * 1e6; + static double start = getClockTimestamp(); + + return uint32_t((getClockTimestamp() - start) * period); +} + +struct GlobalContext +{ + GlobalContext() = default; + ~GlobalContext() + { + // Ideally we would want all ThreadContext destructors to run + // But in VS, not all thread_local object instances are destroyed + for (ThreadContext* context : threads) + context->flushEvents(); + + if (traceFile) + fclose(traceFile); + } + + std::mutex mutex; + std::vector threads; + uint32_t nextThreadId = 0; + std::vector tokens; + FILE* traceFile = nullptr; +}; + +GlobalContext& getGlobalContext() +{ + static GlobalContext context; + return context; +} + +uint16_t createToken(GlobalContext& context, const char* name, const char* category) +{ + std::scoped_lock lock(context.mutex); + + LUAU_ASSERT(context.tokens.size() < 64 * 1024); + + context.tokens.push_back({name, category}); + return uint16_t(context.tokens.size() - 1); +} + +uint32_t createThread(GlobalContext& context, ThreadContext* threadContext) +{ + std::scoped_lock lock(context.mutex); + + context.threads.push_back(threadContext); + + return ++context.nextThreadId; +} + +void releaseThread(GlobalContext& context, ThreadContext* threadContext) +{ + std::scoped_lock lock(context.mutex); + + if (auto it = std::find(context.threads.begin(), context.threads.end(), threadContext); it != context.threads.end()) + context.threads.erase(it); +} + +void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector& events, const std::vector& data) +{ + std::scoped_lock lock(context.mutex); + + if (!context.traceFile) + { + context.traceFile = fopen("trace.json", "w"); + + if (!context.traceFile) + return; + + fprintf(context.traceFile, "[\n"); + } + + std::string temp; + const unsigned tempReserve = 64 * 1024; + temp.reserve(tempReserve); + + const char* rawData = data.data(); + + // Formatting state + bool unfinishedEnter = false; + bool unfinishedArgs = false; + + for (const Event& ev : events) + { + switch (ev.type) + { + case EventType::Enter: + { + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + Token& token = context.tokens[ev.token]; + + formatAppend(temp, R"({"name": "%s", "cat": "%s", "ph": "B", "ts": %u, "pid": 0, "tid": %u)", token.name, token.category, + ev.data.microsec, threadId); + unfinishedEnter = true; + } + break; + case EventType::Leave: + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + formatAppend(temp, + R"({"ph": "E", "ts": %u, "pid": 0, "tid": %u},)" + "\n", + ev.data.microsec, threadId); + break; + case EventType::ArgName: + LUAU_ASSERT(unfinishedEnter); + + if (!unfinishedArgs) + { + formatAppend(temp, R"(, "args": { "%s": )", rawData + ev.data.dataPos); + unfinishedArgs = true; + } + else + { + formatAppend(temp, R"(, "%s": )", rawData + ev.data.dataPos); + } + break; + case EventType::ArgValue: + LUAU_ASSERT(unfinishedArgs); + formatAppend(temp, R"("%s")", rawData + ev.data.dataPos); + break; + } + + // Don't want to hit the string capacity and reallocate + if (temp.size() > tempReserve - 1024) + { + fwrite(temp.data(), 1, temp.size(), context.traceFile); + temp.clear(); + } + } + + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + fwrite(temp.data(), 1, temp.size(), context.traceFile); + fflush(context.traceFile); +} + +ThreadContext& getThreadContext() +{ + thread_local ThreadContext context; + return context; +} + +std::pair createScopeData(const char* name, const char* category) +{ + uint16_t token = createToken(Luau::TimeTrace::getGlobalContext(), name, category); + return {token, Luau::TimeTrace::getThreadContext()}; +} +} // namespace TimeTrace +} // namespace Luau + +#endif diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 920502b..ed0552d 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -111,11 +111,24 @@ struct CliFileResolver : Luau::FileResolver return Luau::SourceCode{*source, Luau::SourceCode::Module}; } + std::optional resolveModule(const Luau::ModuleInfo* context, Luau::AstExpr* node) override + { + if (Luau::AstExprConstantString* expr = node->as()) + { + Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".lua"; + + return {{name}}; + } + + return std::nullopt; + } + bool moduleExists(const Luau::ModuleName& name) const override { return !!readFile(name); } + std::optional fromAstFragment(Luau::AstExpr* expr) const override { return std::nullopt; @@ -130,11 +143,6 @@ struct CliFileResolver : Luau::FileResolver { return std::nullopt; } - - std::optional getEnvironmentForModule(const Luau::ModuleName& name) const override - { - return std::nullopt; - } }; struct CliConfigResolver : Luau::ConfigResolver diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 2156424..7750a1d 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -4,6 +4,7 @@ #include "Luau/Parser.h" #include "Luau/BytecodeBuilder.h" #include "Luau/Common.h" +#include "Luau/TimeTrace.h" #include #include @@ -137,6 +138,11 @@ struct Compiler uint32_t compileFunction(AstExprFunction* func) { + LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler"); + + if (func->debugname.value) + LUAU_TIMETRACE_ARGUMENT("name", func->debugname.value); + LUAU_ASSERT(!functions.contains(func)); LUAU_ASSERT(regTop == 0 && stackSize == 0 && localStack.empty() && upvals.empty()); @@ -3686,6 +3692,8 @@ struct Compiler void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options) { + LUAU_TIMETRACE_SCOPE("compileOrThrow", "Compiler"); + Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block table imports @@ -3748,6 +3756,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const std::string compile(const std::string& source, const CompileOptions& options, const ParseOptions& parseOptions, BytecodeEncoder* encoder) { + LUAU_TIMETRACE_SCOPE("compile", "Compiler"); + Allocator allocator; AstNameTable names(allocator); ParseResult result = Parser::parse(source.c_str(), source.size(), names, allocator, parseOptions); diff --git a/Sources.cmake b/Sources.cmake index 6f96f6a..83ed523 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -9,6 +9,7 @@ target_sources(Luau.Ast PRIVATE Ast/include/Luau/ParseOptions.h Ast/include/Luau/Parser.h Ast/include/Luau/StringUtils.h + Ast/include/Luau/TimeTrace.h Ast/src/Ast.cpp Ast/src/Confusables.cpp @@ -16,6 +17,7 @@ target_sources(Luau.Ast PRIVATE Ast/src/Location.cpp Ast/src/Parser.cpp Ast/src/StringUtils.cpp + Ast/src/TimeTrace.cpp ) # Luau.Compiler Sources @@ -46,6 +48,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Predicate.h Analysis/include/Luau/RecursionCounter.h Analysis/include/Luau/RequireTracer.h + Analysis/include/Luau/Scope.h Analysis/include/Luau/Substitution.h Analysis/include/Luau/Symbol.h Analysis/include/Luau/TopoSortStatements.h @@ -75,6 +78,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Module.cpp Analysis/src/Predicate.cpp Analysis/src/RequireTracer.cpp + Analysis/src/Scope.cpp Analysis/src/Substitution.cpp Analysis/src/Symbol.cpp Analysis/src/TopoSortStatements.cpp @@ -188,6 +192,7 @@ if(TARGET Luau.UnitTest) tests/TopoSort.test.cpp tests/ToString.test.cpp tests/Transpiler.test.cpp + tests/TypeInfer.aliases.test.cpp tests/TypeInfer.annotations.test.cpp tests/TypeInfer.builtins.test.cpp tests/TypeInfer.classes.test.cpp diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 7736671..fb4978d 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -8,9 +8,13 @@ #include "lmem.h" #include "lvm.h" -#include - +#if LUA_USE_LONGJMP #include +#include +#else +#include +#endif + #include LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) @@ -51,8 +55,8 @@ l_noret luaD_throw(lua_State* L, int errcode) longjmp(jb->buf, 1); } - if (L->global->panic) - L->global->panic(L, errcode); + if (L->global->cb.panic) + L->global->cb.panic(L, errcode); abort(); } diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 85af403..5de5d49 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -16,6 +16,8 @@ LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false) LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false) LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false) +LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false) + LUAU_FASTFLAG(LuauArrayBoundary) #define GC_SWEEPMAX 40 @@ -810,6 +812,133 @@ static size_t singlestep(lua_State* L) return cost; } +static size_t gcstep(lua_State* L, size_t limit) +{ + size_t cost = 0; + global_State* g = L->global; + switch (g->gcstate) + { + case GCSpause: + { + markroot(L); /* start a new collection */ + break; + } + case GCSpropagate: + { + if (FFlag::LuauRescanGrayAgain) + { + while (g->gray && cost < limit) + { + g->gcstats.currcycle.markitems++; + + cost += propagatemark(g); + } + + if (!g->gray) + { + // perform one iteration over 'gray again' list + g->gray = g->grayagain; + g->grayagain = NULL; + + g->gcstate = GCSpropagateagain; + } + } + else + { + while (g->gray && cost < limit) + { + g->gcstats.currcycle.markitems++; + + cost += propagatemark(g); + } + + if (!g->gray) /* no more `gray' objects */ + { + double starttimestamp = lua_clock(); + + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + atomic(L); /* finish mark phase */ + + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } + } + break; + } + case GCSpropagateagain: + { + while (g->gray && cost < limit) + { + g->gcstats.currcycle.markitems++; + + cost += propagatemark(g); + } + + if (!g->gray) /* no more `gray' objects */ + { + double starttimestamp = lua_clock(); + + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + atomic(L); /* finish mark phase */ + + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } + break; + } + case GCSsweepstring: + { + while (g->sweepstrgc < g->strt.size && cost < limit) + { + size_t traversedcount = 0; + sweepwholelist(L, &g->strt.hash[g->sweepstrgc++], &traversedcount); + + g->gcstats.currcycle.sweepitems += traversedcount; + cost += GC_SWEEPCOST; + } + + // nothing more to sweep? + if (g->sweepstrgc >= g->strt.size) + { + // sweep string buffer list and preserve used string count + uint32_t nuse = L->global->strt.nuse; + + size_t traversedcount = 0; + sweepwholelist(L, &g->strbufgc, &traversedcount); + + L->global->strt.nuse = nuse; + + g->gcstats.currcycle.sweepitems += traversedcount; + g->gcstate = GCSsweep; // end sweep-string phase + } + break; + } + case GCSsweep: + { + while (*g->sweepgc && cost < limit) + { + size_t traversedcount = 0; + g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); + + g->gcstats.currcycle.sweepitems += traversedcount; + cost += GC_SWEEPMAX * GC_SWEEPCOST; + } + + if (*g->sweepgc == NULL) + { /* nothing more to sweep? */ + shrinkbuffers(L); + g->gcstate = GCSpause; /* end collection */ + } + break; + } + default: + LUAU_ASSERT(0); + } + return cost; +} + static int64_t getheaptriggererroroffset(GCHeapTriggerStats* triggerstats, GCCycleStats* cyclestats) { // adjust for error using Proportional-Integral controller @@ -878,33 +1007,40 @@ void luaC_step(lua_State* L, bool assist) if (g->gcstate == GCSpause) startGcCycleStats(g); - if (assist) - g->gcstats.currcycle.assistwork += lim; - else - g->gcstats.currcycle.explicitwork += lim; - int lastgcstate = g->gcstate; double lasttimestamp = lua_clock(); - // always perform at least one single step - do + if (FFlag::LuauConsolidatedStep) { - lim -= singlestep(L); + size_t work = gcstep(L, lim); - // if we have switched to a different state, capture the duration of last stage - // this way we reduce the number of timer calls we make - if (lastgcstate != g->gcstate) + if (assist) + g->gcstats.currcycle.assistwork += work; + else + g->gcstats.currcycle.explicitwork += work; + } + else + { + // always perform at least one single step + do { - GC_INTERRUPT(lastgcstate); + lim -= singlestep(L); - double now = lua_clock(); + // if we have switched to a different state, capture the duration of last stage + // this way we reduce the number of timer calls we make + if (lastgcstate != g->gcstate) + { + GC_INTERRUPT(lastgcstate); - recordGcStateTime(g, lastgcstate, now - lasttimestamp, assist); + double now = lua_clock(); - lasttimestamp = now; - lastgcstate = g->gcstate; - } - } while (lim > 0 && g->gcstate != GCSpause); + recordGcStateTime(g, lastgcstate, now - lasttimestamp, assist); + + lasttimestamp = now; + lastgcstate = g->gcstate; + } + } while (lim > 0 && g->gcstate != GCSpause); + } recordGcStateTime(g, lastgcstate, lua_clock() - lasttimestamp, assist); @@ -931,7 +1067,14 @@ void luaC_step(lua_State* L, bool assist) g->GCthreshold -= debt; } - GC_INTERRUPT(g->gcstate); + if (FFlag::LuauConsolidatedStep) + { + GC_INTERRUPT(lastgcstate); + } + else + { + GC_INTERRUPT(g->gcstate); + } } void luaC_fullgc(lua_State* L) @@ -957,7 +1100,10 @@ void luaC_fullgc(lua_State* L) while (g->gcstate != GCSpause) { LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); - singlestep(L); + if (FFlag::LuauConsolidatedStep) + gcstep(L, SIZE_MAX); + else + singlestep(L); } finishGcCycleStats(g); @@ -968,7 +1114,10 @@ void luaC_fullgc(lua_State* L) markroot(L); while (g->gcstate != GCSpause) { - singlestep(L); + if (FFlag::LuauConsolidatedStep) + gcstep(L, SIZE_MAX); + else + singlestep(L); } /* reclaim as much buffer memory as possible (shrinkbuffers() called during sweep is incremental) */ shrinkbuffersfull(L); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 090e183..de5788e 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -9,14 +9,8 @@ #include "ldebug.h" #include "lvm.h" -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauTableMoveTelemetry, false) - LUAU_FASTFLAGVARIABLE(LuauTableFreeze, false) -bool lua_telemetry_table_move_oob_src_from = false; -bool lua_telemetry_table_move_oob_src_to = false; -bool lua_telemetry_table_move_oob_dst = false; - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -202,22 +196,6 @@ static int tmove(lua_State* L) int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ luaL_checktype(L, tt, LUA_TTABLE); - if (DFFlag::LuauTableMoveTelemetry) - { - int nf = lua_objlen(L, 1); - int nt = lua_objlen(L, tt); - - // source index range must be in bounds in source table unless the table is empty (permits 1..#t moves) - if (!(f == 1 || (f >= 1 && f <= nf))) - lua_telemetry_table_move_oob_src_from = true; - if (!(e == nf || (e >= 1 && e <= nf))) - lua_telemetry_table_move_oob_src_to = true; - - // destination index must be in bounds in dest table or be exactly at the first empty element (permits concats) - if (!(t == nt + 1 || (t >= 1 && t <= nt + 1))) - lua_telemetry_table_move_oob_dst = true; - } - if (e >= f) { /* otherwise, nothing to move */ luaL_argcheck(L, f > 0 || e < INT_MAX + f, 3, "too many elements to move"); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 5f0ee92..eed2862 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,8 +16,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauLoopUseSafeenv, false) - // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -292,10 +290,6 @@ inline bool luau_skipstep(uint8_t op) return op == LOP_PREPVARARGS || op == LOP_BREAK; } -// declared in lbaselib.cpp, needed to support cases when pairs/ipairs have been replaced via setfenv -LUAI_FUNC int luaB_inext(lua_State* L); -LUAI_FUNC int luaB_next(lua_State* L); - template static void luau_execute(lua_State* L) { @@ -2223,8 +2217,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); // fast-path: ipairs/inext - bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_inext; - if (safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) + if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) { setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } @@ -2304,8 +2297,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); // fast-path: pairs/next - bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_next; - if (safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) + if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) { setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 0a23234..b932a85 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -12,7 +12,32 @@ #include -#include +// TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens +template +struct TempBuffer +{ + lua_State* L; + T* data; + size_t count; + + TempBuffer(lua_State* L, size_t count) + : L(L) + , data(luaM_newarray(L, count, T, 0)) + , count(count) + { + } + + ~TempBuffer() + { + luaM_freearray(L, data, count, T, 0); + } + + T& operator[](size_t index) + { + LUAU_ASSERT(index < count); + return data[index]; + } +}; void luaV_getimport(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil) { @@ -67,7 +92,7 @@ static unsigned int readVarInt(const char* data, size_t size, size_t& offset) return result; } -static TString* readString(std::vector& strings, const char* data, size_t size, size_t& offset) +static TString* readString(TempBuffer& strings, const char* data, size_t size, size_t& offset) { unsigned int id = readVarInt(data, size, offset); @@ -133,6 +158,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size } // pause GC for the duration of deserialization - some objects we're creating aren't rooted + // TODO: if an allocation error happens mid-load, we do not unpause GC! size_t GCthreshold = L->global->GCthreshold; L->global->GCthreshold = SIZE_MAX; @@ -144,7 +170,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size // string table unsigned int stringCount = readVarInt(data, size, offset); - std::vector strings(stringCount); + TempBuffer strings(L, stringCount); for (unsigned int i = 0; i < stringCount; ++i) { @@ -156,7 +182,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size // proto table unsigned int protoCount = readVarInt(data, size, offset); - std::vector protos(protoCount); + TempBuffer protos(L, protoCount); for (unsigned int i = 0; i < protoCount; ++i) { diff --git a/bench/tests/deltablue.lua b/bench/tests/deltablue.lua deleted file mode 100644 index ad18c23..0000000 --- a/bench/tests/deltablue.lua +++ /dev/null @@ -1,934 +0,0 @@ -local bench = script and require(script.Parent.bench_support) or require("bench_support") - --- Copyright 2008 the V8 project authors. All rights reserved. --- Copyright 1996 John Maloney and Mario Wolczko. - --- This program is free software; you can redistribute it and/or modify --- it under the terms of the GNU General Public License as published by --- the Free Software Foundation; either version 2 of the License, or --- (at your option) any later version. --- --- This program is distributed in the hope that it will be useful, --- but WITHOUT ANY WARRANTY; without even the implied warranty of --- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the --- GNU General Public License for more details. --- --- You should have received a copy of the GNU General Public License --- along with this program; if not, write to the Free Software --- Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - - --- This implementation of the DeltaBlue benchmark is derived --- from the Smalltalk implementation by John Maloney and Mario --- Wolczko. Some parts have been translated directly, whereas --- others have been modified more aggressively to make it feel --- more like a JavaScript program. - - --- --- A JavaScript implementation of the DeltaBlue constraint-solving --- algorithm, as described in: --- --- "The DeltaBlue Algorithm: An Incremental Constraint Hierarchy Solver" --- Bjorn N. Freeman-Benson and John Maloney --- January 1990 Communications of the ACM, --- also available as University of Washington TR 89-08-06. --- --- Beware: this benchmark is written in a grotesque style where --- the constraint model is built by side-effects from constructors. --- I've kept it this way to avoid deviating too much from the original --- implementation. --- - -function class(base) - local T = {} - T.__index = T - - if base then - T.super = base - setmetatable(T, base) - end - - function T.new(...) - local O = {} - setmetatable(O, T) - O:constructor(...) - return O - end - - return T -end - -local planner - ---- O b j e c t M o d e l --- - -local function alert (...) print(...) end - -local OrderedCollection = class() - -function OrderedCollection:constructor() - self.elms = {} -end - -function OrderedCollection:add(elm) - self.elms[#self.elms + 1] = elm -end - -function OrderedCollection:at (index) - return self.elms[index] -end - -function OrderedCollection:size () - return #self.elms -end - -function OrderedCollection:removeFirst () - local e = self.elms[#self.elms] - self.elms[#self.elms] = nil - return e -end - -function OrderedCollection:remove (elm) - local index = 0 - local skipped = 0 - - for i = 1, #self.elms do - local value = self.elms[i] - if value ~= elm then - self.elms[index] = value - index = index + 1 - else - skipped = skipped + 1 - end - end - - local l = #self.elms - for i = 1, skipped do self.elms[l - i + 1] = nil end -end - --- --- S t r e n g t h --- - --- --- Strengths are used to measure the relative importance of constraints. --- New strengths may be inserted in the strength hierarchy without --- disrupting current constraints. Strengths cannot be created outside --- this class, so pointer comparison can be used for value comparison. --- - -local Strength = class() - -function Strength:constructor(strengthValue, name) - self.strengthValue = strengthValue - self.name = name -end - -function Strength.stronger (s1, s2) - return s1.strengthValue < s2.strengthValue -end - -function Strength.weaker (s1, s2) - return s1.strengthValue > s2.strengthValue -end - -function Strength.weakestOf (s1, s2) - return Strength.weaker(s1, s2) and s1 or s2 -end - -function Strength.strongest (s1, s2) - return Strength.stronger(s1, s2) and s1 or s2 -end - -function Strength:nextWeaker () - local v = self.strengthValue - if v == 0 then return Strength.WEAKEST - elseif v == 1 then return Strength.WEAK_DEFAULT - elseif v == 2 then return Strength.NORMAL - elseif v == 3 then return Strength.STRONG_DEFAULT - elseif v == 4 then return Strength.PREFERRED - elseif v == 5 then return Strength.REQUIRED - end -end - --- Strength constants. -Strength.REQUIRED = Strength.new(0, "required"); -Strength.STONG_PREFERRED = Strength.new(1, "strongPreferred"); -Strength.PREFERRED = Strength.new(2, "preferred"); -Strength.STRONG_DEFAULT = Strength.new(3, "strongDefault"); -Strength.NORMAL = Strength.new(4, "normal"); -Strength.WEAK_DEFAULT = Strength.new(5, "weakDefault"); -Strength.WEAKEST = Strength.new(6, "weakest"); - --- --- C o n s t r a i n t --- - --- --- An abstract class representing a system-maintainable relationship --- (or "constraint") between a set of variables. A constraint supplies --- a strength instance variable; concrete subclasses provide a means --- of storing the constrained variables and other information required --- to represent a constraint. --- - -local Constraint = class () - -function Constraint:constructor(strength) - self.strength = strength -end - --- --- Activate this constraint and attempt to satisfy it. --- -function Constraint:addConstraint () - self:addToGraph() - planner:incrementalAdd(self) -end - --- --- Attempt to find a way to enforce this constraint. If successful, --- record the solution, perhaps modifying the current dataflow --- graph. Answer the constraint that this constraint overrides, if --- there is one, or nil, if there isn't. --- Assume: I am not already satisfied. --- -function Constraint:satisfy (mark) - self:chooseMethod(mark) - if not self:isSatisfied() then - if self.strength == Strength.REQUIRED then - alert("Could not satisfy a required constraint!") - end - return nil - end - self:markInputs(mark) - local out = self:output() - local overridden = out.determinedBy - if overridden ~= nil then overridden:markUnsatisfied() end - out.determinedBy = self - if not planner:addPropagate(self, mark) then alert("Cycle encountered") end - out.mark = mark - return overridden -end - -function Constraint:destroyConstraint () - if self:isSatisfied() - then planner:incrementalRemove(self) - else self:removeFromGraph() - end -end - --- --- Normal constraints are not input constraints. An input constraint --- is one that depends on external state, such as the mouse, the --- keyboard, a clock, or some arbitrary piece of imperative code. --- -function Constraint:isInput () - return false -end - - --- --- U n a r y C o n s t r a i n t --- - --- --- Abstract superclass for constraints having a single possible output --- variable. --- - -local UnaryConstraint = class(Constraint) - -function UnaryConstraint:constructor (v, strength) - UnaryConstraint.super.constructor(self, strength) - self.myOutput = v - self.satisfied = false - self:addConstraint() -end - --- --- Adds this constraint to the constraint graph --- -function UnaryConstraint:addToGraph () - self.myOutput:addConstraint(self) - self.satisfied = false -end - --- --- Decides if this constraint can be satisfied and records that --- decision. --- -function UnaryConstraint:chooseMethod (mark) - self.satisfied = (self.myOutput.mark ~= mark) - and Strength.stronger(self.strength, self.myOutput.walkStrength); -end - --- --- Returns true if this constraint is satisfied in the current solution. --- -function UnaryConstraint:isSatisfied () - return self.satisfied; -end - -function UnaryConstraint:markInputs (mark) - -- has no inputs -end - --- --- Returns the current output variable. --- -function UnaryConstraint:output () - return self.myOutput -end - --- --- Calculate the walkabout strength, the stay flag, and, if it is --- 'stay', the value for the current output of this constraint. Assume --- this constraint is satisfied. --- -function UnaryConstraint:recalculate () - self.myOutput.walkStrength = self.strength - self.myOutput.stay = not self:isInput() - if self.myOutput.stay then - self:execute() -- Stay optimization - end -end - --- --- Records that this constraint is unsatisfied --- -function UnaryConstraint:markUnsatisfied () - self.satisfied = false -end - -function UnaryConstraint:inputsKnown () - return true -end - -function UnaryConstraint:removeFromGraph () - if self.myOutput ~= nil then - self.myOutput:removeConstraint(self) - end - self.satisfied = false -end - --- --- S t a y C o n s t r a i n t --- - --- --- Variables that should, with some level of preference, stay the same. --- Planners may exploit the fact that instances, if satisfied, will not --- change their output during plan execution. This is called "stay --- optimization". --- - -local StayConstraint = class(UnaryConstraint) - -function StayConstraint:constructor(v, str) - StayConstraint.super.constructor(self, v, str) -end - -function StayConstraint:execute () - -- Stay constraints do nothing -end - --- --- E d i t C o n s t r a i n t --- - --- --- A unary input constraint used to mark a variable that the client --- wishes to change. --- - -local EditConstraint = class (UnaryConstraint) - -function EditConstraint:constructor(v, str) - EditConstraint.super.constructor(self, v, str) -end - --- --- Edits indicate that a variable is to be changed by imperative code. --- -function EditConstraint:isInput () - return true -end - -function EditConstraint:execute () - -- Edit constraints do nothing -end - --- --- B i n a r y C o n s t r a i n t --- - -local Direction = {} -Direction.NONE = 0 -Direction.FORWARD = 1 -Direction.BACKWARD = -1 - --- --- Abstract superclass for constraints having two possible output --- variables. --- - -local BinaryConstraint = class(Constraint) - -function BinaryConstraint:constructor(var1, var2, strength) - BinaryConstraint.super.constructor(self, strength); - self.v1 = var1 - self.v2 = var2 - self.direction = Direction.NONE - self:addConstraint() -end - - --- --- Decides if this constraint can be satisfied and which way it --- should flow based on the relative strength of the variables related, --- and record that decision. --- -function BinaryConstraint:chooseMethod (mark) - if self.v1.mark == mark then - self.direction = (self.v2.mark ~= mark and Strength.stronger(self.strength, self.v2.walkStrength)) and Direction.FORWARD or Direction.NONE - end - if self.v2.mark == mark then - self.direction = (self.v1.mark ~= mark and Strength.stronger(self.strength, self.v1.walkStrength)) and Direction.BACKWARD or Direction.NONE - end - if Strength.weaker(self.v1.walkStrength, self.v2.walkStrength) then - self.direction = Strength.stronger(self.strength, self.v1.walkStrength) and Direction.BACKWARD or Direction.NONE - else - self.direction = Strength.stronger(self.strength, self.v2.walkStrength) and Direction.FORWARD or Direction.BACKWARD - end -end - --- --- Add this constraint to the constraint graph --- -function BinaryConstraint:addToGraph () - self.v1:addConstraint(self) - self.v2:addConstraint(self) - self.direction = Direction.NONE -end - --- --- Answer true if this constraint is satisfied in the current solution. --- -function BinaryConstraint:isSatisfied () - return self.direction ~= Direction.NONE -end - --- --- Mark the input variable with the given mark. --- -function BinaryConstraint:markInputs (mark) - self:input().mark = mark -end - --- --- Returns the current input variable --- -function BinaryConstraint:input () - return (self.direction == Direction.FORWARD) and self.v1 or self.v2 -end - --- --- Returns the current output variable --- -function BinaryConstraint:output () - return (self.direction == Direction.FORWARD) and self.v2 or self.v1 -end - --- --- Calculate the walkabout strength, the stay flag, and, if it is --- 'stay', the value for the current output of this --- constraint. Assume this constraint is satisfied. --- -function BinaryConstraint:recalculate () - local ihn = self:input() - local out = self:output() - out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength); - out.stay = ihn.stay - if out.stay then self:execute() end -end - --- --- Record the fact that self constraint is unsatisfied. --- -function BinaryConstraint:markUnsatisfied () - self.direction = Direction.NONE -end - -function BinaryConstraint:inputsKnown (mark) - local i = self:input() - return i.mark == mark or i.stay or i.determinedBy == nil -end - -function BinaryConstraint:removeFromGraph () - if (self.v1 ~= nil) then self.v1:removeConstraint(self) end - if (self.v2 ~= nil) then self.v2:removeConstraint(self) end - self.direction = Direction.NONE -end - --- --- S c a l e C o n s t r a i n t --- - --- --- Relates two variables by the linear scaling relationship: "v2 = --- (v1 * scale) + offset". Either v1 or v2 may be changed to maintain --- this relationship but the scale factor and offset are considered --- read-only. --- - -local ScaleConstraint = class (BinaryConstraint) - -function ScaleConstraint:constructor(src, scale, offset, dest, strength) - self.direction = Direction.NONE - self.scale = scale - self.offset = offset - ScaleConstraint.super.constructor(self, src, dest, strength) -end - - --- --- Adds this constraint to the constraint graph. --- -function ScaleConstraint:addToGraph () - ScaleConstraint.super.addToGraph(self) - self.scale:addConstraint(self) - self.offset:addConstraint(self) -end - -function ScaleConstraint:removeFromGraph () - ScaleConstraint.super.removeFromGraph(self) - if (self.scale ~= nil) then self.scale:removeConstraint(self) end - if (self.offset ~= nil) then self.offset:removeConstraint(self) end -end - -function ScaleConstraint:markInputs (mark) - ScaleConstraint.super.markInputs(self, mark); - self.offset.mark = mark - self.scale.mark = mark -end - --- --- Enforce this constraint. Assume that it is satisfied. --- -function ScaleConstraint:execute () - if self.direction == Direction.FORWARD then - self.v2.value = self.v1.value * self.scale.value + self.offset.value - else - self.v1.value = (self.v2.value - self.offset.value) / self.scale.value - end -end - --- --- Calculate the walkabout strength, the stay flag, and, if it is --- 'stay', the value for the current output of this constraint. Assume --- this constraint is satisfied. --- -function ScaleConstraint:recalculate () - local ihn = self:input() - local out = self:output() - out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength) - out.stay = ihn.stay and self.scale.stay and self.offset.stay - if out.stay then self:execute() end -end - --- --- E q u a l i t y C o n s t r a i n t --- - --- --- Constrains two variables to have the same value. --- - -local EqualityConstraint = class (BinaryConstraint) - -function EqualityConstraint:constructor(var1, var2, strength) - EqualityConstraint.super.constructor(self, var1, var2, strength) -end - - --- --- Enforce this constraint. Assume that it is satisfied. --- -function EqualityConstraint:execute () - self:output().value = self:input().value -end - --- --- V a r i a b l e --- - --- --- A constrained variable. In addition to its value, it maintain the --- structure of the constraint graph, the current dataflow graph, and --- various parameters of interest to the DeltaBlue incremental --- constraint solver. --- -local Variable = class () - -function Variable:constructor(name, initialValue) - self.value = initialValue or 0 - self.constraints = OrderedCollection.new() - self.determinedBy = nil - self.mark = 0 - self.walkStrength = Strength.WEAKEST - self.stay = true - self.name = name -end - --- --- Add the given constraint to the set of all constraints that refer --- this variable. --- -function Variable:addConstraint (c) - self.constraints:add(c) -end - --- --- Removes all traces of c from this variable. --- -function Variable:removeConstraint (c) - self.constraints:remove(c) - if self.determinedBy == c then - self.determinedBy = nil - end -end - --- --- P l a n n e r --- - --- --- The DeltaBlue planner --- -local Planner = class() -function Planner:constructor() - self.currentMark = 0 -end - --- --- Attempt to satisfy the given constraint and, if successful, --- incrementally update the dataflow graph. Details: If satisfying --- the constraint is successful, it may override a weaker constraint --- on its output. The algorithm attempts to resatisfy that --- constraint using some other method. This process is repeated --- until either a) it reaches a variable that was not previously --- determined by any constraint or b) it reaches a constraint that --- is too weak to be satisfied using any of its methods. The --- variables of constraints that have been processed are marked with --- a unique mark value so that we know where we've been. This allows --- the algorithm to avoid getting into an infinite loop even if the --- constraint graph has an inadvertent cycle. --- -function Planner:incrementalAdd (c) - local mark = self:newMark() - local overridden = c:satisfy(mark) - while overridden ~= nil do - overridden = overridden:satisfy(mark) - end -end - --- --- Entry point for retracting a constraint. Remove the given --- constraint and incrementally update the dataflow graph. --- Details: Retracting the given constraint may allow some currently --- unsatisfiable downstream constraint to be satisfied. We therefore collect --- a list of unsatisfied downstream constraints and attempt to --- satisfy each one in turn. This list is traversed by constraint --- strength, strongest first, as a heuristic for avoiding --- unnecessarily adding and then overriding weak constraints. --- Assume: c is satisfied. --- -function Planner:incrementalRemove (c) - local out = c:output() - c:markUnsatisfied() - c:removeFromGraph() - local unsatisfied = self:removePropagateFrom(out) - local strength = Strength.REQUIRED - repeat - for i = 1, unsatisfied:size() do - local u = unsatisfied:at(i) - if u.strength == strength then - self:incrementalAdd(u) - end - end - strength = strength:nextWeaker() - until strength == Strength.WEAKEST -end - --- --- Select a previously unused mark value. --- -function Planner:newMark () - self.currentMark = self.currentMark + 1 - return self.currentMark -end - --- --- Extract a plan for resatisfaction starting from the given source --- constraints, usually a set of input constraints. This method --- assumes that stay optimization is desired; the plan will contain --- only constraints whose output variables are not stay. Constraints --- that do no computation, such as stay and edit constraints, are --- not included in the plan. --- Details: The outputs of a constraint are marked when it is added --- to the plan under construction. A constraint may be appended to --- the plan when all its input variables are known. A variable is --- known if either a) the variable is marked (indicating that has --- been computed by a constraint appearing earlier in the plan), b) --- the variable is 'stay' (i.e. it is a constant at plan execution --- time), or c) the variable is not determined by any --- constraint. The last provision is for past states of history --- variables, which are not stay but which are also not computed by --- any constraint. --- Assume: sources are all satisfied. --- -local Plan -- FORWARD DECLARATION -function Planner:makePlan (sources) - local mark = self:newMark() - local plan = Plan.new() - local todo = sources - while todo:size() > 0 do - local c = todo:removeFirst() - if c:output().mark ~= mark and c:inputsKnown(mark) then - plan:addConstraint(c) - c:output().mark = mark - self:addConstraintsConsumingTo(c:output(), todo) - end - end - return plan -end - --- --- Extract a plan for resatisfying starting from the output of the --- given constraints, usually a set of input constraints. --- -function Planner:extractPlanFromConstraints (constraints) - local sources = OrderedCollection.new() - for i = 1, constraints:size() do - local c = constraints:at(i) - if c:isInput() and c:isSatisfied() then - -- not in plan already and eligible for inclusion - sources:add(c) - end - end - return self:makePlan(sources) -end - --- --- Recompute the walkabout strengths and stay flags of all variables --- downstream of the given constraint and recompute the actual --- values of all variables whose stay flag is true. If a cycle is --- detected, remove the given constraint and answer --- false. Otherwise, answer true. --- Details: Cycles are detected when a marked variable is --- encountered downstream of the given constraint. The sender is --- assumed to have marked the inputs of the given constraint with --- the given mark. Thus, encountering a marked node downstream of --- the output constraint means that there is a path from the --- constraint's output to one of its inputs. --- -function Planner:addPropagate (c, mark) - local todo = OrderedCollection.new() - todo:add(c) - while todo:size() > 0 do - local d = todo:removeFirst() - if d:output().mark == mark then - self:incrementalRemove(c) - return false - end - d:recalculate() - self:addConstraintsConsumingTo(d:output(), todo) - end - return true -end - - --- --- Update the walkabout strengths and stay flags of all variables --- downstream of the given constraint. Answer a collection of --- unsatisfied constraints sorted in order of decreasing strength. --- -function Planner:removePropagateFrom (out) - out.determinedBy = nil - out.walkStrength = Strength.WEAKEST - out.stay = true - local unsatisfied = OrderedCollection.new() - local todo = OrderedCollection.new() - todo:add(out) - while todo:size() > 0 do - local v = todo:removeFirst() - for i = 1, v.constraints:size() do - local c = v.constraints:at(i) - if not c:isSatisfied() then unsatisfied:add(c) end - end - local determining = v.determinedBy - for i = 1, v.constraints:size() do - local next = v.constraints:at(i); - if next ~= determining and next:isSatisfied() then - next:recalculate() - todo:add(next:output()) - end - end - end - return unsatisfied -end - -function Planner:addConstraintsConsumingTo (v, coll) - local determining = v.determinedBy - local cc = v.constraints - for i = 1, cc:size() do - local c = cc:at(i) - if c ~= determining and c:isSatisfied() then - coll:add(c) - end - end -end - --- --- P l a n --- - --- --- A Plan is an ordered list of constraints to be executed in sequence --- to resatisfy all currently satisfiable constraints in the face of --- one or more changing inputs. --- -Plan = class() -function Plan:constructor() - self.v = OrderedCollection.new() -end - -function Plan:addConstraint (c) - self.v:add(c) -end - -function Plan:size () - return self.v:size() -end - -function Plan:constraintAt (index) - return self.v:at(index) -end - -function Plan:execute () - for i = 1, self:size() do - local c = self:constraintAt(i) - c:execute() - end -end - --- --- M a i n --- - --- --- This is the standard DeltaBlue benchmark. A long chain of equality --- constraints is constructed with a stay constraint on one end. An --- edit constraint is then added to the opposite end and the time is --- measured for adding and removing this constraint, and extracting --- and executing a constraint satisfaction plan. There are two cases. --- In case 1, the added constraint is stronger than the stay --- constraint and values must propagate down the entire length of the --- chain. In case 2, the added constraint is weaker than the stay --- constraint so it cannot be accommodated. The cost in this case is, --- of course, very low. Typical situations lie somewhere between these --- two extremes. --- -local function chainTest(n) - planner = Planner.new() - local prev = nil - local first = nil - local last = nil - - -- Build chain of n equality constraints - for i = 0, n do - local name = "v" .. i; - local v = Variable.new(name) - if prev ~= nil then EqualityConstraint.new(prev, v, Strength.REQUIRED) end - if i == 0 then first = v end - if i == n then last = v end - prev = v - end - - StayConstraint.new(last, Strength.STRONG_DEFAULT) - local edit = EditConstraint.new(first, Strength.PREFERRED) - local edits = OrderedCollection.new() - edits:add(edit) - local plan = planner:extractPlanFromConstraints(edits) - for i = 0, 99 do - first.value = i - plan:execute() - if last.value ~= i then - alert("Chain test failed.") - end - end -end - -local function change(v, newValue) - local edit = EditConstraint.new(v, Strength.PREFERRED) - local edits = OrderedCollection.new() - edits:add(edit) - local plan = planner:extractPlanFromConstraints(edits) - for i = 1, 10 do - v.value = newValue - plan:execute() - end - edit:destroyConstraint() -end - --- --- This test constructs a two sets of variables related to each --- other by a simple linear transformation (scale and offset). The --- time is measured to change a variable on either side of the --- mapping and to change the scale and offset factors. --- -local function projectionTest(n) - planner = Planner.new(); - local scale = Variable.new("scale", 10); - local offset = Variable.new("offset", 1000); - local src = nil - local dst = nil; - - local dests = OrderedCollection.new(); - for i = 0, n - 1 do - src = Variable.new("src" .. i, i); - dst = Variable.new("dst" .. i, i); - dests:add(dst); - StayConstraint.new(src, Strength.NORMAL); - ScaleConstraint.new(src, scale, offset, dst, Strength.REQUIRED); - end - - change(src, 17) - if dst.value ~= 1170 then alert("Projection 1 failed") end - change(dst, 1050) - if src.value ~= 5 then alert("Projection 2 failed") end - change(scale, 5) - for i = 0, n - 2 do - if dests:at(i + 1).value ~= i * 5 + 1000 then - alert("Projection 3 failed") - end - end - change(offset, 2000) - for i = 0, n - 2 do - if dests:at(i + 1).value ~= i * 5 + 2000 then - alert("Projection 4 failed") - end - end -end - -function test() - local t0 = os.clock() - chainTest(1000); - projectionTest(1000); - local t1 = os.clock() - return t1-t0 -end - -bench.runCode(test, "deltablue") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 9cd642c..07910a0 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -23,19 +23,17 @@ static std::optional nullCallback(std::string tag, std::op return std::nullopt; } -struct ACFixture : Fixture +template +struct ACFixtureImpl : BaseType { AutocompleteResult autocomplete(unsigned row, unsigned column) { - return Luau::autocomplete(frontend, "MainModule", Position{row, column}, nullCallback); + return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); } AutocompleteResult autocomplete(char marker) { - auto i = markerPosition.find(marker); - LUAU_ASSERT(i != markerPosition.end()); - const Position& pos = i->second; - return Luau::autocomplete(frontend, "MainModule", pos, nullCallback); + return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), nullCallback); } CheckResult check(const std::string& source) @@ -45,16 +43,18 @@ struct ACFixture : Fixture filteredSource.reserve(source.size()); Position curPos(0, 0); + char prevChar{}; for (char c : source) { - if (c == '@' && !filteredSource.empty()) + if (prevChar == '@') { - char prevChar = filteredSource.back(); - filteredSource.pop_back(); - curPos.column--; // Adjust column position since we removed a character from the output - LUAU_ASSERT("Illegal marker character" && prevChar >= '0' && prevChar <= '9'); - LUAU_ASSERT("Duplicate marker found" && markerPosition.count(prevChar) == 0); - markerPosition.insert(std::pair{prevChar, curPos}); + LUAU_ASSERT("Illegal marker character" && c >= '0' && c <= '9'); + LUAU_ASSERT("Duplicate marker found" && markerPosition.count(c) == 0); + markerPosition.insert(std::pair{c, curPos}); + } + else if (c == '@') + { + // skip the '@' character } else { @@ -69,22 +69,39 @@ struct ACFixture : Fixture curPos.column++; } } + prevChar = c; } + LUAU_ASSERT("Digit expected after @ symbol" && prevChar != '@'); return Fixture::check(filteredSource); } + const Position& getPosition(char marker) const + { + auto i = markerPosition.find(marker); + LUAU_ASSERT(i != markerPosition.end()); + return i->second; + } + // Maps a marker character (0-9 inclusive) to a position in the source code. std::map markerPosition; }; +struct ACFixture : ACFixtureImpl +{ +}; + +struct UnfrozenACFixture : ACFixtureImpl +{ +}; + TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") { - check(" "); + check(" @1"); - auto ac = autocomplete(0, 1); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.empty()); CHECK(ac.entryMap.count("table")); @@ -93,26 +110,26 @@ TEST_CASE_FIXTURE(ACFixture, "empty_program") TEST_CASE_FIXTURE(ACFixture, "local_initializer") { - check("local a = "); + check("local a = @1"); - auto ac = autocomplete(0, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); CHECK(ac.entryMap.count("math")); } TEST_CASE_FIXTURE(ACFixture, "leave_numbers_alone") { - check("local a = 3.1"); + check("local a = 3.@11"); - auto ac = autocomplete(0, 12); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "user_defined_globals") { - check("local myLocal = 4; "); + check("local myLocal = 4; @1"); - auto ac = autocomplete(0, 19); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("myLocal")); CHECK(ac.entryMap.count("table")); @@ -124,20 +141,20 @@ TEST_CASE_FIXTURE(ACFixture, "dont_suggest_local_before_its_definition") check(R"( local myLocal = 4 function abc() - local myInnerLocal = 1 - +@1 local myInnerLocal = 1 +@2 end - )"); +@3 )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("myLocal")); CHECK(!ac.entryMap.count("myInnerLocal")); - ac = autocomplete(4, 0); + ac = autocomplete('2'); CHECK(ac.entryMap.count("myLocal")); CHECK(ac.entryMap.count("myInnerLocal")); - ac = autocomplete(6, 0); + ac = autocomplete('3'); CHECK(ac.entryMap.count("myLocal")); CHECK(!ac.entryMap.count("myInnerLocal")); } @@ -146,10 +163,10 @@ TEST_CASE_FIXTURE(ACFixture, "recursive_function") { check(R"( function foo() - end +@1 end )"); - auto ac = autocomplete(2, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("foo")); } @@ -158,11 +175,11 @@ TEST_CASE_FIXTURE(ACFixture, "nested_recursive_function") check(R"( local function outer() local function inner() - end +@1 end end )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("inner")); CHECK(ac.entryMap.count("outer")); } @@ -171,11 +188,11 @@ TEST_CASE_FIXTURE(ACFixture, "user_defined_local_functions_in_own_definition") { check(R"( local function abc() - +@1 end )"); - auto ac = autocomplete(2, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); CHECK(ac.entryMap.count("table")); @@ -183,11 +200,11 @@ TEST_CASE_FIXTURE(ACFixture, "user_defined_local_functions_in_own_definition") check(R"( local abc = function() - +@1 end )"); - ac = autocomplete(2, 0); + ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); // FIXME: This is actually incorrect! CHECK(ac.entryMap.count("table")); @@ -202,9 +219,9 @@ TEST_CASE_FIXTURE(ACFixture, "global_functions_are_not_scoped_lexically") end end - )"); +@1 )"); - auto ac = autocomplete(6, 0); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.empty()); CHECK(ac.entryMap.count("abc")); @@ -220,9 +237,9 @@ TEST_CASE_FIXTURE(ACFixture, "local_functions_fall_out_of_scope") end end - )"); +@1 )"); - auto ac = autocomplete(6, 0); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); CHECK(!ac.entryMap.count("abc")); @@ -233,10 +250,10 @@ TEST_CASE_FIXTURE(ACFixture, "function_parameters") check(R"( function abc(test) - end +@1 end )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("test")); } @@ -244,11 +261,10 @@ TEST_CASE_FIXTURE(ACFixture, "function_parameters") TEST_CASE_FIXTURE(ACFixture, "get_member_completions") { check(R"( - local a = table. -- Line 1 - -- | Column 23 + local a = table.@1 )"); - auto ac = autocomplete(1, 24); + auto ac = autocomplete('1'); CHECK_EQ(16, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); @@ -260,10 +276,10 @@ TEST_CASE_FIXTURE(ACFixture, "nested_member_completions") { check(R"( local tbl = { abc = { def = 1234, egh = false } } - tbl.abc. + tbl.abc. @1 )"); - auto ac = autocomplete(2, 17); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("def")); CHECK(ac.entryMap.count("egh")); @@ -274,10 +290,10 @@ TEST_CASE_FIXTURE(ACFixture, "unsealed_table") check(R"( local tbl = {} tbl.prop = 5 - tbl. + tbl.@1 )"); - auto ac = autocomplete(3, 12); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("prop")); } @@ -288,10 +304,10 @@ TEST_CASE_FIXTURE(ACFixture, "unsealed_table_2") local tbl = {} local inner = { prop = 5 } tbl.inner = inner - tbl.inner. + tbl.inner. @1 )"); - auto ac = autocomplete(4, 19); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("prop")); } @@ -302,10 +318,10 @@ TEST_CASE_FIXTURE(ACFixture, "cyclic_table") local abc = {} local def = { abc = abc } abc.def = def - abc.def. + abc.def. @1 )"); - auto ac = autocomplete(4, 17); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); } @@ -315,11 +331,11 @@ TEST_CASE_FIXTURE(ACFixture, "table_union") type t1 = { a1 : string, b2 : number } type t2 = { b2 : string, c3 : string } function func(abc : t1 | t2) - abc. + abc. @1 end )"); - auto ac = autocomplete(4, 18); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("b2")); } @@ -330,11 +346,11 @@ TEST_CASE_FIXTURE(ACFixture, "table_intersection") type t1 = { a1 : string, b2 : number } type t2 = { b2 : string, c3 : string } function func(abc : t1 & t2) - abc. + abc. @1 end )"); - auto ac = autocomplete(4, 18); + auto ac = autocomplete('1'); CHECK_EQ(3, ac.entryMap.size()); CHECK(ac.entryMap.count("a1")); CHECK(ac.entryMap.count("b2")); @@ -344,20 +360,19 @@ TEST_CASE_FIXTURE(ACFixture, "table_intersection") TEST_CASE_FIXTURE(ACFixture, "get_string_completions") { check(R"( - local a = ("foo"): -- Line 1 - -- | Column 26 + local a = ("foo"):@1 )"); - auto ac = autocomplete(1, 26); + auto ac = autocomplete('1'); CHECK_EQ(17, ac.entryMap.size()); } TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_new_statement") { - check(""); + check("@1"); - auto ac = autocomplete(0, 0); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); @@ -366,12 +381,12 @@ TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_new_statement") TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_the_very_start_of_the_script") { - check(R"( + check(R"(@1 function aaa() end )"); - auto ac = autocomplete(0, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); } @@ -382,11 +397,11 @@ TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") local game = { GetService=function(s) return 'hello' end } function a() - game: + game: @1 end )"); - auto ac = autocomplete(4, 19); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); @@ -396,10 +411,10 @@ TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") TEST_CASE_FIXTURE(ACFixture, "method_call_inside_if_conditional") { check(R"( - if table: + if table: @1 )"); - auto ac = autocomplete(1, 19); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); CHECK(ac.entryMap.count("concat")); @@ -411,12 +426,12 @@ TEST_CASE_FIXTURE(ACFixture, "statement_between_two_statements") check(R"( function getmyscripts() end - g + g@1 getmyscripts() )"); - auto ac = autocomplete(3, 9); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); @@ -431,11 +446,11 @@ TEST_CASE_FIXTURE(ACFixture, "bias_toward_inner_scope") function B() local A = {two=2} - A + A @1 end )"); - auto ac = autocomplete(6, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("A")); @@ -448,12 +463,12 @@ TEST_CASE_FIXTURE(ACFixture, "bias_toward_inner_scope") TEST_CASE_FIXTURE(ACFixture, "recommend_statement_starting_keywords") { - check(""); - auto ac = autocomplete(0, 0); + check("@1"); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("local")); - check("local i = "); - auto ac2 = autocomplete(0, 10); + check("local i = @1"); + auto ac2 = autocomplete('1'); CHECK(!ac2.entryMap.count("local")); } @@ -464,9 +479,9 @@ TEST_CASE_FIXTURE(ACFixture, "do_not_overwrite_context_sensitive_kws") end - )"); +@1 )"); - auto ac = autocomplete(5, 0); + auto ac = autocomplete('1'); AutocompleteEntry entry = ac.entryMap["continue"]; CHECK(entry.kind == AutocompleteEntryKind::Binding); @@ -480,11 +495,11 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_comment") function foo:bar() end --[[ - foo: + foo:@1 ]] )"); - auto ac = autocomplete(6, 16); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -492,10 +507,10 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_comment") TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_the_end_of_a_comment") { check(R"( - --!strict + --!strict@1 )"); - auto ac = autocomplete(1, 17); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -505,10 +520,10 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_co ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; check(R"( - --[[ + --[[ @1 )"); - auto ac = autocomplete(1, 13); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -517,129 +532,129 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_co { ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - check("--[["); + check("--[[@1"); - auto ac = autocomplete(0, 4); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_middle_keywords") { check(R"( - for x = + for x @1= )"); - auto ac1 = autocomplete(1, 14); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 0); CHECK_EQ(ac1.entryMap.count("end"), 0); check(R"( - for x = 1 + for x =@1 1 )"); - auto ac2 = autocomplete(1, 15); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("do"), 0); CHECK_EQ(ac2.entryMap.count("end"), 0); check(R"( - for x = 1, 2 + for x = 1,@1 2 )"); - auto ac3 = autocomplete(1, 18); + auto ac3 = autocomplete('1'); CHECK_EQ(1, ac3.entryMap.size()); CHECK_EQ(ac3.entryMap.count("do"), 1); check(R"( - for x = 1, 2, + for x = 1, @12, )"); - auto ac4 = autocomplete(1, 19); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.count("do"), 0); CHECK_EQ(ac4.entryMap.count("end"), 0); check(R"( - for x = 1, 2, 5 + for x = 1, 2, @15 )"); - auto ac5 = autocomplete(1, 22); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.count("do"), 1); CHECK_EQ(ac5.entryMap.count("end"), 0); check(R"( - for x = 1, 2, 5 f + for x = 1, 2, 5 f@1 )"); - auto ac6 = autocomplete(1, 25); + auto ac6 = autocomplete('1'); CHECK_EQ(ac6.entryMap.size(), 1); CHECK_EQ(ac6.entryMap.count("do"), 1); check(R"( - for x = 1, 2, 5 do + for x = 1, 2, 5 do @1 )"); - auto ac7 = autocomplete(1, 32); + auto ac7 = autocomplete('1'); CHECK_EQ(ac7.entryMap.count("end"), 1); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_in_middle_keywords") { check(R"( - for + for @1 )"); - auto ac1 = autocomplete(1, 12); + auto ac1 = autocomplete('1'); CHECK_EQ(0, ac1.entryMap.size()); check(R"( - for x + for x@1 @2 )"); - auto ac2 = autocomplete(1, 13); + auto ac2 = autocomplete('1'); CHECK_EQ(0, ac2.entryMap.size()); - auto ac2a = autocomplete(1, 14); + auto ac2a = autocomplete('2'); CHECK_EQ(1, ac2a.entryMap.size()); CHECK_EQ(1, ac2a.entryMap.count("in")); check(R"( - for x in y + for x in y@1 )"); - auto ac3 = autocomplete(1, 18); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("table"), 1); CHECK_EQ(ac3.entryMap.count("do"), 0); check(R"( - for x in y + for x in y @1 )"); - auto ac4 = autocomplete(1, 19); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.size(), 1); CHECK_EQ(ac4.entryMap.count("do"), 1); check(R"( - for x in f f + for x in f f@1 )"); - auto ac5 = autocomplete(1, 20); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.size(), 1); CHECK_EQ(ac5.entryMap.count("do"), 1); check(R"( - for x in y do + for x in y do @1 )"); - auto ac6 = autocomplete(1, 23); + auto ac6 = autocomplete('1'); CHECK_EQ(ac6.entryMap.count("in"), 0); CHECK_EQ(ac6.entryMap.count("table"), 1); CHECK_EQ(ac6.entryMap.count("end"), 1); CHECK_EQ(ac6.entryMap.count("function"), 1); check(R"( - for x in y do e + for x in y do e@1 )"); - auto ac7 = autocomplete(1, 23); + auto ac7 = autocomplete('1'); CHECK_EQ(ac7.entryMap.count("in"), 0); CHECK_EQ(ac7.entryMap.count("table"), 1); CHECK_EQ(ac7.entryMap.count("end"), 1); @@ -649,33 +664,33 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_in_middle_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") { check(R"( - while + while@1 )"); - auto ac1 = autocomplete(1, 13); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 0); CHECK_EQ(ac1.entryMap.count("end"), 0); check(R"( - while true + while true @1 )"); - auto ac2 = autocomplete(1, 19); + auto ac2 = autocomplete('1'); CHECK_EQ(1, ac2.entryMap.size()); CHECK_EQ(ac2.entryMap.count("do"), 1); check(R"( - while true do + while true do @1 )"); - auto ac3 = autocomplete(1, 23); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("end"), 1); check(R"( - while true d + while true d@1 )"); - auto ac4 = autocomplete(1, 20); + auto ac4 = autocomplete('1'); CHECK_EQ(1, ac4.entryMap.size()); CHECK_EQ(ac4.entryMap.count("do"), 1); } @@ -683,10 +698,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") { check(R"( - if + if @1 )"); - auto ac1 = autocomplete(1, 13); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("then"), 0); CHECK_EQ(ac1.entryMap.count("function"), 1); // FIXME: This is kind of dumb. It is technically syntactically valid but you can never do anything interesting with this. @@ -696,10 +711,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") CHECK_EQ(ac1.entryMap.count("end"), 0); check(R"( - if x + if x @1 )"); - auto ac2 = autocomplete(1, 14); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("then"), 1); CHECK_EQ(ac2.entryMap.count("function"), 0); CHECK_EQ(ac2.entryMap.count("else"), 0); @@ -707,20 +722,20 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") CHECK_EQ(ac2.entryMap.count("end"), 0); check(R"( - if x t + if x t@1 )"); - auto ac3 = autocomplete(1, 14); + auto ac3 = autocomplete('1'); CHECK_EQ(1, ac3.entryMap.size()); CHECK_EQ(ac3.entryMap.count("then"), 1); check(R"( if x then - +@1 end )"); - auto ac4 = autocomplete(2, 0); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.count("then"), 0); CHECK_EQ(ac4.entryMap.count("else"), 1); CHECK_EQ(ac4.entryMap.count("function"), 1); @@ -729,11 +744,11 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") check(R"( if x then - t + t@1 end )"); - auto ac4a = autocomplete(2, 13); + auto ac4a = autocomplete('1'); CHECK_EQ(ac4a.entryMap.count("then"), 0); CHECK_EQ(ac4a.entryMap.count("table"), 1); CHECK_EQ(ac4a.entryMap.count("else"), 1); @@ -741,12 +756,12 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") check(R"( if x then - +@1 elseif x then end )"); - auto ac5 = autocomplete(2, 0); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.count("then"), 0); CHECK_EQ(ac5.entryMap.count("function"), 1); CHECK_EQ(ac5.entryMap.count("else"), 0); @@ -757,10 +772,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_until_in_repeat") { check(R"( - repeat + repeat @1 )"); - auto ac = autocomplete(1, 16); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); CHECK_EQ(ac.entryMap.count("until"), 1); } @@ -769,48 +784,48 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_until_expression") { check(R"( repeat - until + until @1 )"); - auto ac = autocomplete(2, 16); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); } TEST_CASE_FIXTURE(ACFixture, "local_names") { check(R"( - local ab + local ab@1 )"); - auto ac1 = autocomplete(1, 16); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.size(), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); check(R"( - local ab, cd + local ab, cd@1 )"); - auto ac2 = autocomplete(1, 20); + auto ac2 = autocomplete('1'); CHECK(ac2.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_with_fn_exprs") { check(R"( - local function f() + local function f() @1 )"); - auto ac = autocomplete(1, 28); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("end"), 1); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_with_lambda") { check(R"( - local a = function() local bar = foo en + local a = function() local bar = foo en@1 )"); - auto ac = autocomplete(1, 47); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("end"), 1); } @@ -818,10 +833,10 @@ TEST_CASE_FIXTURE(ACFixture, "stop_at_first_stat_when_recommending_keywords") { check(R"( repeat - for x + for x @1 )"); - auto ac1 = autocomplete(2, 18); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("in"), 1); CHECK_EQ(ac1.entryMap.count("until"), 0); } @@ -829,112 +844,112 @@ TEST_CASE_FIXTURE(ACFixture, "stop_at_first_stat_when_recommending_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_repeat_middle_keyword") { check(R"( - repeat + repeat @1 )"); - auto ac1 = autocomplete(1, 15); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); CHECK_EQ(ac1.entryMap.count("until"), 1); check(R"( - repeat f f + repeat f f@1 )"); - auto ac2 = autocomplete(1, 18); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("function"), 1); CHECK_EQ(ac2.entryMap.count("until"), 1); check(R"( repeat - u + u@1 until )"); - auto ac3 = autocomplete(2, 13); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("until"), 0); } TEST_CASE_FIXTURE(ACFixture, "local_function") { check(R"( - local f + local f@1 )"); - auto ac1 = autocomplete(1, 15); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.size(), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); check(R"( - local f, cd + local f@1, cd )"); - auto ac2 = autocomplete(1, 15); + auto ac2 = autocomplete('1'); CHECK(ac2.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "local_function") { check(R"( - local function + local function @1 )"); - auto ac = autocomplete(1, 23); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( - local function s + local function @1s@2 )"); - ac = autocomplete(1, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); - ac = autocomplete(1, 24); + ac = autocomplete('2'); CHECK(ac.entryMap.empty()); check(R"( - local function () + local function @1()@2 )"); - ac = autocomplete(1, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); - ac = autocomplete(1, 25); + ac = autocomplete('2'); CHECK(ac.entryMap.count("end")); check(R"( - local function something + local function something@1 )"); - ac = autocomplete(1, 32); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( local tbl = {} - function tbl.something() end + function tbl.something@1() end )"); - ac = autocomplete(2, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "local_function_params") { check(R"( - local function abc(def) + local function @1a@2bc(@3d@4ef)@5 @6 )"); - CHECK(autocomplete(1, 23).entryMap.empty()); - CHECK(autocomplete(1, 24).entryMap.empty()); - CHECK(autocomplete(1, 27).entryMap.empty()); - CHECK(autocomplete(1, 28).entryMap.empty()); - CHECK(!autocomplete(1, 31).entryMap.empty()); + CHECK(autocomplete('1').entryMap.empty()); + CHECK(autocomplete('2').entryMap.empty()); + CHECK(autocomplete('3').entryMap.empty()); + CHECK(autocomplete('4').entryMap.empty()); + CHECK(!autocomplete('5').entryMap.empty()); - CHECK(!autocomplete(1, 32).entryMap.empty()); + CHECK(!autocomplete('6').entryMap.empty()); check(R"( local function abc(def) - end +@1 end )"); for (unsigned int i = 23; i < 31; ++i) @@ -943,16 +958,16 @@ TEST_CASE_FIXTURE(ACFixture, "local_function_params") } CHECK(!autocomplete(1, 32).entryMap.empty()); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("abc"), 1); CHECK_EQ(ac2.entryMap.count("def"), 1); check(R"( - local function abc(def, ghi) + local function abc(def, ghi@1) end )"); - auto ac3 = autocomplete(1, 35); + auto ac3 = autocomplete('1'); CHECK(ac3.entryMap.empty()); } @@ -981,48 +996,48 @@ TEST_CASE_FIXTURE(ACFixture, "global_function_params") check(R"( function abc(def) - +@1 end )"); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("abc"), 1); CHECK_EQ(ac2.entryMap.count("def"), 1); check(R"( - function abc(def, ghi) + function abc(def, ghi@1) end )"); - auto ac3 = autocomplete(1, 29); + auto ac3 = autocomplete('1'); CHECK(ac3.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "arguments_to_global_lambda") { check(R"( - abc = function(def, ghi) + abc = function(def, ghi@1) end )"); - auto ac = autocomplete(1, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "function_expr_params") { check(R"( - abc = function(def) + abc = function(def) @1 )"); for (unsigned int i = 20; i < 27; ++i) { CHECK(autocomplete(1, i).entryMap.empty()); } - CHECK(!autocomplete(1, 28).entryMap.empty()); + CHECK(!autocomplete('1').entryMap.empty()); check(R"( - abc = function(def) + abc = function(def) @1 end )"); @@ -1030,25 +1045,25 @@ TEST_CASE_FIXTURE(ACFixture, "function_expr_params") { CHECK(autocomplete(1, i).entryMap.empty()); } - CHECK(!autocomplete(1, 28).entryMap.empty()); + CHECK(!autocomplete('1').entryMap.empty()); check(R"( abc = function(def) - +@1 end )"); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("def"), 1); } TEST_CASE_FIXTURE(ACFixture, "local_initializer") { check(R"( - local a = t + local a = t@1 )"); - auto ac = autocomplete(1, 19); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); CHECK_EQ(ac.entryMap.count("true"), 1); } @@ -1056,20 +1071,20 @@ TEST_CASE_FIXTURE(ACFixture, "local_initializer") TEST_CASE_FIXTURE(ACFixture, "local_initializer_2") { check(R"( - local a= + local a=@1 )"); - auto ac = autocomplete(1, 16); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); } TEST_CASE_FIXTURE(ACFixture, "get_member_completions") { check(R"( - local a = 12.3 + local a = 12.@13 )"); - auto ac = autocomplete(1, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } @@ -1083,21 +1098,21 @@ TEST_CASE_FIXTURE(ACFixture, "sometimes_the_metatable_is_an_error") return setmetatable({x=6}, X) -- oops! end local t = T.new() - t. + t. @1 )"); - autocomplete(8, 12); + autocomplete('1'); // Don't crash! } TEST_CASE_FIXTURE(ACFixture, "local_types_builtin") { check(R"( -local a: n +local a: n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1108,23 +1123,23 @@ TEST_CASE_FIXTURE(ACFixture, "private_types") check(R"( do type num = number - local a: nu - local b: num + local a: n@1u + local b: nu@2m end -local a: nu +local a: nu@3 )"); - auto ac = autocomplete(3, 14); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(4, 15); + ac = autocomplete('2'); CHECK(ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(6, 11); + ac = autocomplete('3'); CHECK(!ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); @@ -1136,11 +1151,11 @@ TEST_CASE_FIXTURE(ACFixture, "type_scoping_easy") type Table = { a: number, b: number } do type Table = { x: string, y: string } - local a: T + local a: T@1 end )"); - auto ac = autocomplete(4, 14); + auto ac = autocomplete('1'); REQUIRE(ac.entryMap.count("Table")); REQUIRE(ac.entryMap["Table"].type); @@ -1198,11 +1213,11 @@ local a: aaa. TEST_CASE_FIXTURE(ACFixture, "argument_types") { check(R"( -local function f(a: n +local function f(a: n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1211,11 +1226,11 @@ local b: string = "don't trip" TEST_CASE_FIXTURE(ACFixture, "return_types") { check(R"( -local function f(a: number): n +local function f(a: number): n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 30); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1225,10 +1240,10 @@ TEST_CASE_FIXTURE(ACFixture, "as_types") { check(R"( local a: any = 5 -local b: number = (a :: n +local b: number = (a :: n@1 )"); - auto ac = autocomplete(2, 25); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1237,34 +1252,34 @@ local b: number = (a :: n TEST_CASE_FIXTURE(ACFixture, "function_type_types") { check(R"( -local a: (n -local b: (number, (n -local c: (number, (number) -> n -local d: (number, (number) -> (number, n -local e: (n: n +local a: (n@1 +local b: (number, (n@2 +local c: (number, (number) -> n@3 +local d: (number, (number) -> (number, n@4 +local e: (n: n@5 )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(2, 20); + ac = autocomplete('2'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(3, 31); + ac = autocomplete('3'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(4, 40); + ac = autocomplete('4'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(5, 14); + ac = autocomplete('5'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1276,11 +1291,11 @@ TEST_CASE_FIXTURE(ACFixture, "generic_types") ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); check(R"( -function f(a: T +function f(a: T@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 25); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("Tee")); } @@ -1293,10 +1308,10 @@ local function target(a: number, b: string) return a + #b end local one = 4 local two = "hello" -return target(o +return target(o@1 )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -1307,10 +1322,10 @@ local function target(a: number, b: string) return a + #b end local one = 4 local two = "hello" -return target(one, t +return target(one, t@1 )"); - ac = autocomplete(5, 20); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1321,10 +1336,10 @@ return target(one, t local function target(a: number, b: string) return a + #b end local a = { one = 4, two = "hello" } -return target(a. +return target(a.@1 )"); - ac = autocomplete(4, 16); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -1334,10 +1349,10 @@ return target(a. local function target(a: number, b: string) return a + #b end local a = { one = 4, two = "hello" } -return target(a.one, a. +return target(a.one, a.@1 )"); - ac = autocomplete(4, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1348,10 +1363,10 @@ return target(a.one, a. local function target(a: string?) return #b end local a = { one = 4, two = "hello" } -return target(a. +return target(a.@1 )"); - ac = autocomplete(4, 16); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1363,10 +1378,10 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_suggestion_in_table") check(R"( type Foo = { a: number, b: string } local a = { one = 4, two = "hello" } -local b: Foo = { a = a. +local b: Foo = { a = a.@1 )"); - auto ac = autocomplete(3, 23); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -1375,10 +1390,10 @@ local b: Foo = { a = a. check(R"( type Foo = { a: number, b: string } local a = { one = 4, two = "hello" } -local b: Foo = { b = a. +local b: Foo = { b = a.@1 )"); - ac = autocomplete(3, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1392,10 +1407,10 @@ local function target(a: number, b: string) return a + #b end local function bar1(a: number) return -a end local function bar2(a: string) reutrn a .. 'x' end -return target(b +return target(b@1 )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1406,10 +1421,10 @@ local function target(a: number, b: string) return a + #b end local function bar1(a: number) return -a end local function bar2(a: string) return a .. 'x' end -return target(bar1, b +return target(bar1, b@1 )"); - ac = autocomplete(5, 21); + ac = autocomplete('1'); CHECK(ac.entryMap.count("bar2")); CHECK(ac.entryMap["bar2"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1420,10 +1435,10 @@ local function target(a: number, b: string) return a + #b end local function bar1(a: number): (...number) return -a, a end local function bar2(a: string) reutrn a .. 'x' end -return target(b +return target(b@1 )"); - ac = autocomplete(5, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1433,69 +1448,69 @@ return target(b TEST_CASE_FIXTURE(ACFixture, "type_correct_local_type_suggestion") { check(R"( -local b: s = "str" +local b: s@1 = "str" )"); - auto ac = autocomplete(1, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f() return "str" end -local b: s = f() +local b: s@1 = f() )"); - ac = autocomplete(2, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: s, c: n = "str", 2 +local b: s@1, c: n@2 = "str", 2 )"); - ac = autocomplete(1, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(1, 16); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f() return 1, "str", 3 end -local a: b, b: n, c: s, d: n = false, f() +local a: b@1, b: n@2, c: s@3, d: n@4 = false, f() )"); - ac = autocomplete(2, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("boolean")); CHECK(ac.entryMap["boolean"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 16); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 22); + ac = autocomplete('3'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 28); + ac = autocomplete('4'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f(): ...number return 1, 2, 3 end -local a: boolean, b: n = false, f() +local a: boolean, b: n@1 = false, f() )"); - ac = autocomplete(2, 22); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1504,46 +1519,46 @@ local a: boolean, b: n = false, f() TEST_CASE_FIXTURE(ACFixture, "type_correct_function_type_suggestion") { check(R"( -local b: (n) -> number = function(a: number, b: string) return a + #b end +local b: (n@1) -> number = function(a: number, b: string) return a + #b end )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, s = function(a: number, b: string) return a + #b end +local b: (number, s@1 = function(a: number, b: string) return a + #b end )"); - ac = autocomplete(1, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, string) -> b = function(a: number, b: string): boolean return a + #b == 0 end +local b: (number, string) -> b@1 = function(a: number, b: string): boolean return a + #b == 0 end )"); - ac = autocomplete(1, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("boolean")); CHECK(ac.entryMap["boolean"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, ...s) = function(a: number, ...: string) return a end +local b: (number, ...s@1) = function(a: number, ...: string) return a end )"); - ac = autocomplete(1, 22); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number) -> ...s = function(a: number): ...string return "a", "b", "c" end +local b: (number) -> ...s@1 = function(a: number): ...string return "a", "b", "c" end )"); - ac = autocomplete(1, 25); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1552,24 +1567,24 @@ local b: (number) -> ...s = function(a: number): ...string return "a", "b", "c" TEST_CASE_FIXTURE(ACFixture, "type_correct_full_type_suggestion") { check(R"( -local b: = "str" +local b:@1 @2= "str" )"); - auto ac = autocomplete(1, 8); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(1, 9); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: = function(a: number) return -a end +local b: @1= function(a: number) return -a end )"); - ac = autocomplete(1, 9); + ac = autocomplete('1'); CHECK(ac.entryMap.count("(number) -> number")); CHECK(ac.entryMap["(number) -> number"].typeCorrect == TypeCorrectKind::Correct); @@ -1580,12 +1595,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_argument_type_suggestion") check(R"( local function target(a: number, b: string) return a + #b end -local function d(a: n, b) +local function d(a: n@1, b) return target(a, b) end )"); - auto ac = autocomplete(3, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1593,12 +1608,12 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a, b: s) +local function d(a, b: s@1) return target(a, b) end )"); - ac = autocomplete(3, 24); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1606,17 +1621,17 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a: , b) +local function d(a:@1 @2, b) return target(a, b) end )"); - ac = autocomplete(3, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 20); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1624,17 +1639,17 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a, b: ): number +local function d(a, b: @1)@2: number return target(a, b) end )"); - ac = autocomplete(3, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 24); + ac = autocomplete('2'); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::None); } @@ -1644,10 +1659,10 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_suggestion") check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: +local x = target(function(a: @1 )"); - auto ac = autocomplete(3, 29); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1655,10 +1670,10 @@ local x = target(function(a: check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: n +local x = target(function(a: n@1 )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1666,17 +1681,17 @@ local x = target(function(a: n check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: n, b: ) +local x = target(function(a: n@1, b: @2) return a + #b end) )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 35); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1684,12 +1699,12 @@ end) check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(a: n) +local x = target(function(a: n@1) return a end )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1700,12 +1715,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_pack_suggestio check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(...:n) +local x = target(function(...:n@1) return a end )"); - auto ac = autocomplete(3, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1713,12 +1728,12 @@ end check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(a:number, b:number, ...:) +local x = target(function(a:number, b:number, ...:@1) return a + b end )"); - ac = autocomplete(3, 50); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1729,12 +1744,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_suggestion") check(R"( local function target(callback: () -> number) return callback() end -local x = target(function(): n +local x = target(function(): n@1 return 1 end )"); - auto ac = autocomplete(3, 30); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1742,12 +1757,12 @@ end check(R"( local function target(callback: () -> (number, number)) return callback() end -local x = target(function(): (number, n +local x = target(function(): (number, n@1 return 1, 2 end )"); - ac = autocomplete(3, 39); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1758,12 +1773,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_pack_suggestion" check(R"( local function target(callback: () -> ...number) return callback() end -local x = target(function(): ...n +local x = target(function(): ...n@1 return 1, 2, 3 end )"); - auto ac = autocomplete(3, 33); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1771,12 +1786,12 @@ end check(R"( local function target(callback: () -> ...number) return callback() end -local x = target(function(): (number, number, ...n +local x = target(function(): (number, number, ...n@1 return 1, 2, 3 end )"); - ac = autocomplete(3, 50); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1787,10 +1802,10 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_suggestion_opt check(R"( local function target(callback: nil | (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: +local x = target(function(a: @1 )"); - auto ac = autocomplete(3, 29); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1803,21 +1818,21 @@ local t = {} t.x = 5 function t:target(callback: (a: number, b: string) -> number) return callback(self.x, "hello") end -local x = t:target(function(a: , b: ) end) -local y = t.target(t, function(a: number, b: ) end) +local x = t:target(function(a: @1, b:@2 ) end) +local y = t.target(t, function(a: number, b: @3) end) )"); - auto ac = autocomplete(5, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(5, 35); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(6, 45); + ac = autocomplete('3'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1899,26 +1914,26 @@ TEST_CASE_FIXTURE(ACFixture, "do_not_suggest_synthetic_table_name") { check(R"( local foo = { a = 1, b = 2 } -local bar: = foo +local bar: @1= foo )"); - auto ac = autocomplete(2, 11); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.count("foo")); } -// CLI-45692: Remove UnfrozenFixture here -TEST_CASE_FIXTURE(UnfrozenFixture, "type_correct_function_no_parenthesis") +// CLI-45692: Remove UnfrozenACFixture here +TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_function_no_parenthesis") { check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end local function bar2(a: string) reutrn a .. 'x' end -return target(b +return target(b@1 )"); - auto ac = autocomplete(frontend, "MainModule", Position{5, 15}, nullCallback); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::Correct); @@ -1930,16 +1945,16 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_sealed_table") { check(R"( local function f(a: { x: number, y: number }) return a.x + a.y end -local fp: = f +local fp: @1= f )"); - auto ac = autocomplete(2, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -// CLI-45692: Remove UnfrozenFixture here -TEST_CASE_FIXTURE(UnfrozenFixture, "type_correct_keywords") +// CLI-45692: Remove UnfrozenACFixture here +TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_keywords") { check(R"( local function a(x: boolean) end @@ -1951,33 +1966,33 @@ local function e(x: ((number) -> string) & ((boolean) -> number)) end local tru = {} local ni = false -local ac = a(t) -local bc = b(n) -local cc = c(f) -local dc = d(f) -local ec = e(f) +local ac = a(t@1) +local bc = b(n@2) +local cc = c(f@3) +local dc = d(f@4) +local ec = e(f@5) )"); - auto ac = autocomplete(frontend, "MainModule", Position{10, 14}, nullCallback); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("tru")); CHECK(ac.entryMap["tru"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{11, 14}, nullCallback); + ac = autocomplete('2'); CHECK(ac.entryMap.count("ni")); CHECK(ac.entryMap["ni"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["nil"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{12, 14}, nullCallback); + ac = autocomplete('3'); CHECK(ac.entryMap.count("false")); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{13, 14}, nullCallback); + ac = autocomplete('4'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{14, 14}, nullCallback); + ac = autocomplete('5'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); } @@ -1988,10 +2003,10 @@ local target: ((number) -> string) & ((string) -> number)) local one = 4 local two = "hello" -return target(o) +return target(o@1) )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2002,10 +2017,10 @@ local target: ((number) -> string) & ((number) -> number)) local one = 4 local two = "hello" -return target(o) +return target(o@1) )"); - ac = autocomplete(5, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2016,10 +2031,10 @@ local target: ((number, number) -> string) & ((string) -> number)) local one = 4 local two = "hello" -return target(1, o) +return target(1, o@1) )"); - ac = autocomplete(5, 18); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2032,10 +2047,10 @@ TEST_CASE_FIXTURE(ACFixture, "optional_members") local a = { x = 2, y = 3 } type A = typeof(a) local b: A? = a -return b. +return b.@1 )"); - auto ac = autocomplete(4, 9); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("x")); @@ -2045,10 +2060,10 @@ return b. local a = { x = 2, y = 3 } type A = typeof(a) local b: nil | A = a -return b. +return b.@1 )"); - ac = autocomplete(4, 9); + ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("x")); @@ -2056,10 +2071,10 @@ return b. check(R"( local b: nil | nil -return b. +return b.@1 )"); - ac = autocomplete(2, 9); + ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -2067,26 +2082,26 @@ return b. TEST_CASE_FIXTURE(ACFixture, "no_function_name_suggestions") { check(R"( -function na +function na@1 )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( -local function +local function @1 )"); - ac = autocomplete(1, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( -local function na +local function na@1 )"); - ac = autocomplete(1, 17); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } @@ -2095,20 +2110,20 @@ TEST_CASE_FIXTURE(ACFixture, "skip_current_local") { check(R"( local other = 1 -local name = na +local name = na@1 )"); - auto ac = autocomplete(2, 15); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.count("name")); CHECK(ac.entryMap.count("other")); check(R"( local other = 1 -local name, test = na +local name, test = na@1 )"); - ac = autocomplete(2, 21); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("name")); CHECK(!ac.entryMap.count("test")); @@ -2119,26 +2134,26 @@ TEST_CASE_FIXTURE(ACFixture, "keyword_members") { check(R"( local a = { done = 1, forever = 2 } -local b = a.do -local c = a.for -local d = a. +local b = a.do@1 +local c = a.for@2 +local d = a.@3 do end )"); - auto ac = autocomplete(2, 14); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); CHECK(ac.entryMap.count("forever")); - ac = autocomplete(3, 15); + ac = autocomplete('2'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); CHECK(ac.entryMap.count("forever")); - ac = autocomplete(4, 12); + ac = autocomplete('3'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); @@ -2150,10 +2165,10 @@ TEST_CASE_FIXTURE(ACFixture, "keyword_methods") check(R"( local a = {} function a:done() end -local b = a:do +local b = a:do@1 )"); - auto ac = autocomplete(3, 14); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); @@ -2247,29 +2262,29 @@ local elsewhere = false local doover = false local endurance = true -if 1 then -else +if 1 then@1 +else@2 end -while false do +while false do@3 end -repeat +repeat@4 until )"); - auto ac = autocomplete(6, 9); + auto ac = autocomplete('1'); CHECK(ac.entryMap.size() == 1); CHECK(ac.entryMap.count("then")); - ac = autocomplete(7, 4); + ac = autocomplete('2'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); - ac = autocomplete(10, 14); + ac = autocomplete('3'); CHECK(ac.entryMap.count("do")); - ac = autocomplete(13, 6); + ac = autocomplete('4'); CHECK(ac.entryMap.count("do")); // FIXME: ideally we want to handle start and end of all statements as well @@ -2284,11 +2299,11 @@ local elsewhere = false if true then return 1 -el +el@1 end )"); - auto ac = autocomplete(5, 2); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); CHECK(ac.entryMap.count("elsewhere") == 0); @@ -2300,11 +2315,11 @@ if true then return 1 else return 2 -el +el@1 end )"); - ac = autocomplete(7, 2); + ac = autocomplete('1'); CHECK(ac.entryMap.count("else") == 0); CHECK(ac.entryMap.count("elseif") == 0); CHECK(ac.entryMap.count("elsewhere")); @@ -2316,10 +2331,10 @@ if true then print("1") elif true then print("2") -el +el@1 end )"); - ac = autocomplete(7, 2); + ac = autocomplete('1'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); CHECK(ac.entryMap.count("elsewhere")); @@ -2360,30 +2375,30 @@ TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys") { check(R"( type Test = { first: number, second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - auto ac = autocomplete(2, 19); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Intersection check(R"( type Test = { first: number } & { second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Union check(R"( type Test = { first: number, second: number } | { second: number, third: number } -local t: Test = { s } +local t: Test = { s@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("second")); CHECK(!ac.entryMap.count("first")); CHECK(!ac.entryMap.count("third")); @@ -2391,60 +2406,60 @@ local t: Test = { s } // No parenthesis suggestion check(R"( type Test = { first: (number) -> number, second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap["first"].parens == ParenthesesRecommendation::None); // When key is changed check(R"( type Test = { first: number, second: number } -local t: Test = { f = 2 } +local t: Test = { f@1 = 2 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Alternative key syntax check(R"( type Test = { first: number, second: number } -local t: Test = { ["f"] } +local t: Test = { ["f@1"] } )"); - ac = autocomplete(2, 21); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Not an alternative key syntax check(R"( type Test = { first: number, second: number } -local t: Test = { "f" } +local t: Test = { "f@1" } )"); - ac = autocomplete(2, 20); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("first")); CHECK(!ac.entryMap.count("second")); // Skip keys that are already defined check(R"( type Test = { first: number, second: number } -local t: Test = { first = 2, s } +local t: Test = { first = 2, s@1 } )"); - ac = autocomplete(2, 30); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Don't skip active key check(R"( type Test = { first: number, second: number } -local t: Test = { first } +local t: Test = { first@1 } )"); - ac = autocomplete(2, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); @@ -2452,22 +2467,22 @@ local t: Test = { first } check(R"( local t = { { first = 5, second = 10 }, - { f } + { f@1 } } )"); - ac = autocomplete(3, 7); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); check(R"( local t = { [2] = { first = 5, second = 10 }, - [5] = { f } + [5] = { f@1 } } )"); - ac = autocomplete(3, 13); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); } @@ -2502,15 +2517,15 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_ifelse_expressions") local temp = false local even = true; local a = true -a = if t1@emp then t -a = if temp t2@ -a = if temp then e3@ -a = if temp then even e4@ -a = if temp then even elseif t5@ -a = if temp then even elseif true t6@ -a = if temp then even elseif true then t7@ -a = if temp then even elseif true then temp e8@ -a = if temp then even elseif true then temp else e9@ +a = if t@1emp then t +a = if temp t@2 +a = if temp then e@3 +a = if temp then even e@4 +a = if temp then even elseif t@5 +a = if temp then even elseif true t@6 +a = if temp then even elseif true then t@7 +a = if temp then even elseif true then temp e@8 +a = if temp then even elseif true then temp else e@9 )"); auto ac = autocomplete('1'); @@ -2573,4 +2588,20 @@ a = if temp then even elseif true then temp else e9@ } } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") +{ + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + check(R"( +type A = () -> T... +local a: A<(number, s@1> + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); +} + TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 26bc77f..29c33f7 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -32,6 +32,55 @@ std::optional TestFileResolver::fromAstFragment(AstExpr* expr) const return std::nullopt; } +std::optional TestFileResolver::resolveModule(const ModuleInfo* context, AstExpr* expr) +{ + if (AstExprGlobal* g = expr->as()) + { + if (g->name == "game") + return ModuleInfo{"game"}; + if (g->name == "workspace") + return ModuleInfo{"workspace"}; + if (g->name == "script") + return context ? std::optional(*context) : std::nullopt; + } + else if (AstExprIndexName* i = expr->as(); i && context) + { + if (i->index == "Parent") + { + std::string_view view = context->name; + size_t lastSeparatorIndex = view.find_last_of('/'); + + if (lastSeparatorIndex == std::string_view::npos) + return std::nullopt; + + return ModuleInfo{ModuleName(view.substr(0, lastSeparatorIndex)), context->optional}; + } + else + { + return ModuleInfo{context->name + '/' + i->index.value, context->optional}; + } + } + else if (AstExprIndexExpr* i = expr->as(); i && context) + { + if (AstExprConstantString* index = i->index->as()) + { + return ModuleInfo{context->name + '/' + std::string(index->value.data, index->value.size), context->optional}; + } + } + else if (AstExprCall* call = expr->as(); call && call->self && call->args.size >= 1 && context) + { + if (AstExprConstantString* index = call->args.data[0]->as()) + { + AstName func = call->func->as()->index; + + if (func == "GetService" && context->name == "game") + return ModuleInfo{"game/" + std::string(index->value.data, index->value.size)}; + } + } + + return std::nullopt; +} + ModuleName TestFileResolver::concat(const ModuleName& lhs, std::string_view rhs) const { return lhs + "/" + ModuleName(rhs); diff --git a/tests/Fixture.h b/tests/Fixture.h index c6294b0..1480a7f 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -65,6 +65,8 @@ struct TestFileResolver } std::optional fromAstFragment(AstExpr* expr) const override; + std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override; + ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override; std::optional getParentModuleName(const ModuleName& name) const override; diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 3f33a5d..fbfec63 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -58,6 +58,35 @@ struct NaiveFileResolver : NullFileResolver return std::nullopt; } + std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override + { + if (AstExprGlobal* g = expr->as()) + { + if (g->name == "Modules") + return ModuleInfo{"Modules"}; + + if (g->name == "game") + return ModuleInfo{"game"}; + } + else if (AstExprIndexName* i = expr->as()) + { + if (context) + return ModuleInfo{context->name + '/' + i->index.value, context->optional}; + } + else if (AstExprCall* call = expr->as(); call && call->self && call->args.size >= 1 && context) + { + if (AstExprConstantString* index = call->args.data[0]->as()) + { + AstName func = call->func->as()->index; + + if (func == "GetService" && context->name == "game") + return ModuleInfo{"game/" + std::string(index->value.data, index->value.size)}; + } + } + + return std::nullopt; + } + ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override { return lhs + "/" + ModuleName(rhs); @@ -528,7 +557,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "ignore_require_to_nonexistent_file") { fileResolver.source["Modules/A"] = R"( local Modules = script - local B = require(Modules.B :: any) + local B = require(Modules.B) :: any )"; CheckResult result = frontend.check("Modules/A"); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index c8eff39..a9ed139 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1400,6 +1400,8 @@ end TEST_CASE_FIXTURE(Fixture, "TableOperations") { + ScopedFastFlag sff("LuauLinterTableMoveZero", true); + LintResult result = lintTyped(R"( local t = {} local tt = {} @@ -1417,9 +1419,12 @@ table.remove(t, 0) table.remove(t, #t-1) table.insert(t, string.find("hello", "h")) + +table.move(t, 0, #t, 1, tt) +table.move(t, 1, #t, 0, tt) )"); - REQUIRE_EQ(result.warnings.size(), 6); + REQUIRE_EQ(result.warnings.size(), 8); CHECK_EQ(result.warnings[0].text, "table.insert will insert the value before the last element, which is likely a bug; consider removing the " "second argument or wrap it in parentheses to silence"); CHECK_EQ(result.warnings[1].text, "table.insert will append the value to the table; consider removing the second argument for efficiency"); @@ -1429,6 +1434,8 @@ table.insert(t, string.find("hello", "h")) "second argument or wrap it in parentheses to silence"); CHECK_EQ(result.warnings[5].text, "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument"); + CHECK_EQ(result.warnings[6].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + CHECK_EQ(result.warnings[7].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); } TEST_CASE_FIXTURE(Fixture, "DuplicateConditions") diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 1b146ed..18f55d2 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Fixture.h" diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index f3c76d5..931a840 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index cb03a7b..a80718e 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2519,4 +2519,19 @@ TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") } } +TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + AstStat* stat = parse(R"( +type Packed = () -> T... + +type A = Packed +type B = Packed<...number> +type C = Packed<(number, X...)> + )"); + REQUIRE(stat != nullptr); +} + TEST_SUITE_END(); diff --git a/tests/RequireTracer.test.cpp b/tests/RequireTracer.test.cpp index cbd4af2..b9fd04d 100644 --- a/tests/RequireTracer.test.cpp +++ b/tests/RequireTracer.test.cpp @@ -57,6 +57,7 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local") { AstStatBlock* block = parse(R"( local m = workspace.Foo.Bar.Baz + require(m) )"); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); @@ -70,22 +71,22 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local") AstExprIndexName* value = loc->values.data[0]->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value]); + CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value].name); value = value->expr->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo/Bar", result.exprs[value]); + CHECK_EQ("workspace/Foo/Bar", result.exprs[value].name); value = value->expr->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo", result.exprs[value]); + CHECK_EQ("workspace/Foo", result.exprs[value].name); AstExprGlobal* workspace = value->expr->as(); REQUIRE(workspace); REQUIRE(result.exprs.contains(workspace)); - CHECK_EQ("workspace", result.exprs[workspace]); + CHECK_EQ("workspace", result.exprs[workspace].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") @@ -93,9 +94,10 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") AstStatBlock* block = parse(R"( local m = workspace.Foo.Bar.Baz local n = m.Quux + require(n) )"); - REQUIRE_EQ(2, block->body.size); + REQUIRE_EQ(3, block->body.size); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); @@ -104,13 +106,13 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") REQUIRE_EQ(1, local->vars.size); REQUIRE(result.exprs.contains(local->values.data[0])); - CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]]); + CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments") { AstStatBlock* block = parse(R"( - local M = require(workspace.Game.Thing, workspace.Something.Else) + local M = require(workspace.Game.Thing) )"); REQUIRE_EQ(1, block->body.size); @@ -124,52 +126,9 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments") AstExprCall* call = local->values.data[0]->as(); REQUIRE(call != nullptr); - REQUIRE_EQ(2, call->args.size); - - CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]]); - CHECK_EQ("workspace/Something/Else", result.exprs[call->args.data[1]]); -} - -TEST_CASE_FIXTURE(RequireTracerFixture, "follow_GetService_calls") -{ - AstStatBlock* block = parse(R"( - local R = game:GetService('ReplicatedStorage').Roact - local Roact = require(R) - )"); - REQUIRE_EQ(2, block->body.size); - - RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); - - AstStatLocal* local = block->body.data[0]->as(); - REQUIRE(local != nullptr); - - CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[local->values.data[0]]); - - AstStatLocal* local2 = block->body.data[1]->as(); - REQUIRE(local2 != nullptr); - REQUIRE_EQ(1, local2->values.size); - - AstExprCall* call = local2->values.data[0]->as(); - REQUIRE(call != nullptr); REQUIRE_EQ(1, call->args.size); - CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[call->args.data[0]]); -} - -TEST_CASE_FIXTURE(RequireTracerFixture, "follow_WaitForChild_calls") -{ - ScopedFastFlag luauTraceRequireLookupChild("LuauTraceRequireLookupChild", true); - - AstStatBlock* block = parse(R"( -local A = require(workspace:WaitForChild('ReplicatedStorage').Content) -local B = require(workspace:FindFirstChild('ReplicatedFirst').Data) - )"); - - RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); - - REQUIRE_EQ(2, result.requires.size()); - CHECK_EQ("workspace/ReplicatedStorage/Content", result.requires[0].first); - CHECK_EQ("workspace/ReplicatedFirst/Data", result.requires[1].first); + CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof") @@ -200,22 +159,23 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof") REQUIRE(call != nullptr); REQUIRE_EQ(1, call->args.size); - CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]]); + CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "follow_string_indexexpr") { AstStatBlock* block = parse(R"( local R = game["Test"] + require(R) )"); - REQUIRE_EQ(1, block->body.size); + REQUIRE_EQ(2, block->body.size); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); AstStatLocal* local = block->body.data[0]->as(); REQUIRE(local != nullptr); - CHECK_EQ("game/Test", result.exprs[local->values.data[0]]); + CHECK_EQ("game/Test", result.exprs[local->values.data[0]].name); } TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index d7d68c4..e18bf7c 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Fixture.h" @@ -416,8 +417,6 @@ function foo(a, b) return a(b) end TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_TypePack") { - ScopedFastFlag sff{"LuauToStringFollowsBoundTo", true}; - TypeVar tv1{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tv1); ttv->state = TableState::Sealed; diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp new file mode 100644 index 0000000..045f023 --- /dev/null +++ b/tests/TypeInfer.aliases.test.cpp @@ -0,0 +1,557 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeAliases"); + +TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + type F = () -> F? + local function f() + return f + end + + local g: F = f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") +{ + CheckResult result = check(R"( + --!strict + type Node = { Parent: Node?; } + local node: Node; + node.Parent = 1 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("Node?", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types") +{ + CheckResult result = check(R"( + --!strict + type T = { f: a, g: U } + type U = { h: a, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = "hi", g = { h = "lo", i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors") +{ + CheckResult result = check(R"( + --!strict + type T = { f: a, g: U } + type U = { h: b, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = "hi", g = { h = 5, i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_ERRORS(result); + + // We had a UAF in this example caused by not cloning type function arguments + ModulePtr module = frontend.moduleResolver.getModule("MainModule"); + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes); + freeze(module->interfaceTypes); + module->internalTypes.clear(); + module->astTypes.clear(); + + // Make sure the error strings don't include "VALUELESS" + for (auto error : module->errors) + CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error)); +} + +TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors") +{ + CheckResult result = check(R"( + type Pair = {first: T, second: U} + local a: Pair + local b: Pair + + a = b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK_EQ("Pair", toString(tm->wantedType)); + CHECK_EQ("Pair", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition") +{ + CheckResult result = check(R"( + type A = number + type A = string -- Redefinition of type 'A', previously defined at line 1 + local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type") +{ + CheckResult result = check(R"( + type Table = { a: T } + type Wrapped = Table + local l: Wrapped = 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("Wrapped", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2") +{ + CheckResult result = check(R"( + type Table = { a: T } + type Wrapped = (Table) -> string + local l: Wrapped = 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +// Check that recursive intersection type doesn't generate an OOM +TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") +{ + CheckResult result = check(R"( + function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any + end + type t0 = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_)) + _(_) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise") +{ + CheckResult result = check(R"( + local foo: Id = 1 + type Id = T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") +{ + const std::string code = R"( + type A = {v:T, b:B} + type B = {v:T, a:A} + local aa:A + local bb = aa + )"; + + const std::string expected = R"( + type A = {v:T, b:B} + type B = {v:T, a:A} + local aa:A + local bb:A=aa + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); + CheckResult result = check(code); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + type A = () -> (number, B) + type B = () -> (string, A) + local a: A + local b: B + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a"))); + CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "generic_param_remap") +{ + const std::string code = R"( + -- An example of a forwarded use of a type that has different type arguments than parameters + type A = {t:T, u:U, next:A?} + local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } + local bb = aa + )"; + + const std::string expected = R"( + + type A = {t:T, u:U, next:A?} + local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } + local bb:A=aa + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); + CheckResult result = check(code); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") +{ + CheckResult result = check(R"( + export type Foo = number + type Foo = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto dtd = get(result.errors[0]); + REQUIRE(dtd); + CHECK_EQ(dtd->name, "Foo"); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") +{ + ScopedFastFlag sffs3{"LuauGenericFunctions", true}; + ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; + + CheckResult result = check(R"( + type Node = { value: T, child: Node? } + + local function visitor(node: Node?) + local a: Node + + if node then + a = node.child -- Observe the output of the error message. + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto e = get(result.errors[0]); + CHECK_EQ("Node?", toString(e->givenType)); + CHECK_EQ("Node", toString(e->wantedType)); +} + +TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") +{ + fileResolver.source["workspace/A"] = R"( + export type myvec2 = {x: number, y: number} + return {} + )"; + + fileResolver.source["workspace/B"] = R"( + export type myvec3 = {x: number, y: number, z: number} + return {} + )"; + + fileResolver.source["workspace/C"] = R"( + local Foo, Bar = require(workspace.A), require(workspace.B) + + local a: Foo.myvec2 + local b: Bar.myvec3 + )"; + + CheckResult result = frontend.check("workspace/C"); + LUAU_REQUIRE_NO_ERRORS(result); + ModulePtr m = frontend.moduleResolver.modules["workspace/C"]; + + REQUIRE(m != nullptr); + + std::optional aTypeId = lookupName(m->getModuleScope(), "a"); + REQUIRE(aTypeId); + const Luau::TableTypeVar* aType = get(follow(*aTypeId)); + REQUIRE(aType); + REQUIRE(aType->props.size() == 2); + + std::optional bTypeId = lookupName(m->getModuleScope(), "b"); + REQUIRE(bTypeId); + const Luau::TableTypeVar* bType = get(follow(*bTypeId)); + REQUIRE(bType); + REQUIRE(bType->props.size() == 3); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation") +{ + CheckResult result = check("type t10 = typeof(table)"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId ty = getGlobalBinding(frontend.typeChecker, "table"); + CHECK_EQ(toString(ty), "table"); + + const TableTypeVar* ttv = get(ty); + REQUIRE(ttv); + + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation") +{ + CheckResult result = check(R"( +type Cool = { a: number, b: string } +local c: Cool = { a = 1, b = "s" } +type NotCool = Cool +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "Cool"); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename") +{ + CheckResult result = check(R"( +type Cool = { a: number, b: string } +type NotCool = Cool +local c: Cool = { a = 1, b = "s" } +local d: NotCool = { a = 1, b = "s" } +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "Cool"); + + ty = requireType("d"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "NotCool"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation") +{ + CheckResult result = check(R"( +local c = { a = 1, b = "s" } +type Cool = typeof(c) +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK_EQ(ttv->name, "Cool"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type") +{ + fileResolver.source["game/A"] = R"( +export type X = { a: number, b: X? } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + std::optional ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + std::optional ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(follow(*ty1), follow(*ty2)); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type") +{ + fileResolver.source["game/A"] = R"( +export type X = { a: T, b: U, C: X? } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + std::optional ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + std::optional ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true})); + + bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}"); + CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?"); +} + +TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak") +{ + CheckResult result = check(R"( +function get() + return function(obj) return true end +end + +export type f = typeof(get()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak") +{ + CheckResult result = check(R"( +function get() + return {a = 1, b = function(obj) return true end} +end + +export type f = typeof(get()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") +{ + CheckResult result = check(R"( + type Tree = { data: T, children: Forest } + type Forest = {Tree} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + -- OK because forwarded types are used with their parameters. + type Tree = { data: T, children: Forest } + type Forest = {Tree<{T}>} + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + -- Not OK because forwarded types are used with different types than their parameters. + type Forest = {Tree<{T}>} + type Tree = { data: T, children: Forest } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") +{ + CheckResult result = check(R"( + type Tree1 = { data: T, children: {Tree2} } + type Tree2 = { data: U, children: {Tree1} } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + type Tree1 = { data: T, children: {Tree2} } + type Tree2 = { data: U, children: {Tree1} } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases") +{ + CheckResult result = check(R"( + function f(x) return x[1] end + -- x has type X? for a free type variable X + local x = f ({}) + type ContainsFree = { this: a, that: typeof(x) } + type ContainsContainsFree = { that: ContainsFree } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") +{ + ScopedFastFlag sff1{"LuauSubstitutionDontReplaceIgnoredTypes", true}; + + CheckResult result = check(R"( + type Array = { [number]: T } + type Tuple = Array + + local p: Tuple + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{number | string}", toString(requireType("p"), {true})); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 8108e7f..b00fddc 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -30,6 +30,8 @@ TEST_SUITE_BEGIN("ProvisionalTests"); */ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") { + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + const std::string code = R"( function f(a) if type(a) == "boolean" then @@ -41,11 +43,11 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") )"; const std::string expected = R"( - function f(a:{fn:()->(free)}): () + function f(a:{fn:()->(free,free...)}): () if type(a) == 'boolean'then local a1:boolean=a elseif a.fn()then - local a2:{fn:()->(free)}=a + local a2:{fn:()->(free,free...)}=a end end )"; @@ -231,16 +233,7 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") local r2 = b == a )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), "Type '{| x: string |}?' could not be converted into 'number | string'"); - CHECK_EQ(toString(result.errors[1]), "Type 'number | string' could not be converted into '{| x: string |}?'"); - } + LUAU_REQUIRE_NO_ERRORS(result); } // Belongs in TypeInfer.refinements.test.cpp. @@ -542,6 +535,25 @@ TEST_CASE_FIXTURE(Fixture, "bail_early_on_typescript_port_of_Result_type" * doct } } +TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables") +{ + CheckResult result = check(R"( + --!strict + local function setNumber(t: { p: number? }, x:number) t.p = x end + local function getString(t: { p: string? }):string return t.p or "" end + -- This shouldn't type-check! + local function oh(x:number): string + local t: {} = {} + setNumber(t, x) + return getString(t) + end + local s: string = oh(37) + )"); + + // Really this should return an error, but it doesn't + LUAU_REQUIRE_NO_ERRORS(result); +} + // Should be in TypeInfer.tables.test.cpp // It's unsound to instantiate tables containing generic methods, // since mutating properties means table properties should be invariant. diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index f2ba0dd..31739cd 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Fixture.h" @@ -6,7 +7,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauWeakEqConstraint) -LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2) LUAU_FASTFLAG(LuauOrPredicate) using namespace Luau; @@ -199,16 +199,8 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") end )"); - if (FFlag::LuauImprovedTypeGuardPredicate2) - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'string' could not be converted into 'boolean'", toString(result.errors[0])); - } + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") @@ -526,8 +518,6 @@ TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x) if type(x) == "vector" then @@ -544,8 +534,6 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local t = {"hello"} local v = t[2] @@ -573,8 +561,6 @@ TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true" TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | number | boolean) if type(x) ~= "string" then @@ -593,8 +579,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | {x: number} | {y: boolean}) if type(x) == "table" then @@ -613,8 +597,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function weird(x: string | ((number) -> string)) if type(x) == "function" then @@ -698,8 +680,6 @@ struct RefinementClassFixture : Fixture TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(vec) local X, Y, Z = vec.X, vec.Y, vec.Z @@ -726,8 +706,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: Instance | Vector3) if typeof(x) == "Vector3" then @@ -746,8 +724,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | number | Instance | Vector3) if type(x) == "userdata" then @@ -766,10 +742,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") { - ScopedFastFlag sffs[] = { - {"LuauImprovedTypeGuardPredicate2", true}, - {"LuauTypeGuardPeelsAwaySubclasses", true}, - }; + ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true}; CheckResult result = check(R"( local function f(x: Part | Folder | string) @@ -789,10 +762,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") { - ScopedFastFlag sffs[] = { - {"LuauImprovedTypeGuardPredicate2", true}, - {"LuauTypeGuardPeelsAwaySubclasses", true}, - }; + ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true}; CheckResult result = check(R"( local function f(x: Part | Folder | Instance | string | Vector3 | any) @@ -812,10 +782,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( --!nonstrict @@ -839,7 +806,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") { ScopedFastFlag sffs[] = { {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, {"LuauTypeGuardPeelsAwaySubclasses", true}, }; @@ -861,8 +827,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( type XYCoord = {x: number} & {y: number} local function f(t: XYCoord?) @@ -882,8 +846,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( type SomeOverloadedFunction = ((number) -> string) & ((string) -> number) local function f(g: SomeOverloadedFunction?) @@ -903,8 +865,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") TEST_CASE_FIXTURE(Fixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(t: {x: number}) if type(t) ~= "table" then @@ -999,10 +959,7 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") TEST_CASE_FIXTURE(Fixture, "either_number_or_string") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( local function f(x: any) @@ -1036,10 +993,7 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( local a: (number | string)? @@ -1057,10 +1011,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; // This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result. CheckResult result = check(R"( @@ -1081,10 +1032,7 @@ TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( local function f(a: string | number | boolean) diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 1d1b2fa..b7f0dc7 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -46,6 +46,21 @@ TEST_CASE_FIXTURE(Fixture, "augment_table") CHECK(tType->props.find("foo") != tType->props.end()); } +TEST_CASE_FIXTURE(Fixture, "augment_nested_table") +{ + CheckResult result = check("local t = { p = {} } t.p.foo = 'bar'"); + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* tType = getMutable(requireType("t")); + REQUIRE(tType != nullptr); + + REQUIRE(tType->props.find("p") != tType->props.end()); + const TableTypeVar* pType = get(tType->props["p"].type); + REQUIRE(pType != nullptr); + + CHECK(pType->props.find("foo") != pType->props.end()); +} + TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") { CheckResult result = check("local t = {prop=999} t.foo = 'bar'"); @@ -260,6 +275,8 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification") TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local a = {} a.x = 99 @@ -272,10 +289,11 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; - UnknownProperty* error = get(err); + MissingProperties* error = get(err); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); - CHECK_EQ(error->key, "y"); + CHECK_EQ("y", error->properties[0]); // TODO(rblanckaert): Revist when we can bind self at function creation time // CHECK_EQ(err.location, Location(Position{5, 19}, Position{5, 25})); @@ -328,6 +346,8 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( --!strict function foo(o) @@ -340,14 +360,17 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") LUAU_REQUIRE_ERROR_COUNT(1, result); - UnknownProperty* error = get(result.errors[0]); + MissingProperties* error = get(result.errors[0]); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); - CHECK_EQ("baz", error->key); + CHECK_EQ("baz", error->properties[0]); } TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local T = {} T.bar = 'hello' @@ -359,8 +382,11 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; - UnknownProperty* error = get(err); + MissingProperties* error = get(err); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); + + CHECK_EQ("baz", error->properties[0]); // TODO(rblanckaert): Revist when we can bind self at function creation time /* @@ -448,6 +474,73 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") dumpErrors(result); } +TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + --!strict + local t = { u = {} } + t = { u = { p = 37 } } + t = { u = { q = "hi" } } + local x = t.u.p + local y = t.u.q + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number?", toString(requireType("x"))); + CHECK_EQ("string?", toString(requireType("y"))); +} + +TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_call") +{ + CheckResult result = check(R"( + --!strict + function get(x) return x.opts["MYOPT"] end + function set(x,y) x.opts["MYOPT"] = y end + local t = { opts = {} } + set(t,37) + local x = get(t) + )"); + + // Currently this errors but it shouldn't, since set only needs write access + // TODO: file a JIRA for this + LUAU_REQUIRE_ERRORS(result); + // CHECK_EQ("number?", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "width_subtyping") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + --!strict + function f(x : { q : number }) + x.q = 8 + end + local t : { q : number, r : string } = { q = 8, r = "hi" } + f(t) + local x : string = t.r + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "width_subtyping_needs_covariance") +{ + CheckResult result = check(R"( + --!strict + function f(x : { p : { q : number }}) + x.p = { q = 8, r = 5 } + end + local t : { p : { q : number, r : string } } = { p = { q = 8, r = "hi" } } + f(t) -- Shouldn't typecheck + local x : string = t.p.r -- x is 5 + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "infer_array") { CheckResult result = check(R"( @@ -676,16 +769,27 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_han LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "sealed_table_value_must_not_infer_an_indexer") +TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local t: { a: string, [number]: string } = { a = "foo" } )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_NO_ERRORS(result); +} - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm != nullptr); +TEST_CASE_FIXTURE(Fixture, "array_factory_function") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + function empty() return {} end + local array: {string} = empty() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "sealed_table_indexers_must_unify") @@ -756,37 +860,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_ CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); } -TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_string") -{ - ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true); - - CheckResult result = check(R"( - local t: { a: string } - function f(x: string) return t[x] end - local a = f("a") - local b = f("b") - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.anyType, *requireType("a")); - CHECK_EQ(*typeChecker.anyType, *requireType("b")); -} - -TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_number") -{ - ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true); - - CheckResult result = check(R"( - local t = { a = true } - function f(x: number) return t[x] end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); -} - TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_should_infer_new_properties_over_indexer") { CheckResult result = check(R"( @@ -1392,6 +1465,8 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end foo({ a = 1 }) @@ -1402,8 +1477,21 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") ToStringOptions o{/* exhaustive= */ true}; TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ("string", toString(tm->wantedType, o)); - CHECK_EQ("number", toString(tm->givenType, o)); + CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o)); + CHECK_EQ("{| a: number |}", toString(tm->givenType, o)); +} + +TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") +{ + CheckResult result = check(R"( + local function foo(a: {[string]: number, a: string}, i: string) + return a[i] + end + local hi: number = foo({ a = "hi" }, "a") -- shouldn't typecheck since at runtime hi is "hi" + )"); + + // This typechecks but shouldn't + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multiple_errors") @@ -1446,22 +1534,32 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multi TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors") { CheckResult result = check(R"( - local vec3 = {x = 1, y = 2, z = 3} - local vec1 = {x = 1} + local vec3 = {{x = 1, y = 2, z = 3}} + local vec1 = {{x = 1}} vec1 = vec3 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - MissingProperties* mp = get(result.errors[0]); - REQUIRE(mp); - CHECK_EQ(mp->context, MissingProperties::Extra); - REQUIRE_EQ(2, mp->properties.size()); - CHECK_EQ(mp->properties[0], "y"); - CHECK_EQ(mp->properties[1], "z"); - CHECK_EQ("vec1", toString(mp->superType)); - CHECK_EQ("vec3", toString(mp->subType)); + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("vec1", toString(tm->wantedType)); + CHECK_EQ("vec3", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + local vec3 = {x = 1, y = 2, z = 3} + local vec1 = {x = 1} + + vec1 = vec3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "type_mismatch_on_massive_table_is_cut_short") @@ -1824,4 +1922,32 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + local buttons = {} + table.insert(buttons, { a = 1 }) + table.insert(buttons, { a = 2, b = true }) + table.insert(buttons, { a = 3 }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_strict") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + --!strict + local buttons = {} + table.insert(buttons, { a = 1 }) + table.insert(buttons, { a = 2, b = true }) + table.insert(buttons, { a = 3 }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 338698e..dbc4538 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -3,6 +3,7 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" @@ -978,23 +979,6 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") -{ - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - - CheckResult result = check(R"( - type F = () -> F? - local function f() - return f - end - - local g: F = f - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g"))); -} - // TODO: File a Jira about this /* TEST_CASE_FIXTURE(Fixture, "unifying_vararg_pack_with_fixed_length_pack_produces_fixed_length_pack") @@ -1257,23 +1241,6 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") REQUIRE_EQ(follow(*methodArg), follow(arg)); } -TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") -{ - CheckResult result = check(R"( - --!strict - type Node = { Parent: Node?; } - local node: Node; - node.Parent = 1 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("Node?", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") { CheckResult result = check(R"( @@ -2591,48 +2558,6 @@ TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") dumpErrors(result); } -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types") -{ - CheckResult result = check(R"( - --!strict - type T = { f: a, g: U } - type U = { h: a, i: T? } - local x: T = { f = 37, g = { h = 5, i = nil } } - x.g.i = x - local y: T = { f = "hi", g = { h = "lo", i = nil } } - y.g.i = y - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors") -{ - CheckResult result = check(R"( - --!strict - type T = { f: a, g: U } - type U = { h: b, i: T? } - local x: T = { f = 37, g = { h = 5, i = nil } } - x.g.i = x - local y: T = { f = "hi", g = { h = 5, i = nil } } - y.g.i = y - )"); - - LUAU_REQUIRE_ERRORS(result); - - // We had a UAF in this example caused by not cloning type function arguments - ModulePtr module = frontend.moduleResolver.getModule("MainModule"); - unfreeze(module->interfaceTypes); - copyErrors(module->errors, module->interfaceTypes); - freeze(module->interfaceTypes); - module->internalTypes.clear(); - module->astTypes.clear(); - - // Make sure the error strings don't include "VALUELESS" - for (auto error : module->errors) - CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error)); -} - TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self") { // CLI-30902 @@ -3388,16 +3313,7 @@ TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison") end )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable") @@ -3407,18 +3323,8 @@ TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable print((x == true and (x .. "y")) .. 1) )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE(get(result.errors[0])); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK_EQ("Type 'boolean' could not be converted into 'number | string'", toString(result.errors[0])); - CHECK_EQ("Type 'boolean | string' could not be converted into 'number | string'", toString(result.errors[1])); - } + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") @@ -3530,25 +3436,6 @@ _(...)(...,setfenv,_):_G() )"); } -TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors") -{ - CheckResult result = check(R"( - type Pair = {first: T, second: U} - local a: Pair - local b: Pair - - a = b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - - CHECK_EQ("Pair", toString(tm->wantedType)); - CHECK_EQ("Pair", toString(tm->givenType)); -} - TEST_CASE_FIXTURE(Fixture, "cyclic_type_packs") { // this has a risk of creating cyclic type packs, causing infinite loops / OOMs @@ -3658,17 +3545,6 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") )"); } -TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition") -{ - CheckResult result = check(R"( - type A = number - type A = string -- Redefinition of type 'A', previously defined at line 1 - local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery") { CheckResult result = check(R"( @@ -3771,38 +3647,6 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); } -TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type") -{ - CheckResult result = check(R"( - type Table = { a: T } - type Wrapped = Table - local l: Wrapped = 2 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("Wrapped", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - -TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2") -{ - CheckResult result = check(R"( - type Table = { a: T } - type Wrapped = (Table) -> string - local l: Wrapped = 2 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked") { CheckResult result = check(R"( @@ -3928,19 +3772,6 @@ TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals") CHECK_EQ("(boolean | number | string)?", toString(tm->givenType)); } -// Check that recursive intersection type doesn't generate an OOM -TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") -{ - CheckResult result = check(R"( - function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any - end - type t0 = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_)) - _(_) - )"); - - LUAU_REQUIRE_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") { // In non-strict mode, global definition is still allowed @@ -3993,16 +3824,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") LUAU_REQUIRE_ERROR_COUNT(2, result); } -TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise") -{ - CheckResult result = check(R"( - local foo: Id = 1 - type Id = T - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods") { CheckResult result = check(R"( @@ -4033,81 +3854,6 @@ end LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") -{ - const std::string code = R"( - type A = {v:T, b:B} - type B = {v:T, a:A} - local aa:A - local bb = aa - )"; - - const std::string expected = R"( - type A = {v:T, b:B} - type B = {v:T, a:A} - local aa:A - local bb:A=aa - )"; - - CHECK_EQ(expected, decorateWithTypes(code)); - CheckResult result = check(code); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") -{ - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - - CheckResult result = check(R"( - type A = () -> (number, B) - type B = () -> (string, A) - local a: A - local b: B - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a"))); - CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "generic_param_remap") -{ - const std::string code = R"( - -- An example of a forwarded use of a type that has different type arguments than parameters - type A = {t:T, u:U, next:A?} - local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } - local bb = aa - )"; - - const std::string expected = R"( - - type A = {t:T, u:U, next:A?} - local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } - local bb:A=aa - )"; - - CHECK_EQ(expected, decorateWithTypes(code)); - CheckResult result = check(code); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") -{ - CheckResult result = check(R"( - export type Foo = number - type Foo = number - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto dtd = get(result.errors[0]); - REQUIRE(dtd); - CHECK_EQ(dtd->name, "Foo"); -} - TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError") { CheckResult result = check(R"( @@ -4212,30 +3958,6 @@ TEST_CASE_FIXTURE(Fixture, "luau_resolves_symbols_the_same_way_lua_does") REQUIRE_MESSAGE(get(e) != nullptr, "Expected UnknownSymbol, but got " << e); } -TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") -{ - ScopedFastFlag sffs3{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - - CheckResult result = check(R"( - type Node = { value: T, child: Node? } - - local function visitor(node: Node?) - local a: Node - - if node then - a = node.child -- Observe the output of the error message. - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto e = get(result.errors[0]); - CHECK_EQ("Node?", toString(e->givenType)); - CHECK_EQ("Node", toString(e->wantedType)); -} - TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") { CheckResult result = check(R"( @@ -4291,181 +4013,6 @@ local tbl: string = require(game.A) CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") -{ - fileResolver.source["workspace/A"] = R"( - export type myvec2 = {x: number, y: number} - return {} - )"; - - fileResolver.source["workspace/B"] = R"( - export type myvec3 = {x: number, y: number, z: number} - return {} - )"; - - fileResolver.source["workspace/C"] = R"( - local Foo, Bar = require(workspace.A), require(workspace.B) - - local a: Foo.myvec2 - local b: Bar.myvec3 - )"; - - CheckResult result = frontend.check("workspace/C"); - LUAU_REQUIRE_NO_ERRORS(result); - ModulePtr m = frontend.moduleResolver.modules["workspace/C"]; - - REQUIRE(m != nullptr); - - std::optional aTypeId = lookupName(m->getModuleScope(), "a"); - REQUIRE(aTypeId); - const Luau::TableTypeVar* aType = get(follow(*aTypeId)); - REQUIRE(aType); - REQUIRE(aType->props.size() == 2); - - std::optional bTypeId = lookupName(m->getModuleScope(), "b"); - REQUIRE(bTypeId); - const Luau::TableTypeVar* bType = get(follow(*bTypeId)); - REQUIRE(bType); - REQUIRE(bType->props.size() == 3); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation") -{ - CheckResult result = check("type t10 = typeof(table)"); - LUAU_REQUIRE_NO_ERRORS(result); - - TypeId ty = getGlobalBinding(frontend.typeChecker, "table"); - CHECK_EQ(toString(ty), "table"); - - const TableTypeVar* ttv = get(ty); - REQUIRE(ttv); - - CHECK(ttv->instantiatedTypeParams.empty()); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation") -{ - CheckResult result = check(R"( -type Cool = { a: number, b: string } -local c: Cool = { a = 1, b = "s" } -type NotCool = Cool -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "Cool"); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - - CHECK(ttv->instantiatedTypeParams.empty()); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename") -{ - CheckResult result = check(R"( -type Cool = { a: number, b: string } -type NotCool = Cool -local c: Cool = { a = 1, b = "s" } -local d: NotCool = { a = 1, b = "s" } -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "Cool"); - - ty = requireType("d"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "NotCool"); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation") -{ - CheckResult result = check(R"( -local c = { a = 1, b = "s" } -type Cool = typeof(c) -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - CHECK_EQ(ttv->name, "Cool"); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type") -{ - ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true}; - - fileResolver.source["game/A"] = R"( -export type X = { a: number, b: X? } -return {} - )"; - - CheckResult aResult = frontend.check("game/A"); - LUAU_REQUIRE_NO_ERRORS(aResult); - - CheckResult bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - std::optional ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - std::optional ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(follow(*ty1), follow(*ty2)); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type") -{ - ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true}; - - fileResolver.source["game/A"] = R"( -export type X = { a: T, b: U, C: X? } -return {} - )"; - - CheckResult aResult = frontend.check("game/A"); - LUAU_REQUIRE_NO_ERRORS(aResult); - - CheckResult bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - std::optional ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - std::optional ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true})); - - bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}"); - CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?"); -} - TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") { CheckResult result = check(R"( @@ -4579,32 +4126,6 @@ local c = a(2) -- too many arguments CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak") -{ - CheckResult result = check(R"( -function get() - return function(obj) return true end -end - -export type f = typeof(get()) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak") -{ - CheckResult result = check(R"( -function get() - return {a = 1, b = function(obj) return true end} -end - -export type f = typeof(get()) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "custom_require_global") { CheckResult result = check(R"( @@ -4787,8 +4308,6 @@ TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") TEST_CASE_FIXTURE(Fixture, "dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back") { - ScopedFastFlag sff{"LuauLogTableTypeVarBoundTo", true}; - fileResolver.source["Module/Backend/Types"] = R"( export type Fiber = { return_: Fiber? @@ -4868,8 +4387,8 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") ModulePtr module = getMainModule(); auto it = module->astOverloadResolvedTypes.find(parentExpr); - REQUIRE(it != module->astOverloadResolvedTypes.end()); - CHECK_EQ(toString(it->second), "(number) -> number"); + REQUIRE(it); + CHECK_EQ(toString(*it), "(number) -> number"); } TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") @@ -5032,76 +4551,6 @@ g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") -{ - CheckResult result = check(R"( - type Tree = { data: T, children: Forest } - type Forest = {Tree} - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - -- OK because forwarded types are used with their parameters. - type Tree = { data: T, children: Forest } - type Forest = {Tree<{T}>} - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - -- Not OK because forwarded types are used with different types than their parameters. - type Forest = {Tree<{T}>} - type Tree = { data: T, children: Forest } - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") -{ - CheckResult result = check(R"( - type Tree1 = { data: T, children: {Tree2} } - type Tree2 = { data: U, children: {Tree1} } - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - type Tree1 = { data: T, children: {Tree2} } - type Tree2 = { data: U, children: {Tree1} } - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases") -{ - CheckResult result = check(R"( - function f(x) return x[1] end - -- x has type X? for a free type variable X - local x = f ({}) - type ContainsFree = { this: a, that: typeof(x) } - type ContainsContainsFree = { that: ContainsFree } - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 91ac9f0..1f4b63e 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 5f7f284..3e1dedd 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -294,4 +294,370 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Packed = (T...) -> T... +local a: Packed<> +local b: Packed +local c: Packed + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto tf = lookupType("Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "(T...) -> (T...)"); + CHECK_EQ(toString(requireType("a")), "() -> ()"); + CHECK_EQ(toString(requireType("b")), "(number) -> number"); + CHECK_EQ(toString(requireType("c")), "(string, number) -> (string, number)"); + + result = check(R"( +-- (U..., T) cannot be parsed right now +type Packed = { f: (a: T, U...) -> (T, U...) } +local a: Packed +local b: Packed +local c: Packed + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + tf = lookupType("Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Packed"); + CHECK_EQ(toString(*tf, {true}), "{| f: (T, U...) -> (T, U...) |}"); + + auto ttvA = get(requireType("a")); + REQUIRE(ttvA); + CHECK_EQ(toString(requireType("a")), "Packed"); + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}"); + REQUIRE(ttvA->instantiatedTypeParams.size() == 1); + REQUIRE(ttvA->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number"); + CHECK_EQ(toString(ttvA->instantiatedTypePackParams[0], {true}), ""); + + auto ttvB = get(requireType("b")); + REQUIRE(ttvB); + CHECK_EQ(toString(requireType("b")), "Packed"); + CHECK_EQ(toString(requireType("b"), {true}), "{| f: (string, number) -> (string, number) |}"); + REQUIRE(ttvB->instantiatedTypeParams.size() == 1); + REQUIRE(ttvB->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvB->instantiatedTypeParams[0], {true}), "string"); + CHECK_EQ(toString(ttvB->instantiatedTypePackParams[0], {true}), "number"); + + auto ttvC = get(requireType("c")); + REQUIRE(ttvC); + CHECK_EQ(toString(requireType("c")), "Packed"); + CHECK_EQ(toString(requireType("c"), {true}), "{| f: (string, number, boolean) -> (string, number, boolean) |}"); + REQUIRE(ttvC->instantiatedTypeParams.size() == 1); + REQUIRE(ttvC->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvC->instantiatedTypeParams[0], {true}), "string"); + CHECK_EQ(toString(ttvC->instantiatedTypePackParams[0], {true}), "number, boolean"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + fileResolver.source["game/A"] = R"( +export type Packed = { a: T, b: (U...) -> () } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +local a: Import.Packed +local b: Import.Packed +local c: Import.Packed +local d: { a: typeof(c) } + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + auto tf = lookupImportedType("Import", "Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Packed"); + CHECK_EQ(toString(*tf, {true}), "{| a: T, b: (U...) -> () |}"); + + CHECK_EQ(toString(requireType("a"), {true}), "{| a: number, b: () -> () |}"); + CHECK_EQ(toString(requireType("b"), {true}), "{| a: string, b: (number) -> () |}"); + CHECK_EQ(toString(requireType("c"), {true}), "{| a: string, b: (number, boolean) -> () |}"); + CHECK_EQ(toString(requireType("d")), "{| a: Packed |}"); +} + +TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + fileResolver.source["game/A"] = R"( +export type Packed = { a: T, b: (U...) -> () } +return {} + )"; + + CheckResult cResult = check(R"( +local Import = require(game.A) +type Alias = Import.Packed +local a: Alias + +type B = Import.Packed +type C = Import.Packed + )"); + LUAU_REQUIRE_NO_ERRORS(cResult); + + auto tf = lookupType("Alias"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Alias"); + CHECK_EQ(toString(*tf, {true}), "{| a: S, b: (T, R...) -> () |}"); + + CHECK_EQ(toString(requireType("a"), {true}), "{| a: string, b: (number, boolean) -> () |}"); + + tf = lookupType("B"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "B"); + CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (X...) -> () |}"); + + tf = lookupType("C"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "C"); + CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (number, X...) -> () |}"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_nested") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Packed1 = (T...) -> (T...) +type Packed2 = (Packed1, T...) -> (Packed1, T...) +type Packed3 = (Packed2, T...) -> (Packed2, T...) +type Packed4 = (Packed3, T...) -> (Packed3, T...) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto tf = lookupType("Packed4"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), + "((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...) -> " + "((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_variadic") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type X = (T...) -> (string, T...) + +type D = X<...number> +type E = X<(number, ...string)> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("D")), "(...number) -> (string, ...number)"); + CHECK_EQ(toString(*lookupType("E")), "(number, ...string) -> (string, number, ...string)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_multi") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Y = (T...) -> (U...) +type A = Y +type B = Y<(number, ...string), S...> + +type Z = (T) -> (U...) +type E = Z +type F = Z + +type W = (T, U...) -> (T, V...) +type H = W +type I = W + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(S...) -> (S...)"); + CHECK_EQ(toString(*lookupType("B")), "(number, ...string) -> (S...)"); + + CHECK_EQ(toString(*lookupType("E")), "(number) -> (S...)"); + CHECK_EQ(toString(*lookupType("F")), "(number) -> (string, S...)"); + + CHECK_EQ(toString(*lookupType("H")), "(number, S...) -> (number, R...)"); + CHECK_EQ(toString(*lookupType("I")), "(number, string, S...) -> (number, R...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type X = (T...) -> (T...) + +type A = X<(S...)> +type B = X<()> +type C = X<(number)> +type D = X<(number, string)> +type E = X<(...number)> +type F = X<(string, ...number)> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(S...) -> (S...)"); + CHECK_EQ(toString(*lookupType("B")), "() -> ()"); + CHECK_EQ(toString(*lookupType("C")), "(number) -> number"); + CHECK_EQ(toString(*lookupType("D")), "(number, string) -> (number, string)"); + CHECK_EQ(toString(*lookupType("E")), "(...number) -> (...number)"); + CHECK_EQ(toString(*lookupType("F")), "(string, ...number) -> (string, ...number)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Y = (T...) -> (U...) + +type A = Y<(number, string), (boolean)> +type B = Y<(), ()> +type C = Y<...string, (number, S...)> +type D = Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(number, string) -> boolean"); + CHECK_EQ(toString(*lookupType("B")), "() -> ()"); + CHECK_EQ(toString(*lookupType("C")), "(...string) -> (number, S...)"); + CHECK_EQ(toString(*lookupType("D")), "(X...) -> (number, string, X...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi_tostring") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + ScopedFastFlag luauInstantiatedTypeParamRecursion("LuauInstantiatedTypeParamRecursion", true); // For correct toString block + + CheckResult result = check(R"( +type Y = { f: (T...) -> (U...) } + +local a: Y<(number, string), (boolean)> +local b: Y<(), ()> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (boolean)>"); + CHECK_EQ(toString(requireType("b")), "Y<(), ()>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_backwards_compatible") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type X = () -> T +type Y = (T) -> U + +type A = X<(number)> +type B = Y<(number), (boolean)> +type C = Y<(number), boolean> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "() -> number"); + CHECK_EQ(toString(*lookupType("B")), "(number) -> boolean"); + CHECK_EQ(toString(*lookupType("C")), "(number) -> boolean"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_errors") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Packed = (T, U) -> (V...) +local b: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects at least 2 type arguments, but only 1 is specified"); + + result = check(R"( +type Packed = (T, U) -> () +type B = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 0 type pack arguments, but 1 is specified"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameters must come before type pack parameters"); + + result = check(R"( +type Packed = (T) -> U +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type arguments, but only 1 is specified"); + + result = check(R"( +type Packed = (T...) -> T... +local a: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameter list is required"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed<> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type pack arguments, but none are specified"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type pack arguments, but only 1 is specified"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 456dbad..c0ed25d 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -237,21 +237,7 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") local z = a == c )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.booleanType, *requireType("x")); - CHECK_EQ(*typeChecker.booleanType, *requireType("y")); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("(number | string)?", toString(*tm->wantedType)); - CHECK_EQ("boolean | number", toString(*tm->givenType)); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "optional_union_members") diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 98ce9f9..a679e3f 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tools/tracegraph.py b/tools/tracegraph.py new file mode 100644 index 0000000..a46423e --- /dev/null +++ b/tools/tracegraph.py @@ -0,0 +1,95 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given a trace event file, this tool generates a flame graph based on the event scopes present in the file +# The result of analysis is a .svg file which can be viewed in a browser + +import sys +import svg +import json + +class Node(svg.Node): + def __init__(self): + svg.Node.__init__(self) + self.caption = "" + self.description = "" + self.ticks = 0 + + def text(self): + return self.caption + + def title(self): + return self.caption + + def details(self, root): + return "{} ({:,} usec, {:.1%}); self: {:,} usec".format(self.description, self.width, self.width / root.width, self.ticks) + +with open(sys.argv[1]) as f: + dump = f.read() + +root = Node() + +# Finish the file +if not dump.endswith("]"): + dump += "{}]" + +data = json.loads(dump) + +stacks = {} + +for l in data: + if len(l) == 0: + continue + + # Track stack of each thread, but aggregate values together + tid = l["tid"] + + if not tid in stacks: + stacks[tid] = [] + stack = stacks[tid] + + if l["ph"] == 'B': + stack.append(l) + elif l["ph"] == 'E': + node = root + + for e in stack: + caption = e["name"] + description = '' + + if "args" in e: + for arg in e["args"]: + if len(description) != 0: + description += ", " + + description += "{}: {}".format(arg, e["args"][arg]) + + child = node.child(caption + description) + child.caption = caption + child.description = description + + node = child + + begin = stack[-1] + + ticks = l["ts"] - begin["ts"] + rawticks = ticks + + # Flame graph requires ticks without children duration + if "childts" in begin: + ticks -= begin["childts"] + + node.ticks += int(ticks) + + stack.pop() + + if len(stack): + parent = stack[-1] + + if "childts" in parent: + parent["childts"] += rawticks + else: + parent["childts"] = rawticks + +svg.layout(root, lambda n: n.ticks) +svg.display(root, "Flame Graph", "hot", flip = True)