diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 06f53e4..9d5aadf 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -7,6 +7,7 @@ #include "Luau/Constraint.h" #include "Luau/TypeVar.h" #include "Luau/ToString.h" +#include "Luau/Normalize.h" #include @@ -44,6 +45,7 @@ struct ConstraintSolver TypeArena* arena; NotNull singletonTypes; InternalErrorReporter iceReporter; + NotNull normalizer; // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; NotNull rootScope; @@ -74,9 +76,12 @@ struct ConstraintSolver DcrLogger* logger; - explicit ConstraintSolver(TypeArena* arena, NotNull singletonTypes, NotNull rootScope, ModuleName moduleName, + explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); + // Randomize the order in which to dispatch constraints + void randomize(unsigned seed); + /** * Attempts to dispatch all pending constraints and reach a type solution * that satisfies all of the constraints. @@ -85,8 +90,9 @@ struct ConstraintSolver bool done(); - /** Attempt to dispatch a constraint. Returns true if it was successful. - * If tryDispatch() returns false, the constraint remains in the unsolved set and will be retried later. + /** Attempt to dispatch a constraint. Returns true if it was successful. If + * tryDispatch() returns false, the constraint remains in the unsolved set + * and will be retried later. */ bool tryDispatch(NotNull c, bool force); diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index f373586..6775488 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -16,7 +16,7 @@ struct TypeMismatch TypeMismatch() = default; TypeMismatch(TypeId wantedType, TypeId givenType); TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason); - TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeError error); + TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional error); TypeId wantedType = nullptr; TypeId givenType = nullptr; diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 5df6f4b..b2662c6 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -83,8 +83,13 @@ struct FrontendOptions // is complete. bool retainFullTypeGraphs = false; - // Run typechecking only in mode required for autocomplete (strict mode in order to get more precise type information) + // Run typechecking only in mode required for autocomplete (strict mode in + // order to get more precise type information) bool forAutocomplete = false; + + // If not empty, randomly shuffle the constraint set before attempting to + // solve. Use this value to seed the random number generator. + std::optional randomizeConstraintResolutionSeed; }; struct CheckResult diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 8e8b889..41e50d1 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -1,9 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Module.h" #include "Luau/NotNull.h" #include "Luau/TypeVar.h" +#include "Luau/UnifierSharedState.h" #include @@ -29,4 +29,231 @@ std::pair normalize( std::pair normalize(TypePackId ty, NotNull module, NotNull singletonTypes, InternalErrorReporter& ice); std::pair normalize(TypePackId ty, const ModulePtr& module, NotNull singletonTypes, InternalErrorReporter& ice); +class TypeIds +{ +private: + std::unordered_set types; + std::vector order; + std::size_t hash = 0; + +public: + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + + TypeIds(const TypeIds&) = delete; + TypeIds(TypeIds&&) = default; + TypeIds() = default; + ~TypeIds() = default; + TypeIds& operator=(TypeIds&&) = default; + + void insert(TypeId ty); + /// Erase every element that does not also occur in tys + void retain(const TypeIds& tys); + void clear(); + + iterator begin(); + iterator end(); + const_iterator begin() const; + const_iterator end() const; + iterator erase(const_iterator it); + + size_t size() const; + bool empty() const; + size_t count(TypeId ty) const; + + template + void insert(Iterator begin, Iterator end) + { + for (Iterator it = begin; it != end; ++it) + insert(*it); + } + + bool operator ==(const TypeIds& there) const; + size_t getHash() const; +}; + +} // namespace Luau + +template<> struct std::hash +{ + std::size_t operator()(const Luau::TypeIds& tys) const + { + return tys.getHash(); + } +}; + +template<> struct std::hash +{ + std::size_t operator()(const Luau::TypeIds* tys) const + { + return tys->getHash(); + } +}; + +template<> struct std::equal_to +{ + bool operator()(const Luau::TypeIds& here, const Luau::TypeIds& there) const + { + return here == there; + } +}; + +template<> struct std::equal_to +{ + bool operator()(const Luau::TypeIds* here, const Luau::TypeIds* there) const + { + return *here == *there; + } +}; + +namespace Luau +{ + +// A normalized string type is either `string` (represented by `nullopt`) +// or a union of string singletons. +using NormalizedStringType = std::optional>; + +// A normalized function type is either `never` (represented by `nullopt`) +// or an intersection of function types. +// NOTE: type normalization can fail on function types with generics +// (e.g. because we do not support unions and intersections of generic type packs), +// so this type may contain `error`. +using NormalizedFunctionType = std::optional; + +// A normalized generic/free type is a union, where each option is of the form (X & T) where +// * X is either a free type or a generic +// * T is a normalized type. +struct NormalizedType; +using NormalizedTyvars = std::unordered_map>; + +// A normalized type is either any, unknown, or one of the form P | T | F | G where +// * P is a union of primitive types (including singletons, classes and the error type) +// * T is a union of table types +// * F is a union of an intersection of function types +// * G is a union of generic/free normalized types, intersected with a normalized type +struct NormalizedType +{ + // The top part of the type. + // This type is either never, unknown, or any. + // If this type is not never, all the other fields are null. + TypeId tops; + + // The boolean part of the type. + // This type is either never, boolean type, or a boolean singleton. + TypeId booleans; + + // The class part of the type. + // Each element of this set is a class, and none of the classes are subclasses of each other. + TypeIds classes; + + // The error part of the type. + // This type is either never or the error type. + TypeId errors; + + // The nil part of the type. + // This type is either never or nil. + TypeId nils; + + // The number part of the type. + // This type is either never or number. + TypeId numbers; + + // The string part of the type. + // This may be the `string` type, or a union of singletons. + NormalizedStringType strings = std::map{}; + + // The thread part of the type. + // This type is either never or thread. + TypeId threads; + + // The (meta)table part of the type. + // Each element of this set is a (meta)table type. + TypeIds tables; + + // The function part of the type. + NormalizedFunctionType functions; + + // The generic/free part of the type. + NormalizedTyvars tyvars; + + NormalizedType(NotNull singletonTypes); + + NormalizedType(const NormalizedType&) = delete; + NormalizedType(NormalizedType&&) = default; + NormalizedType() = delete; + ~NormalizedType() = default; + NormalizedType& operator=(NormalizedType&&) = default; + NormalizedType& operator=(NormalizedType&) = delete; +}; + +class Normalizer +{ + std::unordered_map> cachedNormals; + std::unordered_map cachedIntersections; + std::unordered_map cachedUnions; + std::unordered_map> cachedTypeIds; + bool withinResourceLimits(); + +public: + TypeArena* arena; + NotNull singletonTypes; + NotNull sharedState; + + Normalizer(TypeArena* arena, NotNull singletonTypes, NotNull sharedState); + Normalizer(const Normalizer&) = delete; + Normalizer(Normalizer&&) = delete; + Normalizer() = delete; + ~Normalizer() = default; + Normalizer& operator=(Normalizer&&) = delete; + Normalizer& operator=(Normalizer&) = delete; + + // If this returns null, the typechecker should emit a "too complex" error + const NormalizedType* normalize(TypeId ty); + void clearNormal(NormalizedType& norm); + + // ------- Cached TypeIds + TypeId unionType(TypeId here, TypeId there); + TypeId intersectionType(TypeId here, TypeId there); + const TypeIds* cacheTypeIds(TypeIds tys); + void clearCaches(); + + // ------- Normalizing unions + void unionTysWithTy(TypeIds& here, TypeId there); + TypeId unionOfTops(TypeId here, TypeId there); + TypeId unionOfBools(TypeId here, TypeId there); + void unionClassesWithClass(TypeIds& heres, TypeId there); + void unionClasses(TypeIds& heres, const TypeIds& theres); + void unionStrings(NormalizedStringType& here, const NormalizedStringType& there); + std::optional unionOfTypePacks(TypePackId here, TypePackId there); + std::optional unionOfFunctions(TypeId here, TypeId there); + std::optional unionSaturatedFunctions(TypeId here, TypeId there); + void unionFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); + void unionFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); + void unionTablesWithTable(TypeIds& heres, TypeId there); + void unionTables(TypeIds& heres, const TypeIds& theres); + bool unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); + bool unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars = -1); + + // ------- Normalizing intersections + void intersectTysWithTy(TypeIds& here, TypeId there); + TypeId intersectionOfTops(TypeId here, TypeId there); + TypeId intersectionOfBools(TypeId here, TypeId there); + void intersectClasses(TypeIds& heres, const TypeIds& theres); + void intersectClassesWithClass(TypeIds& heres, TypeId there); + void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); + std::optional intersectionOfTypePacks(TypePackId here, TypePackId there); + std::optional intersectionOfTables(TypeId here, TypeId there); + void intersectTablesWithTable(TypeIds& heres, TypeId there); + void intersectTables(TypeIds& heres, const TypeIds& theres); + std::optional intersectionOfFunctions(TypeId here, TypeId there); + void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); + void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); + bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there); + bool intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); + bool intersectNormalWithTy(NormalizedType& here, TypeId there); + + // -------- Convert back from a normalized type to a type + TypeId typeFromNormal(const NormalizedType& norm); +}; + } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index e5675eb..3184b0d 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -234,6 +234,8 @@ public: TypeId anyify(const ScopePtr& scope, TypeId ty, Location location); TypePackId anyify(const ScopePtr& scope, TypePackId ty, Location location); + TypePackId anyifyModuleReturnTypePackGenerics(TypePackId ty); + void reportError(const TypeError& error); void reportError(const Location& location, TypeErrorData error); void reportErrors(const ErrorVec& errors); @@ -359,6 +361,7 @@ public: InternalErrorReporter* iceHandler; UnifierSharedState unifierState; + Normalizer normalizer; std::vector requireCycles; diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index 0ea175c..c43daa2 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -96,7 +96,7 @@ struct Free bool forwardedTypeAlias = false; private: - static int nextIndex; + static int DEPRECATED_nextIndex; }; template @@ -127,7 +127,7 @@ struct Generic bool explicitName = false; private: - static int nextIndex; + static int DEPRECATED_nextIndex; }; struct Error diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 26a922f..f6219df 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -9,6 +9,7 @@ #include "Luau/TxnLog.h" #include "Luau/TypeArena.h" #include "Luau/UnifierSharedState.h" +#include "Normalize.h" #include @@ -52,6 +53,7 @@ struct Unifier { TypeArena* const types; NotNull singletonTypes; + NotNull normalizer; Mode mode; NotNull scope; // const Scope maybe @@ -60,13 +62,14 @@ struct Unifier Location location; Variance variance = Covariant; bool anyIsTop = false; // If true, we consider any to be a top type. If false, it is a familiar but weird mix of top and bottom all at once. + bool normalize; // Normalize unions and intersections if necessary bool useScopes = false; // If true, we use the scope hierarchy rather than TypeLevels CountMismatch::Context ctx = CountMismatch::Arg; UnifierSharedState& sharedState; - Unifier(TypeArena* types, NotNull singletonTypes, Mode mode, NotNull scope, const Location& location, Variance variance, - UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); + Unifier(NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, + TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); @@ -84,6 +87,7 @@ private: void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTypeVar* uv, bool cacheEnabled, bool isFunctionCall); void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionTypeVar* uv); void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall); + void tryUnifyNormalizedTypes(TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional error = std::nullopt); void tryUnifyPrimitives(TypeId subTy, TypeId superTy); void tryUnifySingletons(TypeId subTy, TypeId superTy); void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); @@ -92,6 +96,8 @@ private: void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); + TypePackId tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args); + TypeId widen(TypeId ty); TypePackId widen(TypePackId tp); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 1f594fe..224e944 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -12,8 +12,6 @@ #include #include -LUAU_FASTFLAG(LuauSelfCallAutocompleteFix3) - static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -139,7 +137,8 @@ static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, T { InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); - Unifier unifier(typeArena, singletonTypes, Mode::Strict, scope, Location(), Variance::Covariant, unifierState); + Normalizer normalizer{typeArena, singletonTypes, NotNull{&unifierState}}; + Unifier unifier(NotNull{&normalizer}, Mode::Strict, scope, Location(), Variance::Covariant); return unifier.canUnify(subTy, superTy).empty(); } @@ -151,18 +150,6 @@ static TypeCorrectKind checkTypeCorrectKind( NotNull moduleScope{module.getModuleScope().get()}; - auto canUnify = [&typeArena, singletonTypes, moduleScope](TypeId subTy, TypeId superTy) { - LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix3); - - InternalErrorReporter iceReporter; - UnifierSharedState unifierState(&iceReporter); - Unifier unifier(typeArena, singletonTypes, Mode::Strict, moduleScope, Location(), Variance::Covariant, unifierState); - - unifier.tryUnify(subTy, superTy); - bool ok = unifier.errors.empty(); - return ok; - }; - auto typeAtPosition = findExpectedTypeAt(module, node, position); if (!typeAtPosition) @@ -170,30 +157,11 @@ static TypeCorrectKind checkTypeCorrectKind( TypeId expectedType = follow(*typeAtPosition); - auto checkFunctionType = [typeArena, singletonTypes, moduleScope, &canUnify, &expectedType](const FunctionTypeVar* ftv) { - if (FFlag::LuauSelfCallAutocompleteFix3) - { - if (std::optional firstRetTy = first(ftv->retTypes)) - return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, singletonTypes); + auto checkFunctionType = [typeArena, singletonTypes, moduleScope, &expectedType](const FunctionTypeVar* ftv) { + if (std::optional firstRetTy = first(ftv->retTypes)) + return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, singletonTypes); - return false; - } - else - { - auto [retHead, retTail] = flatten(ftv->retTypes); - - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) - return true; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) - { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) - return true; - } - - return false; - } + return false; }; // We also want to suggest functions that return compatible result @@ -212,11 +180,8 @@ static TypeCorrectKind checkTypeCorrectKind( } } - if (FFlag::LuauSelfCallAutocompleteFix3) - return checkTypeMatch(ty, expectedType, NotNull{module.getModuleScope().get()}, typeArena, singletonTypes) ? TypeCorrectKind::Correct - : TypeCorrectKind::None; - else - return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + return checkTypeMatch(ty, expectedType, NotNull{module.getModuleScope().get()}, typeArena, singletonTypes) ? TypeCorrectKind::Correct + : TypeCorrectKind::None; } enum class PropIndexType @@ -230,51 +195,14 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul PropIndexType indexType, const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, std::optional containingClass = std::nullopt) { - if (FFlag::LuauSelfCallAutocompleteFix3) - rootTy = follow(rootTy); - + rootTy = follow(rootTy); ty = follow(ty); if (seen.count(ty)) return; seen.insert(ty); - auto isWrongIndexer_DEPRECATED = [indexType, useStrictFunctionIndexers = !!get(ty)](Luau::TypeId type) { - LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix3); - - if (indexType == PropIndexType::Key) - return false; - - bool colonIndex = indexType == PropIndexType::Colon; - - if (const FunctionTypeVar* ftv = get(type)) - { - return useStrictFunctionIndexers ? colonIndex != ftv->hasSelf : false; - } - else if (const IntersectionTypeVar* itv = get(type)) - { - bool allHaveSelf = true; - for (auto subType : itv->parts) - { - if (const FunctionTypeVar* ftv = get(Luau::follow(subType))) - { - allHaveSelf &= ftv->hasSelf; - } - else - { - return colonIndex; - } - } - return useStrictFunctionIndexers ? colonIndex != allHaveSelf : false; - } - else - { - return colonIndex; - } - }; auto isWrongIndexer = [typeArena, singletonTypes, &module, rootTy, indexType](Luau::TypeId type) { - LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix3); - if (indexType == PropIndexType::Key) return false; @@ -337,7 +265,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul AutocompleteEntryKind::Property, type, prop.deprecated, - FFlag::LuauSelfCallAutocompleteFix3 ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type), + isWrongIndexer(type), typeCorrect, containingClass, &prop, @@ -380,31 +308,8 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul { autocompleteProps(module, typeArena, singletonTypes, rootTy, mt->table, indexType, nodes, result, seen); - if (FFlag::LuauSelfCallAutocompleteFix3) - { - if (auto mtable = get(mt->metatable)) - fillMetatableProps(mtable); - } - else - { - auto mtable = get(mt->metatable); - if (!mtable) - return; - - auto indexIt = mtable->props.find("__index"); - if (indexIt != mtable->props.end()) - { - TypeId followed = follow(indexIt->second.type); - if (get(followed) || get(followed)) - autocompleteProps(module, typeArena, singletonTypes, rootTy, followed, indexType, nodes, result, seen); - else if (auto indexFunction = get(followed)) - { - std::optional indexFunctionResult = first(indexFunction->retTypes); - if (indexFunctionResult) - autocompleteProps(module, typeArena, singletonTypes, rootTy, *indexFunctionResult, indexType, nodes, result, seen); - } - } - } + if (auto mtable = get(mt->metatable)) + fillMetatableProps(mtable); } else if (auto i = get(ty)) { @@ -446,9 +351,6 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul AutocompleteEntryMap inner; std::unordered_set innerSeen; - if (!FFlag::LuauSelfCallAutocompleteFix3) - innerSeen = seen; - if (isNil(*iter)) { ++iter; @@ -472,7 +374,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul ++iter; } } - else if (auto pt = get(ty); pt && FFlag::LuauSelfCallAutocompleteFix3) + else if (auto pt = get(ty)) { if (pt->metatable) { @@ -480,7 +382,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul fillMetatableProps(mtable); } } - else if (FFlag::LuauSelfCallAutocompleteFix3 && get(get(ty))) + else if (get(get(ty))) { autocompleteProps(module, typeArena, singletonTypes, rootTy, singletonTypes->stringType, indexType, nodes, result, seen); } @@ -1416,11 +1318,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - if (!FFlag::LuauSelfCallAutocompleteFix3 && isString(ty)) - return {autocompleteProps(*module, &typeArena, singletonTypes, globalScope->bindings[AstName{"string"}].typeId, indexType, ancestry), - ancestry, AutocompleteContext::Property}; - else - return {autocompleteProps(*module, &typeArena, singletonTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; + return {autocompleteProps(*module, &typeArena, singletonTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; } else if (auto typeReference = node->as()) { diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 35b8387..5b3ec03 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -14,6 +14,8 @@ #include "Luau/VisitTypeVar.h" #include "Luau/TypeUtils.h" +#include + LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); LUAU_FASTFLAG(LuauFixNameMaps) @@ -251,10 +253,11 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) } } -ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull singletonTypes, NotNull rootScope, ModuleName moduleName, +ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) - : arena(arena) - , singletonTypes(singletonTypes) + : arena(normalizer->arena) + , singletonTypes(normalizer->singletonTypes) + , normalizer(normalizer) , constraints(collectConstraints(rootScope)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) @@ -278,6 +281,12 @@ ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull sin LUAU_ASSERT(logger); } +void ConstraintSolver::randomize(unsigned seed) +{ + std::mt19937 g(seed); + std::shuffle(begin(unsolvedConstraints), end(unsolvedConstraints), g); +} + void ConstraintSolver::run() { if (done()) @@ -1355,8 +1364,7 @@ bool ConstraintSolver::isBlocked(NotNull constraint) void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull scope) { - UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, singletonTypes, Mode::Strict, scope, Location{}, Covariant, sharedState}; + Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; u.useScopes = true; u.tryUnify(subType, superType); @@ -1379,7 +1387,7 @@ void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull sc void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope) { UnifierSharedState sharedState{&iceReporter}; - Unifier u{arena, singletonTypes, Mode::Strict, scope, Location{}, Covariant, sharedState}; + Unifier u{normalizer, Mode::Strict, scope, Location{}, Covariant}; u.useScopes = true; u.tryUnify(subPack, superPack); diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index aa496ee..d13e26c 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -511,11 +511,11 @@ TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reas { } -TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeError error) +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, std::optional error) : wantedType(wantedType) , givenType(givenType) , reason(reason) - , error(std::make_shared(std::move(error))) + , error(error ? std::make_shared(std::move(*error)) : nullptr) { } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 1890e08..5705ac1 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -860,12 +860,18 @@ ModulePtr Frontend::check( const NotNull mr{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}; const ScopePtr& globalScope{forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope}; + Normalizer normalizer{&result->internalTypes, singletonTypes, NotNull{&typeChecker.unifierState}}; + ConstraintGraphBuilder cgb{ sourceModule.name, result, &result->internalTypes, mr, singletonTypes, NotNull(&iceHandler), globalScope, logger.get()}; cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); - ConstraintSolver cs{&result->internalTypes, singletonTypes, NotNull(cgb.rootScope), sourceModule.name, NotNull(&moduleResolver), requireCycles, logger.get()}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), sourceModule.name, NotNull(&moduleResolver), requireCycles, logger.get()}; + + if (options.randomizeConstraintResolutionSeed) + cs.randomize(*options.randomizeConstraintResolutionSeed); + cs.run(); for (TypeError& e : cs.errors) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index b9deac7..45eb87d 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,6 +14,7 @@ #include +LUAU_FASTFLAG(LuauAnyifyModuleReturnGenerics) LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAGVARIABLE(LuauForceExportSurfacesToBeNormal, false); @@ -285,13 +286,16 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern } } - for (TypeId ty : returnType) + if (!FFlag::LuauAnyifyModuleReturnGenerics) { - if (get(follow(ty))) + for (TypeId ty : returnType) { - auto t = asMutable(ty); - t->ty = AnyTypeVar{}; - t->normal = true; + if (get(follow(ty))) + { + auto t = asMutable(ty); + t->ty = AnyTypeVar{}; + t->normal = true; + } } } diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 42f6151..c008bcf 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Normalize.h" +#include "Luau/ToString.h" #include @@ -10,15 +11,1703 @@ #include "Luau/VisitTypeVar.h" LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false) +LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); +LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); +LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); +LUAU_FASTFLAGVARIABLE(LuauTypeNormalization2, false); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) namespace Luau { +void TypeIds::insert(TypeId ty) +{ + ty = follow(ty); + auto [_, fresh] = types.insert(ty); + if (fresh) + { + order.push_back(ty); + hash ^= std::hash{}(ty); + } +} + +void TypeIds::clear() +{ + order.clear(); + types.clear(); + hash = 0; +} + +TypeIds::iterator TypeIds::begin() +{ + return order.begin(); +} + +TypeIds::iterator TypeIds::end() +{ + return order.end(); +} + +TypeIds::const_iterator TypeIds::begin() const +{ + return order.begin(); +} + +TypeIds::const_iterator TypeIds::end() const +{ + return order.end(); +} + +TypeIds::iterator TypeIds::erase(TypeIds::const_iterator it) +{ + TypeId ty = *it; + types.erase(ty); + hash ^= std::hash{}(ty); + return order.erase(it); +} + +size_t TypeIds::size() const +{ + return types.size(); +} + +bool TypeIds::empty() const +{ + return types.empty(); +} + +size_t TypeIds::count(TypeId ty) const +{ + ty = follow(ty); + return types.count(ty); +} + +void TypeIds::retain(const TypeIds& there) +{ + for (auto it = begin(); it != end();) + { + if (there.count(*it)) + it++; + else + it = erase(it); + } +} + +size_t TypeIds::getHash() const +{ + return hash; +} + +bool TypeIds::operator==(const TypeIds& there) const +{ + return hash == there.hash && types == there.types; +} + +NormalizedType::NormalizedType(NotNull singletonTypes) + : tops(singletonTypes->neverType) + , booleans(singletonTypes->neverType) + , errors(singletonTypes->neverType) + , nils(singletonTypes->neverType) + , numbers(singletonTypes->neverType) + , threads(singletonTypes->neverType) +{ +} + +static bool isInhabited(const NormalizedType& norm) +{ + return !get(norm.tops) + || !get(norm.booleans) + || !norm.classes.empty() + || !get(norm.errors) + || !get(norm.nils) + || !get(norm.numbers) + || !norm.strings || !norm.strings->empty() + || !get(norm.threads) + || norm.functions + || !norm.tables.empty() + || !norm.tyvars.empty(); +} + +static int tyvarIndex(TypeId ty) +{ + if (const GenericTypeVar* gtv = get(ty)) + return gtv->index; + else if (const FreeTypeVar* ftv = get(ty)) + return ftv->index; + else + return 0; +} + +#ifdef LUAU_ASSERTENABLED + +static bool isNormalizedTop(TypeId ty) +{ + return get(ty) || get(ty) || get(ty); +} + +static bool isNormalizedBoolean(TypeId ty) +{ + if (get(ty)) + return true; + else if (const PrimitiveTypeVar* ptv = get(ty)) + return ptv->type == PrimitiveTypeVar::Boolean; + else if (const SingletonTypeVar* stv = get(ty)) + return get(stv); + else + return false; +} + +static bool isNormalizedError(TypeId ty) +{ + if (get(ty) || get(ty)) + return true; + else + return false; +} + +static bool isNormalizedNil(TypeId ty) +{ + if (get(ty)) + return true; + else if (const PrimitiveTypeVar* ptv = get(ty)) + return ptv->type == PrimitiveTypeVar::NilType; + else + return false; +} + +static bool isNormalizedNumber(TypeId ty) +{ + if (get(ty)) + return true; + else if (const PrimitiveTypeVar* ptv = get(ty)) + return ptv->type == PrimitiveTypeVar::Number; + else + return false; +} + +static bool isNormalizedString(const NormalizedStringType& ty) +{ + if (!ty) + return true; + + for (auto& [str, ty] : *ty) + { + if (const SingletonTypeVar* stv = get(ty)) + { + if (const StringSingleton* sstv = get(stv)) + { + if (sstv->value != str) + return false; + } + else + return false; + } + else + return false; + } + + return true; +} + +static bool isNormalizedThread(TypeId ty) +{ + if (get(ty)) + return true; + else if (const PrimitiveTypeVar* ptv = get(ty)) + return ptv->type == PrimitiveTypeVar::Thread; + else + return false; +} + +static bool areNormalizedFunctions(const NormalizedFunctionType& tys) +{ + if (tys) + for (TypeId ty : *tys) + if (!get(ty) && !get(ty)) + return false; + return true; +} + +static bool areNormalizedTables(const TypeIds& tys) +{ + for (TypeId ty : tys) + if (!get(ty) && !get(ty)) + return false; + return true; +} + +static bool areNormalizedClasses(const TypeIds& tys) +{ + for (TypeId ty : tys) + if (!get(ty)) + return false; + return true; +} + +static bool isPlainTyvar(TypeId ty) +{ + return (get(ty) || get(ty)); +} + +static bool isNormalizedTyvar(const NormalizedTyvars& tyvars) +{ + for (auto& [tyvar, intersect] : tyvars) + { + if (!isPlainTyvar(tyvar)) + return false; + if (!isInhabited(*intersect)) + return false; + for (auto& [other, _] : intersect->tyvars) + if (tyvarIndex(other) <= tyvarIndex(tyvar)) + return false; + } + return true; +} + +#endif // LUAU_ASSERTENABLED + +static void assertInvariant(const NormalizedType& norm) +{ + #ifdef LUAU_ASSERTENABLED + if (!FFlag::DebugLuauCheckNormalizeInvariant) + return; + + LUAU_ASSERT(isNormalizedTop(norm.tops)); + LUAU_ASSERT(isNormalizedBoolean(norm.booleans)); + LUAU_ASSERT(areNormalizedClasses(norm.classes)); + LUAU_ASSERT(isNormalizedError(norm.errors)); + LUAU_ASSERT(isNormalizedNil(norm.nils)); + LUAU_ASSERT(isNormalizedNumber(norm.numbers)); + LUAU_ASSERT(isNormalizedString(norm.strings)); + LUAU_ASSERT(isNormalizedThread(norm.threads)); + LUAU_ASSERT(areNormalizedFunctions(norm.functions)); + LUAU_ASSERT(areNormalizedTables(norm.tables)); + LUAU_ASSERT(isNormalizedTyvar(norm.tyvars)); + for (auto& [_, child] : norm.tyvars) + assertInvariant(*child); + #endif +} + +Normalizer::Normalizer(TypeArena* arena, NotNull singletonTypes, NotNull sharedState) + : arena(arena) + , singletonTypes(singletonTypes) + , sharedState(sharedState) +{ +} + +const NormalizedType* Normalizer::normalize(TypeId ty) +{ + if (!arena) + sharedState->iceHandler->ice("Normalizing types outside a module"); + + auto found = cachedNormals.find(ty); + if (found != cachedNormals.end()) + return found->second.get(); + + NormalizedType norm{singletonTypes}; + if (!unionNormalWithTy(norm, ty)) + return nullptr; + std::unique_ptr uniq = std::make_unique(std::move(norm)); + const NormalizedType* result = uniq.get(); + cachedNormals[ty] = std::move(uniq); + return result; +} + +void Normalizer::clearNormal(NormalizedType& norm) +{ + norm.tops = singletonTypes->neverType; + norm.booleans = singletonTypes->neverType; + norm.classes.clear(); + norm.errors = singletonTypes->neverType; + norm.nils = singletonTypes->neverType; + norm.numbers = singletonTypes->neverType; + if (norm.strings) + norm.strings->clear(); + else + norm.strings.emplace(); + norm.threads = singletonTypes->neverType; + norm.tables.clear(); + norm.functions = std::nullopt; + norm.tyvars.clear(); +} + +// ------- Cached TypeIds +const TypeIds* Normalizer::cacheTypeIds(TypeIds tys) +{ + auto found = cachedTypeIds.find(&tys); + if (found != cachedTypeIds.end()) + return found->first; + + std::unique_ptr uniq = std::make_unique(std::move(tys)); + const TypeIds* result = uniq.get(); + cachedTypeIds[result] = std::move(uniq); + return result; +} + +TypeId Normalizer::unionType(TypeId here, TypeId there) +{ + here = follow(here); + there = follow(there); + + if (here == there) + return here; + if (get(here) || get(there)) + return there; + if (get(there) || get(here)) + return here; + + TypeIds tmps; + + if (const UnionTypeVar* utv = get(here)) + { + TypeIds heres; + heres.insert(begin(utv), end(utv)); + tmps.insert(heres.begin(), heres.end()); + cachedUnions[cacheTypeIds(std::move(heres))] = here; + } + else + tmps.insert(here); + + if (const UnionTypeVar* utv = get(there)) + { + TypeIds theres; + theres.insert(begin(utv), end(utv)); + tmps.insert(theres.begin(), theres.end()); + cachedUnions[cacheTypeIds(std::move(theres))] = there; + } + else + tmps.insert(there); + + auto cacheHit = cachedUnions.find(&tmps); + if (cacheHit != cachedUnions.end()) + return cacheHit->second; + + std::vector parts; + parts.insert(parts.end(), tmps.begin(), tmps.end()); + TypeId result = arena->addType(UnionTypeVar{std::move(parts)}); + cachedUnions[cacheTypeIds(std::move(tmps))] = result; + + return result; +} + +TypeId Normalizer::intersectionType(TypeId here, TypeId there) +{ + here = follow(here); + there = follow(there); + + if (here == there) + return here; + if (get(here) || get(there)) + return here; + if (get(there) || get(here)) + return there; + + TypeIds tmps; + + if (const IntersectionTypeVar* utv = get(here)) + { + TypeIds heres; + heres.insert(begin(utv), end(utv)); + tmps.insert(heres.begin(), heres.end()); + cachedIntersections[cacheTypeIds(std::move(heres))] = here; + } + else + tmps.insert(here); + + if (const IntersectionTypeVar* utv = get(there)) + { + TypeIds theres; + theres.insert(begin(utv), end(utv)); + tmps.insert(theres.begin(), theres.end()); + cachedIntersections[cacheTypeIds(std::move(theres))] = there; + } + else + tmps.insert(there); + + if (tmps.size() == 1) + return *tmps.begin(); + + auto cacheHit = cachedIntersections.find(&tmps); + if (cacheHit != cachedIntersections.end()) + return cacheHit->second; + + std::vector parts; + parts.insert(parts.end(), tmps.begin(), tmps.end()); + TypeId result = arena->addType(IntersectionTypeVar{std::move(parts)}); + cachedIntersections[cacheTypeIds(std::move(tmps))] = result; + + return result; +} + +void Normalizer::clearCaches() +{ + cachedNormals.clear(); + cachedIntersections.clear(); + cachedUnions.clear(); + cachedTypeIds.clear(); +} + +// ------- Normalizing unions +TypeId Normalizer::unionOfTops(TypeId here, TypeId there) +{ + if (get(here) || get(there)) + return there; + else + return here; +} + +TypeId Normalizer::unionOfBools(TypeId here, TypeId there) +{ + if (get(here)) + return there; + if (get(there)) + return here; + if (const BooleanSingleton* hbool = get(get(here))) + if (const BooleanSingleton* tbool = get(get(there))) + if (hbool->value == tbool->value) + return here; + return singletonTypes->booleanType; +} + +void Normalizer::unionClassesWithClass(TypeIds& heres, TypeId there) +{ + if (heres.count(there)) + return; + + const ClassTypeVar* tctv = get(there); + + for (auto it = heres.begin(); it != heres.end();) + { + TypeId here = *it; + const ClassTypeVar* hctv = get(here); + if (isSubclass(tctv, hctv)) + return; + else if (isSubclass(hctv, tctv)) + it = heres.erase(it); + else + it++; + } + + heres.insert(there); +} + +void Normalizer::unionClasses(TypeIds& heres, const TypeIds& theres) +{ + for (TypeId there : theres) + unionClassesWithClass(heres, there); +} + +void Normalizer::unionStrings(NormalizedStringType& here, const NormalizedStringType& there) +{ + if (!there) + here.reset(); + else if (here) + here->insert(there->begin(), there->end()); +} + +std::optional Normalizer::unionOfTypePacks(TypePackId here, TypePackId there) +{ + if (here == there) + return here; + + std::vector head; + std::optional tail; + + bool hereSubThere = true; + bool thereSubHere = true; + + TypePackIterator ith = begin(here); + TypePackIterator itt = begin(there); + + while (ith != end(here) && itt != end(there)) + { + TypeId hty = *ith; + TypeId tty = *itt; + TypeId ty = unionType(hty, tty); + if (ty != hty) + thereSubHere = false; + if (ty != tty) + hereSubThere = false; + head.push_back(ty); + ith++; + itt++; + } + + auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, bool& thereSubHere) + { + if (ith != end(here)) + { + TypeId tty = singletonTypes->nilType; + if (std::optional ttail = itt.tail()) + { + if (const VariadicTypePack* tvtp = get(*ttail)) + tty = tvtp->ty; + else + // Luau doesn't have unions of type pack variables + return false; + } + else + // Type packs of different arities are incomparable + return false; + + while (ith != end(here)) + { + TypeId hty = *ith; + TypeId ty = unionType(hty, tty); + if (ty != hty) + thereSubHere = false; + if (ty != tty) + hereSubThere = false; + head.push_back(ty); + ith++; + } + } + return true; + }; + + if (!dealWithDifferentArities(ith, itt, here, there, hereSubThere, thereSubHere)) + return std::nullopt; + + if (!dealWithDifferentArities(itt, ith, there, here, thereSubHere, hereSubThere)) + return std::nullopt; + + if (std::optional htail = ith.tail()) + { + if (std::optional ttail = itt.tail()) + { + if (*htail == *ttail) + tail = htail; + else if (const VariadicTypePack* hvtp = get(*htail)) + { + if (const VariadicTypePack* tvtp = get(*ttail)) + { + TypeId ty = unionType(hvtp->ty, tvtp->ty); + if (ty != hvtp->ty) + thereSubHere = false; + if (ty != tvtp->ty) + hereSubThere = false; + bool hidden = hvtp->hidden & tvtp->hidden; + tail = arena->addTypePack(VariadicTypePack{ty,hidden}); + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else if (get(*htail)) + { + hereSubThere = false; + tail = htail; + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else if (std::optional ttail = itt.tail()) + { + if (get(*ttail)) + { + thereSubHere = false; + tail = htail; + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + + if (hereSubThere) + return there; + else if (thereSubHere) + return here; + if (!head.empty()) + return arena->addTypePack(TypePack{head,tail}); + else if (tail) + return *tail; + else + // TODO: Add an emptyPack to singleton types + return arena->addTypePack({}); +} + +std::optional Normalizer::unionOfFunctions(TypeId here, TypeId there) +{ + if (get(here)) + return here; + + if (get(there)) + return there; + + const FunctionTypeVar* hftv = get(here); + LUAU_ASSERT(hftv); + const FunctionTypeVar* tftv = get(there); + LUAU_ASSERT(tftv); + + if (hftv->generics != tftv->generics) + return std::nullopt; + if (hftv->genericPacks != tftv->genericPacks) + return std::nullopt; + + std::optional argTypes = intersectionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!argTypes) + return std::nullopt; + + std::optional retTypes = unionOfTypePacks(hftv->retTypes, tftv->retTypes); + if (!retTypes) + return std::nullopt; + + if (*argTypes == hftv->argTypes && *retTypes == hftv->retTypes) + return here; + if (*argTypes == tftv->argTypes && *retTypes == tftv->retTypes) + return there; + + FunctionTypeVar result{*argTypes, *retTypes}; + result.generics = hftv->generics; + result.genericPacks = hftv->genericPacks; + return arena->addType(std::move(result)); +} + +void Normalizer::unionFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) +{ + if (!theres) + return; + + TypeIds tmps; + + if (!heres) + { + tmps.insert(theres->begin(), theres->end()); + heres = std::move(tmps); + return; + } + + for (TypeId here : *heres) + for (TypeId there : *theres) + { + if (std::optional fun = unionOfFunctions(here, there)) + tmps.insert(*fun); + else + tmps.insert(singletonTypes->errorRecoveryType(there)); + } + + heres = std::move(tmps); +} + +void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) +{ + if (!heres) + { + TypeIds tmps; + tmps.insert(there); + heres = std::move(tmps); + return; + } + + TypeIds tmps; + for (TypeId here : *heres) + { + if (std::optional fun = unionOfFunctions(here, there)) + tmps.insert(*fun); + else + tmps.insert(singletonTypes->errorRecoveryType(there)); + } + heres = std::move(tmps); +} + +void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there) +{ + // TODO: remove unions of tables where possible + heres.insert(there); +} + +void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) +{ + for (TypeId there : theres) + unionTablesWithTable(heres, there); +} + +// So why `ignoreSmallerTyvars`? +// +// First up, what it does... Every tyvar has an index, and this parameter says to ignore +// any tyvars in `there` if their index is less than or equal to the parameter. +// The parameter is always greater than any tyvars mentioned in here, so the result is +// a lower bound on any tyvars in `here.tyvars`. +// +// This is used to maintain in invariant, which is that in any tyvar `X&T`, any any tyvar +// `Y&U` in `T`, the index of `X` is less than the index of `Y`. This is an implementation +// of *ordered decision diagrams* (https://en.wikipedia.org/wiki/Binary_decision_diagram#Variable_ordering) +// which are a compression technique used to save memory usage when representing boolean formulae. +// +// The idea is that if you have an out-of-order decision diagram +// like `Z&(X|Y)`, to re-order it in this case to `(X&Z)|(Y&Z)`. +// The hope is that by imposing a global order, there's a higher chance of sharing opportunities, +// and hence reduced memory. +// +// And yes, this is essentially a SAT solver hidden inside a typechecker. +// That's what you get for having a type system with generics, intersection and union types. +bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) +{ + TypeId tops = unionOfTops(here.tops, there.tops); + if (!get(tops)) + { + clearNormal(here); + here.tops = tops; + return true; + } + + for (auto it = there.tyvars.begin(); it != there.tyvars.end(); it++) + { + TypeId tyvar = it->first; + const NormalizedType& inter = *it->second; + int index = tyvarIndex(tyvar); + if (index <= ignoreSmallerTyvars) + continue; + auto [emplaced, fresh] = here.tyvars.emplace(tyvar, std::make_unique(NormalizedType{singletonTypes})); + if (fresh) + if (!unionNormals(*emplaced->second, here, index)) + return false; + if (!unionNormals(*emplaced->second, inter, index)) + return false; + } + + here.booleans = unionOfBools(here.booleans, there.booleans); + unionClasses(here.classes, there.classes); + here.errors = (get(there.errors) ? here.errors : there.errors); + here.nils = (get(there.nils) ? here.nils : there.nils); + here.numbers = (get(there.numbers) ? here.numbers : there.numbers); + unionStrings(here.strings, there.strings); + here.threads = (get(there.threads) ? here.threads : there.threads); + unionFunctions(here.functions, there.functions); + unionTables(here.tables, there.tables); + return true; +} + +bool Normalizer::withinResourceLimits() +{ + // If cache is too large, clear it + if (FInt::LuauNormalizeCacheLimit > 0) + { + size_t cacheUsage = cachedNormals.size() + cachedIntersections.size() + cachedUnions.size() + cachedTypeIds.size(); + if (cacheUsage > size_t(FInt::LuauNormalizeCacheLimit)) + { + clearCaches(); + return false; + } + } + + // Check the recursion count + if (sharedState->counters.recursionLimit > 0) + if (sharedState->counters.recursionLimit < sharedState->counters.recursionCount) + return false; + + return true; +} + +// See above for an explaination of `ignoreSmallerTyvars`. +bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignoreSmallerTyvars) +{ + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (!withinResourceLimits()) + return false; + + there = follow(there); + if (get(there) || get(there)) + { + TypeId tops = unionOfTops(here.tops, there); + clearNormal(here); + here.tops = tops; + return true; + } + else if (get(there) || !get(here.tops)) + return true; + else if (const UnionTypeVar* utv = get(there)) + { + for (UnionTypeVarIterator it = begin(utv); it != end(utv); ++it) + if (!unionNormalWithTy(here, *it)) + return false; + return true; + } + else if (const IntersectionTypeVar* itv = get(there)) + { + NormalizedType norm{singletonTypes}; + norm.tops = singletonTypes->anyType; + for (IntersectionTypeVarIterator it = begin(itv); it != end(itv); ++it) + if (!intersectNormalWithTy(norm, *it)) + return false; + return unionNormals(here, norm); + } + else if (get(there) || get(there)) + { + if (tyvarIndex(there) <= ignoreSmallerTyvars) + return true; + NormalizedType inter{singletonTypes}; + inter.tops = singletonTypes->unknownType; + here.tyvars.insert_or_assign(there, std::make_unique(std::move(inter))); + } + else if (get(there)) + unionFunctionsWithFunction(here.functions, there); + else if (get(there) || get(there)) + unionTablesWithTable(here.tables, there); + else if (get(there)) + unionClassesWithClass(here.classes, there); + else if (get(there)) + here.errors = there; + else if (const PrimitiveTypeVar* ptv = get(there)) + { + if (ptv->type == PrimitiveTypeVar::Boolean) + here.booleans = there; + else if (ptv->type == PrimitiveTypeVar::NilType) + here.nils = there; + else if (ptv->type == PrimitiveTypeVar::Number) + here.numbers = there; + else if (ptv->type == PrimitiveTypeVar::String) + here.strings = std::nullopt; + else if (ptv->type == PrimitiveTypeVar::Thread) + here.threads = there; + else + LUAU_ASSERT(!"Unreachable"); + } + else if (const SingletonTypeVar* stv = get(there)) + { + if (get(stv)) + here.booleans = unionOfBools(here.booleans, there); + else if (const StringSingleton* sstv = get(stv)) + { + if (here.strings) + here.strings->insert({sstv->value, there}); + } + else + LUAU_ASSERT(!"Unreachable"); + } + else + LUAU_ASSERT(!"Unreachable"); + + for (auto& [tyvar, intersect] : here.tyvars) + if (!unionNormalWithTy(*intersect, there, tyvarIndex(tyvar))) + return false; + + assertInvariant(here); + return true; +} + +// ------- Normalizing intersections +TypeId Normalizer::intersectionOfTops(TypeId here, TypeId there) +{ + if (get(here) || get(there)) + return here; + else + return there; +} + +TypeId Normalizer::intersectionOfBools(TypeId here, TypeId there) +{ + if (get(here)) + return here; + if (get(there)) + return there; + if (const BooleanSingleton* hbool = get(get(here))) + if (const BooleanSingleton* tbool = get(get(there))) + return (hbool->value == tbool->value ? here : singletonTypes->neverType); + else + return here; + else + return there; +} + +void Normalizer::intersectClasses(TypeIds& heres, const TypeIds& theres) +{ + TypeIds tmp; + for (auto it = heres.begin(); it != heres.end();) + { + const ClassTypeVar* hctv = get(*it); + LUAU_ASSERT(hctv); + bool keep = false; + for (TypeId there : theres) + { + const ClassTypeVar* tctv = get(there); + LUAU_ASSERT(tctv); + if (isSubclass(hctv, tctv)) + { + keep = true; + break; + } + else if (isSubclass(tctv, hctv)) + { + keep = false; + tmp.insert(there); + break; + } + } + if (keep) + it++; + else + it = heres.erase(it); + } + heres.insert(tmp.begin(), tmp.end()); +} + +void Normalizer::intersectClassesWithClass(TypeIds& heres, TypeId there) +{ + bool foundSuper = false; + const ClassTypeVar* tctv = get(there); + LUAU_ASSERT(tctv); + for (auto it = heres.begin(); it != heres.end();) + { + const ClassTypeVar* hctv = get(*it); + LUAU_ASSERT(hctv); + if (isSubclass(hctv, tctv)) + it++; + else if (isSubclass(tctv, hctv)) + { + foundSuper = true; + break; + } + else + it = heres.erase(it); + } + if (foundSuper) + { + heres.clear(); + heres.insert(there); + } +} + +void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there) +{ + if (!there) + return; + if (!here) + here.emplace(); + + for (auto it = here->begin(); it != here->end();) + { + if (there->count(it->first)) + it++; + else + it = here->erase(it); + } +} + +std::optional Normalizer::intersectionOfTypePacks(TypePackId here, TypePackId there) +{ + if (here == there) + return here; + + std::vector head; + std::optional tail; + + bool hereSubThere = true; + bool thereSubHere = true; + + TypePackIterator ith = begin(here); + TypePackIterator itt = begin(there); + + while (ith != end(here) && itt != end(there)) + { + TypeId hty = *ith; + TypeId tty = *itt; + TypeId ty = intersectionType(hty, tty); + if (ty != hty) + hereSubThere = false; + if (ty != tty) + thereSubHere = false; + head.push_back(ty); + ith++; + itt++; + } + + auto dealWithDifferentArities = [&](TypePackIterator& ith, TypePackIterator itt, TypePackId here, TypePackId there, bool& hereSubThere, bool& thereSubHere) + { + if (ith != end(here)) + { + TypeId tty = singletonTypes->nilType; + if (std::optional ttail = itt.tail()) + { + if (const VariadicTypePack* tvtp = get(*ttail)) + tty = tvtp->ty; + else + // Luau doesn't have intersections of type pack variables + return false; + } + else + // Type packs of different arities are incomparable + return false; + + while (ith != end(here)) + { + TypeId hty = *ith; + TypeId ty = intersectionType(hty, tty); + if (ty != hty) + hereSubThere = false; + if (ty != tty) + thereSubHere = false; + head.push_back(ty); + ith++; + } + } + return true; + }; + + if (!dealWithDifferentArities(ith, itt, here, there, hereSubThere, thereSubHere)) + return std::nullopt; + + if (!dealWithDifferentArities(itt, ith, there, here, thereSubHere, hereSubThere)) + return std::nullopt; + + if (std::optional htail = ith.tail()) + { + if (std::optional ttail = itt.tail()) + { + if (*htail == *ttail) + tail = htail; + else if (const VariadicTypePack* hvtp = get(*htail)) + { + if (const VariadicTypePack* tvtp = get(*ttail)) + { + TypeId ty = intersectionType(hvtp->ty, tvtp->ty); + if (ty != hvtp->ty) + thereSubHere = false; + if (ty != tvtp->ty) + hereSubThere = false; + bool hidden = hvtp->hidden & tvtp->hidden; + tail = arena->addTypePack(VariadicTypePack{ty,hidden}); + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else if (get(*htail)) + hereSubThere = false; + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + else if (std::optional ttail = itt.tail()) + { + if (get(*ttail)) + thereSubHere = false; + else + // Luau doesn't have unions of type pack variables + return std::nullopt; + } + + if (hereSubThere) + return here; + else if (thereSubHere) + return there; + if (!head.empty()) + return arena->addTypePack(TypePack{head,tail}); + else if (tail) + return *tail; + else + // TODO: Add an emptyPack to singleton types + return arena->addTypePack({}); +} + +std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there) +{ + if (here == there) + return here; + + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (sharedState->counters.recursionLimit > 0 && sharedState->counters.recursionLimit < sharedState->counters.recursionCount) + return std::nullopt; + + TypeId htable = here; + TypeId hmtable = nullptr; + if (const MetatableTypeVar* hmtv = get(here)) + { + htable = hmtv->table; + hmtable = hmtv->metatable; + } + TypeId ttable = there; + TypeId tmtable = nullptr; + if (const MetatableTypeVar* tmtv = get(there)) + { + ttable = tmtv->table; + tmtable = tmtv->metatable; + } + + const TableTypeVar* httv = get(htable); + LUAU_ASSERT(httv); + const TableTypeVar* tttv = get(ttable); + LUAU_ASSERT(tttv); + + if (httv->state == TableState::Free || tttv->state == TableState::Free) + return std::nullopt; + if (httv->state == TableState::Generic || tttv->state == TableState::Generic) + return std::nullopt; + + TableState state = httv->state; + if (tttv->state == TableState::Unsealed) + state = tttv->state; + + TypeLevel level = max(httv->level, tttv->level); + TableTypeVar result{state, level}; + + bool hereSubThere = true; + bool thereSubHere = true; + + for (const auto& [name, hprop] : httv->props) + { + Property prop = hprop; + auto tfound = tttv->props.find(name); + if (tfound == tttv->props.end()) + thereSubHere = false; + else + { + const auto& [_name, tprop] = *tfound; + // TODO: variance issues here, which can't be fixed until we have read/write property types + prop.type = intersectionType(hprop.type, tprop.type); + hereSubThere &= (prop.type == hprop.type); + thereSubHere &= (prop.type == tprop.type); + } + // TODO: string indexers + result.props[name] = prop; + } + + for (const auto& [name, tprop] : tttv->props) + { + if (httv->props.count(name) == 0) + { + result.props[name] = tprop; + hereSubThere = false; + } + } + + if (httv->indexer && tttv->indexer) + { + // TODO: What should intersection of indexes be? + TypeId index = unionType(httv->indexer->indexType, tttv->indexer->indexType); + TypeId indexResult = intersectionType(httv->indexer->indexResultType, tttv->indexer->indexResultType); + result.indexer = {index, indexResult}; + hereSubThere &= (httv->indexer->indexType == index) && (httv->indexer->indexResultType == indexResult); + thereSubHere &= (tttv->indexer->indexType == index) && (tttv->indexer->indexResultType == indexResult); + } + else if (httv->indexer) + { + result.indexer = httv->indexer; + thereSubHere = false; + } + else if (tttv->indexer) + { + result.indexer = tttv->indexer; + hereSubThere = false; + } + + TypeId table; + if (hereSubThere) + table = htable; + else if (thereSubHere) + table = ttable; + else + table = arena->addType(std::move(result)); + + if (tmtable && hmtable) + { + // NOTE: this assumes metatables are ivariant + if (std::optional mtable = intersectionOfTables(hmtable, tmtable)) + { + if (table == htable && *mtable == hmtable) + return here; + else if (table == ttable && *mtable == tmtable) + return there; + else + return arena->addType(MetatableTypeVar{table, *mtable}); + } + else + return std::nullopt; + + } + else if (hmtable) + { + if (table == htable) + return here; + else + return arena->addType(MetatableTypeVar{table, hmtable}); + } + else if (tmtable) + { + if (table == ttable) + return there; + else + return arena->addType(MetatableTypeVar{table, tmtable}); + } + else + return table; +} + +void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there) +{ + TypeIds tmp; + for (TypeId here : heres) + if (std::optional inter = intersectionOfTables(here, there)) + tmp.insert(*inter); + heres.retain(tmp); + heres.insert(tmp.begin(), tmp.end()); +} + +void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres) +{ + TypeIds tmp; + for (TypeId here : heres) + for (TypeId there : theres) + if (std::optional inter = intersectionOfTables(here, there)) + tmp.insert(*inter); + heres.retain(tmp); + heres.insert(tmp.begin(), tmp.end()); +} + +std::optional Normalizer::intersectionOfFunctions(TypeId here, TypeId there) +{ + const FunctionTypeVar* hftv = get(here); + LUAU_ASSERT(hftv); + const FunctionTypeVar* tftv = get(there); + LUAU_ASSERT(tftv); + + if (hftv->generics != tftv->generics) + return std::nullopt; + if (hftv->genericPacks != tftv->genericPacks) + return std::nullopt; + if (hftv->retTypes != tftv->retTypes) + return std::nullopt; + + std::optional argTypes = unionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!argTypes) + return std::nullopt; + + if (*argTypes == hftv->argTypes) + return here; + if (*argTypes == tftv->argTypes) + return there; + + FunctionTypeVar result{*argTypes, hftv->retTypes}; + result.generics = hftv->generics; + result.genericPacks = hftv->genericPacks; + return arena->addType(std::move(result)); +} + +std::optional Normalizer::unionSaturatedFunctions(TypeId here, TypeId there) +{ + // Deep breath... + // + // When we come to check overloaded functions for subtyping, + // we have to compare (F1 & ... & FM) <: (G1 & ... G GN) + // where each Fi or Gj is a function type. Now that intersection on the right is no + // problem, since that's true if and only if (F1 & ... & FM) <: Gj for every j. + // But the intersection on the left is annoying, since we might have + // (F1 & ... & FM) <: G but no Fi <: G. For example + // + // ((number? -> number?) & (string? -> string?)) <: (nil -> nil) + // + // So in this case, what we do is define Apply for the result of applying + // a function of type F to an argument of type T, and then F <: (T -> U) + // if and only if Apply <: U. For example: + // + // if f : ((number? -> number?) & (string? -> string?)) + // then f(nil) must be nil, so + // Apply<((number? -> number?) & (string? -> string?)), nil> is nil + // + // So subtyping on overloaded functions "just" boils down to defining Apply. + // + // Now for non-overloaded functions, this is easy! + // Apply<(R -> S), T> is S if T <: R, and an error type otherwise. + // + // But for overloaded functions it's not so simple. We'd like Apply + // to just be Apply & ... & Apply but oh dear + // + // if f : ((number -> number) & (string -> string)) + // and x : (number | string) + // then f(x) : (number | string) + // + // so we want + // + // Apply<((number -> number) & (string -> string)), (number | string)> is (number | string) + // + // but + // + // Apply<(number -> number), (number | string)> is an error + // Apply<(string -> string), (number | string)> is an error + // + // that is Apply should consider all possible combinations of overloads of F, + // not just individual overloads. + // + // For this reason, when we're normalizing function types (in order to check subtyping + // or perform overload resolution) we should first *union-saturate* them. An overloaded + // function is union-saturated whenever: + // + // if (R -> S) is an overload of F + // and (T -> U) is an overload of F + // then ((R | T) -> (S | U)) is a subtype of an overload of F + // + // Any overloaded function can be normalized to a union-saturated one by adding enough extra overloads. + // For example, union-saturating + // + // ((number -> number) & (string -> string)) + // + // is + // + // ((number -> number) & (string -> string) & ((number | string) -> (number | string))) + // + // For union-saturated overloaded functions, the "obvious" algorithm works: + // + // Apply is Apply & ... & Apply + // + // so we can define Apply, so we can perform overloaded function resolution + // and check subtyping on overloaded function types, yay! + // + // This is yet another potential source of exponential blow-up, sigh, since + // the union-saturation of a function with N overloads may have 2^N overloads + // (one for every subset). In practice, that hopefully won't happen that often, + // in particular we only union-saturate overloads with different return types, + // and there are hopefully not very many cases of that. + // + // All of this is mechanically verified in Agda, at https://github.com/luau-lang/agda-typeck + // + // It is essentially the algorithm defined in https://pnwamk.github.io/sst-tutorial/ + // except that we're precomputing the union-saturation rather than converting + // to disjunctive normal form on the fly. + // + // This is all built on semantic subtyping: + // + // Covariance and Contravariance, Giuseppe Castagna, + // Logical Methods in Computer Science 16(1), 2022 + // https://arxiv.org/abs/1809.01427 + // + // A gentle introduction to semantic subtyping, Giuseppe Castagna and Alain Frisch, + // Proc. Principles and practice of declarative programming 2005, pp 198–208 + // https://doi.org/10.1145/1069774.1069793 + + const FunctionTypeVar* hftv = get(here); + if (!hftv) + return std::nullopt; + const FunctionTypeVar* tftv = get(there); + if (!tftv) + return std::nullopt; + + if (hftv->generics != tftv->generics) + return std::nullopt; + if (hftv->genericPacks != tftv->genericPacks) + return std::nullopt; + + std::optional argTypes = unionOfTypePacks(hftv->argTypes, tftv->argTypes); + if (!argTypes) + return std::nullopt; + std::optional retTypes = unionOfTypePacks(hftv->retTypes, tftv->retTypes); + if (!retTypes) + return std::nullopt; + + FunctionTypeVar result{*argTypes, *retTypes}; + result.generics = hftv->generics; + result.genericPacks = hftv->genericPacks; + return arena->addType(std::move(result)); +} + +void Normalizer::intersectFunctionsWithFunction(NormalizedFunctionType& heres, TypeId there) +{ + if (!heres) + return; + + for (auto it = heres->begin(); it != heres->end();) + { + TypeId here = *it; + if (get(here)) + it++; + else if (std::optional tmp = intersectionOfFunctions(here, there)) + { + heres->erase(it); + heres->insert(*tmp); + return; + } + else + it++; + } + + TypeIds tmps; + for (TypeId here : *heres) + { + if (std::optional tmp = unionSaturatedFunctions(here, there)) + tmps.insert(*tmp); + } + heres->insert(there); + heres->insert(tmps.begin(), tmps.end()); +} + +void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const NormalizedFunctionType& theres) +{ + if (!heres) + return; + else if (!theres) + { + heres = std::nullopt; + return; + } + else + { + for (TypeId there : *theres) + intersectFunctionsWithFunction(heres, there); + } +} + +bool Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there) +{ + for (auto it = here.begin(); it != here.end();) + { + NormalizedType& inter = *it->second; + if (!intersectNormalWithTy(inter, there)) + return false; + if (isInhabited(inter)) + ++it; + else + it = here.erase(it); + } + return true; +} + +// See above for an explaination of `ignoreSmallerTyvars`. +bool Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) +{ + if (!get(there.tops)) + { + here.tops = intersectionOfTops(here.tops, there.tops); + return true; + } + else if (!get(here.tops)) + { + clearNormal(here); + return unionNormals(here, there, ignoreSmallerTyvars); + } + + here.booleans = intersectionOfBools(here.booleans, there.booleans); + intersectClasses(here.classes, there.classes); + here.errors = (get(there.errors) ? there.errors : here.errors); + here.nils = (get(there.nils) ? there.nils : here.nils); + here.numbers = (get(there.numbers) ? there.numbers : here.numbers); + intersectStrings(here.strings, there.strings); + here.threads = (get(there.threads) ? there.threads : here.threads); + intersectFunctions(here.functions, there.functions); + intersectTables(here.tables, there.tables); + + for (auto& [tyvar, inter] : there.tyvars) + { + int index = tyvarIndex(tyvar); + if (ignoreSmallerTyvars < index) + { + auto [found, fresh] = here.tyvars.emplace(tyvar, std::make_unique(NormalizedType{singletonTypes})); + if (fresh) + { + if (!unionNormals(*found->second, here, index)) + return false; + } + } + } + for (auto it = here.tyvars.begin(); it != here.tyvars.end();) + { + TypeId tyvar = it->first; + NormalizedType& inter = *it->second; + int index = tyvarIndex(tyvar); + LUAU_ASSERT(ignoreSmallerTyvars < index); + auto found = there.tyvars.find(tyvar); + if (found == there.tyvars.end()) + { + if (!intersectNormals(inter, there, index)) + return false; + } + else + { + if (!intersectNormals(inter, *found->second, index)) + return false; + } + if (isInhabited(inter)) + it++; + else + it = here.tyvars.erase(it); + } + return true; +} + +bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) +{ + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (!withinResourceLimits()) + return false; + + there = follow(there); + if (get(there) || get(there)) + { + here.tops = intersectionOfTops(here.tops, there); + return true; + } + else if (!get(here.tops)) + { + clearNormal(here); + return unionNormalWithTy(here, there); + } + else if (const UnionTypeVar* utv = get(there)) + { + NormalizedType norm{singletonTypes}; + for (UnionTypeVarIterator it = begin(utv); it != end(utv); ++it) + if (!unionNormalWithTy(norm, *it)) + return false; + return intersectNormals(here, norm); + } + else if (const IntersectionTypeVar* itv = get(there)) + { + for (IntersectionTypeVarIterator it = begin(itv); it != end(itv); ++it) + if (!intersectNormalWithTy(here, *it)) + return false; + return true; + } + else if (get(there) || get(there)) + { + NormalizedType thereNorm{singletonTypes}; + NormalizedType topNorm{singletonTypes}; + topNorm.tops = singletonTypes->unknownType; + thereNorm.tyvars.insert_or_assign(there, std::make_unique(std::move(topNorm))); + return intersectNormals(here, thereNorm); + } + + NormalizedTyvars tyvars = std::move(here.tyvars); + + if (const FunctionTypeVar* utv = get(there)) + { + NormalizedFunctionType functions = std::move(here.functions); + clearNormal(here); + intersectFunctionsWithFunction(functions, there); + here.functions = std::move(functions); + } + else if (get(there) || get(there)) + { + TypeIds tables = std::move(here.tables); + clearNormal(here); + intersectTablesWithTable(tables, there); + here.tables = std::move(tables); + } + else if (get(there)) + { + TypeIds classes = std::move(here.classes); + clearNormal(here); + intersectClassesWithClass(classes, there); + here.classes = std::move(classes); + } + else if (get(there)) + { + TypeId errors = here.errors; + clearNormal(here); + here.errors = errors; + } + else if (const PrimitiveTypeVar* ptv = get(there)) + { + TypeId booleans = here.booleans; + TypeId nils = here.nils; + TypeId numbers = here.numbers; + NormalizedStringType strings = std::move(here.strings); + TypeId threads = here.threads; + + clearNormal(here); + + if (ptv->type == PrimitiveTypeVar::Boolean) + here.booleans = booleans; + else if (ptv->type == PrimitiveTypeVar::NilType) + here.nils = nils; + else if (ptv->type == PrimitiveTypeVar::Number) + here.numbers = numbers; + else if (ptv->type == PrimitiveTypeVar::String) + here.strings = std::move(strings); + else if (ptv->type == PrimitiveTypeVar::Thread) + here.threads = threads; + else + LUAU_ASSERT(!"Unreachable"); + } + else if (const SingletonTypeVar* stv = get(there)) + { + TypeId booleans = here.booleans; + NormalizedStringType strings = std::move(here.strings); + + clearNormal(here); + + if (get(stv)) + here.booleans = intersectionOfBools(booleans, there); + else if (const StringSingleton* sstv = get(stv)) + { + if (!strings || strings->count(sstv->value)) + here.strings->insert({sstv->value, there}); + } + else + LUAU_ASSERT(!"Unreachable"); + } + else + LUAU_ASSERT(!"Unreachable"); + + if (!intersectTyvarsWithTy(tyvars, there)) + return false; + here.tyvars = std::move(tyvars); + + return true; +} + +// -------- Convert back from a normalized type to a type +TypeId Normalizer::typeFromNormal(const NormalizedType& norm) +{ + assertInvariant(norm); + if (!get(norm.tops)) + return norm.tops; + + std::vector result; + + if (!get(norm.booleans)) + result.push_back(norm.booleans); + result.insert(result.end(), norm.classes.begin(), norm.classes.end()); + if (!get(norm.errors)) + result.push_back(norm.errors); + if (norm.functions) + { + if (norm.functions->size() == 1) + result.push_back(*norm.functions->begin()); + else + { + std::vector parts; + parts.insert(parts.end(), norm.functions->begin(), norm.functions->end()); + result.push_back(arena->addType(IntersectionTypeVar{std::move(parts)})); + } + } + if (!get(norm.nils)) + result.push_back(norm.nils); + if (!get(norm.numbers)) + result.push_back(norm.numbers); + if (norm.strings) + for (auto& [_, ty] : *norm.strings) + result.push_back(ty); + else + result.push_back(singletonTypes->stringType); + result.insert(result.end(), norm.tables.begin(), norm.tables.end()); + for (auto& [tyvar, intersect] : norm.tyvars) + { + if (get(intersect->tops)) + { + TypeId ty = typeFromNormal(*intersect); + result.push_back(arena->addType(IntersectionTypeVar{{tyvar, ty}})); + } + else + result.push_back(tyvar); + } + + if (result.size() == 0) + return singletonTypes->neverType; + else if (result.size() == 1) + return result[0]; + else + return arena->addType(UnionTypeVar{std::move(result)}); +} namespace { @@ -59,7 +1748,8 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull scope, N { UnifierSharedState sharedState{&ice}; TypeArena arena; - Unifier u{&arena, singletonTypes, Mode::Strict, scope, Location{}, Covariant, sharedState}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; u.anyIsTop = anyIsTop; u.tryUnify(subPack, superPack); @@ -686,3 +2377,4 @@ std::pair normalize(TypePackId tp, const ModulePtr& module, No } } // namespace Luau + diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index f98a212..b2f3cfd 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -280,7 +280,8 @@ struct TypeChecker2 TypePackId actualRetType = reconstructPack(ret->list, arena); UnifierSharedState sharedState{&ice}; - Unifier u{&arena, singletonTypes, Mode::Strict, stack.back(), ret->location, Covariant, sharedState}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, stack.back(), ret->location, Covariant}; u.anyIsTop = true; u.tryUnify(actualRetType, expectedRetType); @@ -1206,7 +1207,8 @@ struct TypeChecker2 ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy) { UnifierSharedState sharedState{&ice}; - Unifier u{&module->internalTypes, singletonTypes, Mode::Strict, scope, location, Covariant, sharedState}; + Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; u.anyIsTop = true; u.tryUnify(subTy, superTy); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index b96046b..cb21aa7 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,19 +32,21 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) +LUAU_FASTFLAG(LuauTypeNormalization2) LUAU_FASTFLAGVARIABLE(LuauFunctionArgMismatchDetails, false) -LUAU_FASTFLAGVARIABLE(LuauInplaceDemoteSkipAllBound, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) -LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix3, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) +LUAU_FASTFLAGVARIABLE(LuauAnyifyModuleReturnGenerics, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAGVARIABLE(LuauCallUnifyPackTails, false) LUAU_FASTFLAGVARIABLE(LuauCheckGenericHOFTypes, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) +LUAU_FASTFLAGVARIABLE(LuauFixVarargExprHeadType, false) LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) +LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) LUAU_FASTFLAGVARIABLE(LuauUnionOfTypesFollow, false) LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) @@ -255,6 +257,7 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, NotNull singl , singletonTypes(singletonTypes) , iceHandler(iceHandler) , unifierState(iceHandler) + , normalizer(nullptr, singletonTypes, NotNull{&unifierState}) , nilType(singletonTypes->nilType) , numberType(singletonTypes->numberType) , stringType(singletonTypes->stringType) @@ -301,12 +304,13 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker"); LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); - currentModule.reset(new Module()); + currentModule.reset(new Module); currentModule->type = module.type; currentModule->allocator = module.allocator; currentModule->names = module.names; iceHandler->moduleName = module.name; + normalizer.arena = ¤tModule->internalTypes; if (FFlag::LuauAutocompleteDynamicLimits) { @@ -351,15 +355,23 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo if (get(follow(moduleScope->returnType))) moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); else - { moduleScope->returnType = anyify(moduleScope, moduleScope->returnType, Location{}); - } + + if (FFlag::LuauAnyifyModuleReturnGenerics) + moduleScope->returnType = anyifyModuleReturnTypePackGenerics(moduleScope->returnType); for (auto& [_, typeFun] : moduleScope->exportedTypeBindings) typeFun.type = anyify(moduleScope, typeFun.type, Location{}); prepareErrorsForDisplay(currentModule->errors); + if (FFlag::LuauTypeNormalization2) + { + // Clear the normalizer caches, since they contain types from the internal type surface + normalizer.clearCaches(); + normalizer.arena = nullptr; + } + currentModule->clonePublicInterface(singletonTypes, *iceHandler); // Clear unifier cache since it's keyed off internal types that get deallocated @@ -474,7 +486,7 @@ struct InplaceDemoter : TypeVarOnceVisitor TypeArena* arena; InplaceDemoter(TypeLevel level, TypeArena* arena) - : TypeVarOnceVisitor(/* skipBoundTypes= */ FFlag::LuauInplaceDemoteSkipAllBound) + : TypeVarOnceVisitor(/* skipBoundTypes= */ true) , newLevel(level) , arena(arena) { @@ -494,12 +506,6 @@ struct InplaceDemoter : TypeVarOnceVisitor return false; } - bool visit(TypeId ty, const BoundTypeVar& btyRef) override - { - LUAU_ASSERT(!FFlag::LuauInplaceDemoteSkipAllBound); - return true; - } - bool visit(TypeId ty) override { if (ty->owningArena != arena) @@ -1029,8 +1035,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) if (right) { - if (!maybeGeneric(left) && isGeneric(right)) - right = instantiate(scope, right, loc); + if (!FFlag::LuauInstantiateInSubtyping) + { + if (!maybeGeneric(left) && isGeneric(right)) + right = instantiate(scope, right, loc); + } // Setting a table entry to nil doesn't mean nil is the type of the indexer, it is just deleting the entry const TableTypeVar* destTableTypeReceivingNil = nullptr; @@ -1104,7 +1113,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) variableTypes.push_back(ty); expectedTypes.push_back(ty); - instantiateGenerics.push_back(annotation != nullptr && !maybeGeneric(ty)); + // with FFlag::LuauInstantiateInSubtyping enabled, we shouldn't need to produce instantiateGenerics at all. + if (!FFlag::LuauInstantiateInSubtyping) + instantiateGenerics.push_back(annotation != nullptr && !maybeGeneric(ty)); } if (local.values.size > 0) @@ -1729,9 +1740,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar { ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); - - if (FFlag::LuauSelfCallAutocompleteFix3) - ftv->hasSelf = true; + ftv->hasSelf = true; } } @@ -1905,8 +1914,18 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (get(varargPack)) { - std::vector types = flatten(varargPack).first; - return {!types.empty() ? types[0] : nilType}; + if (FFlag::LuauFixVarargExprHeadType) + { + if (std::optional ty = first(varargPack)) + return {*ty}; + + return {nilType}; + } + else + { + std::vector types = flatten(varargPack).first; + return {!types.empty() ? types[0] : nilType}; + } } else if (get(varargPack)) { @@ -3967,7 +3986,10 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam } else { - unifyWithInstantiationIfNeeded(*argIter, *paramIter, scope, state); + if (FFlag::LuauInstantiateInSubtyping) + state.tryUnify(*argIter, *paramIter, /*isFunctionCall*/ false); + else + unifyWithInstantiationIfNeeded(*argIter, *paramIter, scope, state); ++argIter; ++paramIter; } @@ -4523,8 +4545,11 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons TypeId actualType = substituteFreeForNil && expr->is() ? freshType(scope) : type; - if (instantiateGenerics.size() > i && instantiateGenerics[i]) - actualType = instantiate(scope, actualType, expr->location); + if (!FFlag::LuauInstantiateInSubtyping) + { + if (instantiateGenerics.size() > i && instantiateGenerics[i]) + actualType = instantiate(scope, actualType, expr->location); + } if (expectedType) { @@ -4686,6 +4711,8 @@ bool TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, c void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, Unifier& state) { + LUAU_ASSERT(!FFlag::LuauInstantiateInSubtyping); + if (!maybeGeneric(subTy)) // Quick check to see if we definitely can't instantiate state.tryUnify(subTy, superTy, /*isFunctionCall*/ false); @@ -4828,6 +4855,33 @@ TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location lo } } +TypePackId TypeChecker::anyifyModuleReturnTypePackGenerics(TypePackId tp) +{ + tp = follow(tp); + + if (const VariadicTypePack* vtp = get(tp)) + return get(vtp->ty) ? anyTypePack : tp; + + if (!get(follow(tp))) + return tp; + + std::vector resultTypes; + std::optional resultTail; + + TypePackIterator it = begin(tp); + + for (TypePackIterator e = end(tp); it != e; ++it) + { + TypeId ty = follow(*it); + resultTypes.push_back(get(ty) ? anyType : ty); + } + + if (std::optional tail = it.tail()) + resultTail = anyifyModuleReturnTypePackGenerics(*tail); + + return addTypePack(resultTypes, resultTail); +} + void TypeChecker::reportError(const TypeError& error) { if (currentModule->mode == Mode::NoCheck) @@ -4955,8 +5009,7 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) Unifier TypeChecker::mkUnifier(const ScopePtr& scope, const Location& location) { - return Unifier{ - ¤tModule->internalTypes, singletonTypes, currentModule->mode, NotNull{scope.get()}, location, Variance::Covariant, unifierState}; + return Unifier{NotNull{&normalizer}, currentModule->mode, NotNull{scope.get()}, location, Variance::Covariant}; } TypeId TypeChecker::freshType(const ScopePtr& scope) diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index bf6bf34..b143268 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -27,6 +27,7 @@ LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) LUAU_FASTFLAGVARIABLE(LuauStringFormatArgumentErrorFix, false) LUAU_FASTFLAGVARIABLE(LuauNoMoreGlobalSingletonTypes, false) +LUAU_FASTFLAG(LuauInstantiateInSubtyping) namespace Luau { @@ -339,6 +340,8 @@ bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub) // then instantiate U if `isGeneric(U)` is true, and `maybeGeneric(T)` is false. bool isGeneric(TypeId ty) { + LUAU_ASSERT(!FFlag::LuauInstantiateInSubtyping); + ty = follow(ty); if (auto ftv = get(ty)) return ftv->generics.size() > 0 || ftv->genericPacks.size() > 0; @@ -350,6 +353,8 @@ bool isGeneric(TypeId ty) bool maybeGeneric(TypeId ty) { + LUAU_ASSERT(!FFlag::LuauInstantiateInSubtyping); + if (FFlag::LuauMaybeGenericIntersectionTypes) { ty = follow(ty); diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index a3d4540..e0cc141 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -1,60 +1,64 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Unifiable.h" +LUAU_FASTFLAG(LuauTypeNormalization2); + namespace Luau { namespace Unifiable { +static int nextIndex = 0; + Free::Free(TypeLevel level) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , level(level) { } Free::Free(Scope* scope) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , scope(scope) { } Free::Free(Scope* scope, TypeLevel level) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , level(level) , scope(scope) { } -int Free::nextIndex = 0; +int Free::DEPRECATED_nextIndex = 0; Generic::Generic() - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , name("g" + std::to_string(index)) { } Generic::Generic(TypeLevel level) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , level(level) , name("g" + std::to_string(index)) { } Generic::Generic(const Name& name) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , name(name) , explicitName(true) { } Generic::Generic(Scope* scope) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , scope(scope) { } Generic::Generic(TypeLevel level, const Name& name) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , level(level) , name(name) , explicitName(true) @@ -62,14 +66,14 @@ Generic::Generic(TypeLevel level, const Name& name) } Generic::Generic(Scope* scope, const Name& name) - : index(++nextIndex) + : index(FFlag::LuauTypeNormalization2 ? ++nextIndex : ++DEPRECATED_nextIndex) , scope(scope) , name(name) , explicitName(true) { } -int Generic::nextIndex = 0; +int Generic::DEPRECATED_nextIndex = 0; Error::Error() : index(++nextIndex) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index c13a6f8..5a01c93 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -2,6 +2,7 @@ #include "Luau/Unifier.h" #include "Luau/Common.h" +#include "Luau/Instantiation.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/TypePack.h" @@ -20,7 +21,9 @@ LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000); LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) +LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(LuauCallUnifyPackTails) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) @@ -343,17 +346,19 @@ static bool subsumes(bool useScopes, TY_A* left, TY_B* right) return left->level.subsumes(right->level); } -Unifier::Unifier(TypeArena* types, NotNull singletonTypes, Mode mode, NotNull scope, const Location& location, - Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) - : types(types) - , singletonTypes(singletonTypes) +Unifier::Unifier(NotNull normalizer, Mode mode, NotNull scope, const Location& location, + Variance variance, TxnLog* parentLog) + : types(normalizer->arena) + , singletonTypes(normalizer->singletonTypes) + , normalizer(normalizer) , mode(mode) , scope(scope) , log(parentLog) , location(location) , variance(variance) - , sharedState(sharedState) + , sharedState(*normalizer->sharedState) { + normalize = FFlag::LuauSubtypeNormalizer; LUAU_ASSERT(sharedState.iceHandler); } @@ -524,7 +529,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyUnionWithType(subTy, subUnion, superTy); } - else if (const UnionTypeVar* uv = log.getMutable(superTy)) + else if (const UnionTypeVar* uv = (FFlag::LuauSubtypeNormalizer? nullptr: log.getMutable(superTy))) { tryUnifyTypeWithUnion(subTy, superTy, uv, cacheEnabled, isFunctionCall); } @@ -532,6 +537,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyTypeWithIntersection(subTy, superTy, uv); } + else if (const UnionTypeVar* uv = log.getMutable(superTy)) + { + tryUnifyTypeWithUnion(subTy, superTy, uv, cacheEnabled, isFunctionCall); + } else if (const IntersectionTypeVar* uv = log.getMutable(subTy)) { tryUnifyIntersectionWithType(subTy, uv, superTy, cacheEnabled, isFunctionCall); @@ -585,7 +594,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* subUnion, TypeId superTy) { - // A | B <: T if A <: T and B <: T + // A | B <: T if and only if A <: T and B <: T bool failed = false; std::optional unificationTooComplex; std::optional firstFailedOption; @@ -715,6 +724,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { TypeId type = uv->options[(i + startIndex) % uv->options.size()]; Unifier innerState = makeChildUnifier(); + innerState.normalize = false; innerState.tryUnify_(subTy, type, isFunctionCall); if (innerState.errors.empty()) @@ -741,6 +751,20 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { reportError(*unificationTooComplex); } + else if (!found && normalize) + { + // It is possible that T <: A | B even though T normalize(subTy); + const NormalizedType* superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + reportError(TypeError{location, UnificationTooComplex{}}); + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + else + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + } else if (!found) { if ((failedOptionCount == 1 || foundHeuristic) && failedOption) @@ -755,7 +779,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I std::optional unificationTooComplex; std::optional firstFailedOption; - // T <: A & B if T <: A and T <: B + // T <: A & B if and only if T <: A and T <: B for (TypeId type : uv->parts) { Unifier innerState = makeChildUnifier(); @@ -806,6 +830,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV { TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; Unifier innerState = makeChildUnifier(); + innerState.normalize = false; innerState.tryUnify_(type, superTy, isFunctionCall); if (innerState.errors.empty()) @@ -822,12 +847,207 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV if (unificationTooComplex) reportError(*unificationTooComplex); + else if (!found && normalize) + { + // It is possible that A & B <: T even though A normalize(subTy); + const NormalizedType* superNorm = normalizer->normalize(superTy); + if (subNorm && superNorm) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); + else + reportError(TypeError{location, UnificationTooComplex{}}); + } else if (!found) { reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); } } +void Unifier::tryUnifyNormalizedTypes(TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional error) +{ + LUAU_ASSERT(FFlag::LuauSubtypeNormalizer); + + if (get(superNorm.tops) || get(superNorm.tops) || get(subNorm.tops)) + return; + else if (get(subNorm.tops)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + + if (get(subNorm.errors)) + if (!get(superNorm.errors)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + + if (get(subNorm.booleans)) + { + if (!get(superNorm.booleans)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + } + else if (const SingletonTypeVar* stv = get(subNorm.booleans)) + { + if (!get(superNorm.booleans) && stv != get(superNorm.booleans)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + } + + if (get(subNorm.nils)) + if (!get(superNorm.nils)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + + if (get(subNorm.numbers)) + if (!get(superNorm.numbers)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + + if (subNorm.strings && superNorm.strings) + { + for (auto [name, ty] : *subNorm.strings) + if (!superNorm.strings->count(name)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + } + else if (!subNorm.strings && superNorm.strings) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + + if (get(subNorm.threads)) + if (!get(superNorm.errors)) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + + for (TypeId subClass : subNorm.classes) + { + bool found = false; + const ClassTypeVar* subCtv = get(subClass); + for (TypeId superClass : superNorm.classes) + { + const ClassTypeVar* superCtv = get(superClass); + if (isSubclass(subCtv, superCtv)) + { + found = true; + break; + } + } + if (!found) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + } + + for (TypeId subTable : subNorm.tables) + { + bool found = false; + for (TypeId superTable : superNorm.tables) + { + Unifier innerState = makeChildUnifier(); + if (get(superTable)) + innerState.tryUnifyWithMetatable(subTable, superTable, /* reversed */ false); + else if (get(subTable)) + innerState.tryUnifyWithMetatable(superTable, subTable, /* reversed */ true); + else + innerState.tryUnifyTables(subTable, superTable); + if (innerState.errors.empty()) + { + found = true; + log.concat(std::move(innerState.log)); + break; + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + return reportError(*e); + } + if (!found) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + } + + if (subNorm.functions) + { + if (!superNorm.functions) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + if (superNorm.functions->empty()) + return; + for (TypeId superFun : *superNorm.functions) + { + Unifier innerState = makeChildUnifier(); + const FunctionTypeVar* superFtv = get(superFun); + if (!superFtv) + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); + innerState.tryUnify_(tgt, superFtv->retTypes); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else if (auto e = hasUnificationTooComplex(innerState.errors)) + return reportError(*e); + else + return reportError(TypeError{location, TypeMismatch{superTy, subTy, reason, error}}); + } + } + + for (auto& [tyvar, subIntersect] : subNorm.tyvars) + { + auto found = superNorm.tyvars.find(tyvar); + if (found == superNorm.tyvars.end()) + tryUnifyNormalizedTypes(subTy, superTy, *subIntersect, superNorm, reason, error); + else + tryUnifyNormalizedTypes(subTy, superTy, *subIntersect, *found->second, reason, error); + if (!errors.empty()) + return; + } +} + +TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const NormalizedFunctionType& overloads, TypePackId args) +{ + if (!overloads || overloads->empty()) + { + reportError(TypeError{location, CannotCallNonFunction{function}}); + return singletonTypes->errorRecoveryTypePack(); + } + + std::optional result; + const FunctionTypeVar* firstFun = nullptr; + for (TypeId overload : *overloads) + { + if (const FunctionTypeVar* ftv = get(overload)) + { + // TODO: instantiate generics? + if (ftv->generics.empty() && ftv->genericPacks.empty()) + { + if (!firstFun) + firstFun = ftv; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(args, ftv->argTypes); + if (innerState.errors.empty()) + { + log.concat(std::move(innerState.log)); + if (result) + { + // Annoyingly, since we don't support intersection of generic type packs, + // the intersection may fail. We rather arbitrarily use the first matching overload + // in that case. + if (std::optional intersect = normalizer->intersectionOfTypePacks(*result, ftv->retTypes)) + result = intersect; + } + else + result = ftv->retTypes; + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + { + reportError(*e); + return singletonTypes->errorRecoveryTypePack(args); + } + } + } + } + + if (result) + return *result; + else if (firstFun) + { + // TODO: better error reporting? + // The logic for error reporting overload resolution + // is currently over in TypeInfer.cpp, should we move it? + reportError(TypeError{location, GenericError{"No matching overload."}}); + return singletonTypes->errorRecoveryTypePack(firstFun->retTypes); + } + else + { + reportError(TypeError{location, CannotCallNonFunction{function}}); + return singletonTypes->errorRecoveryTypePack(); + } +} + bool Unifier::canCacheResult(TypeId subTy, TypeId superTy) { bool* superTyInfo = sharedState.skipCacheForType.find(superTy); @@ -1253,14 +1473,38 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal ice("passed non-function types to unifyFunction"); size_t numGenerics = superFunction->generics.size(); - if (numGenerics != subFunction->generics.size()) + size_t numGenericPacks = superFunction->genericPacks.size(); + + bool shouldInstantiate = (numGenerics == 0 && subFunction->generics.size() > 0) || (numGenericPacks == 0 && subFunction->genericPacks.size() > 0); + + if (FFlag::LuauInstantiateInSubtyping && variance == Covariant && shouldInstantiate) + { + Instantiation instantiation{&log, types, scope->level, scope}; + + std::optional instantiated = instantiation.substitute(subTy); + if (instantiated.has_value()) + { + subFunction = log.getMutable(*instantiated); + + if (!subFunction) + ice("instantiation made a function type into a non-function type in unifyFunction"); + + numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); + numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); + + } + else + { + reportError(TypeError{location, UnificationTooComplex{}}); + } + } + else if (numGenerics != subFunction->generics.size()) { numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); } - size_t numGenericPacks = superFunction->genericPacks.size(); if (numGenericPacks != subFunction->genericPacks.size()) { numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); @@ -1376,6 +1620,27 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) std::vector missingProperties; std::vector extraProperties; + if (FFlag::LuauInstantiateInSubtyping) + { + if (variance == Covariant && subTable->state == TableState::Generic && superTable->state != TableState::Generic) + { + Instantiation instantiation{&log, types, subTable->level, scope}; + + std::optional instantiated = instantiation.substitute(subTy); + if (instantiated.has_value()) + { + subTable = log.getMutable(*instantiated); + + if (!subTable) + ice("instantiation made a table type into a non-table type in tryUnifyTables"); + } + else + { + reportError(TypeError{location, UnificationTooComplex{}}); + } + } + } + // Optimization: First test that the property sets are compatible without doing any recursive unification if (!subTable->indexer && subTable->state != TableState::Free) { @@ -2344,8 +2609,9 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { - Unifier u = Unifier{types, singletonTypes, mode, scope, location, variance, sharedState, &log}; + Unifier u = Unifier{normalizer, mode, scope, location, variance, &log}; u.anyIsTop = anyIsTop; + u.normalize = normalize; return u; } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index cf3eaaa..c20c084 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -25,6 +25,8 @@ LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauTypeAnnotationLocationChange, false) +LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false) + bool lua_telemetry_parsed_out_of_range_bin_integer = false; bool lua_telemetry_parsed_out_of_range_hex_integer = false; bool lua_telemetry_parsed_double_prefix_hex_integer = false; @@ -1062,6 +1064,12 @@ void Parser::parseExprList(TempVector& result) { nextLexeme(); + if (FFlag::LuauCommaParenWarnings && lexer.current().type == ')') + { + report(lexer.current().location, "Expected expression after ',' but got ')' instead"); + break; + } + result.push_back(parseExpr()); } } @@ -1148,7 +1156,14 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector, AstArray> Parser::parseG } if (lexer.current().type == ',') + { nextLexeme(); + + if (FFlag::LuauCommaParenWarnings && lexer.current().type == '>') + { + report(lexer.current().location, "Expected type after ',' but got '>' instead"); + break; + } + } else break; } diff --git a/CLI/Profiler.cpp b/CLI/Profiler.cpp index 30a171f..d3ad4e9 100644 --- a/CLI/Profiler.cpp +++ b/CLI/Profiler.cpp @@ -82,11 +82,13 @@ static void profilerLoop() if (now - last >= 1.0 / double(gProfiler.frequency)) { - gProfiler.ticks += uint64_t((now - last) * 1e6); + int64_t ticks = int64_t((now - last) * 1e6); + + gProfiler.ticks += ticks; gProfiler.samples++; gProfiler.callbacks->interrupt = profilerTrigger; - last = now; + last += ticks * 1e-6; } else { diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index cb799d3..15db7a1 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -152,6 +152,7 @@ private: void placeModRegMem(OperandX64 rhs, uint8_t regop); void placeRex(RegisterX64 op); void placeRex(OperandX64 op); + void placeRexNoW(OperandX64 op); void placeRex(RegisterX64 lhs, OperandX64 rhs); void placeVex(OperandX64 dst, OperandX64 src1, OperandX64 src2, bool setW, uint8_t mode, uint8_t prefix); void placeImm8Or32(int32_t imm); diff --git a/CodeGen/include/Luau/CodeAllocator.h b/CodeGen/include/Luau/CodeAllocator.h index 01e1312..2ea6463 100644 --- a/CodeGen/include/Luau/CodeAllocator.h +++ b/CodeGen/include/Luau/CodeAllocator.h @@ -24,13 +24,15 @@ struct CodeAllocator void* context = nullptr; // Called when new block is created to create and setup the unwinding information for all the code in the block - // If data is placed inside the block itself (some platforms require this), we also return 'unwindDataSizeInBlock' - void* (*createBlockUnwindInfo)(void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock) = nullptr; + // 'startOffset' reserves space for data at the beginning of the page + void* (*createBlockUnwindInfo)(void* context, uint8_t* block, size_t blockSize, size_t& startOffset) = nullptr; // Called to destroy unwinding information returned by 'createBlockUnwindInfo' void (*destroyBlockUnwindInfo)(void* context, void* unwindData) = nullptr; - static const size_t kMaxUnwindDataSize = 128; + // Unwind information can be placed inside the block with some implementation-specific reservations at the beginning + // But to simplify block space checks, we limit the max size of all that data + static const size_t kMaxReservedDataSize = 256; bool allocateNewBlock(size_t& unwindInfoSize); diff --git a/CodeGen/include/Luau/CodeBlockUnwind.h b/CodeGen/include/Luau/CodeBlockUnwind.h index ddae33a..0f7af3a 100644 --- a/CodeGen/include/Luau/CodeBlockUnwind.h +++ b/CodeGen/include/Luau/CodeBlockUnwind.h @@ -10,7 +10,7 @@ namespace CodeGen { // context must be an UnwindBuilder -void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock); +void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& startOffset); void destroyBlockUnwindInfo(void* context, void* unwindData); } // namespace CodeGen diff --git a/CodeGen/include/Luau/UnwindBuilder.h b/CodeGen/include/Luau/UnwindBuilder.h index c6f611b..b723731 100644 --- a/CodeGen/include/Luau/UnwindBuilder.h +++ b/CodeGen/include/Luau/UnwindBuilder.h @@ -14,7 +14,10 @@ namespace CodeGen class UnwindBuilder { public: - virtual ~UnwindBuilder() {} + virtual ~UnwindBuilder() = default; + + virtual void setBeginOffset(size_t beginOffset) = 0; + virtual size_t getBeginOffset() const = 0; virtual void start() = 0; diff --git a/CodeGen/include/Luau/UnwindBuilderDwarf2.h b/CodeGen/include/Luau/UnwindBuilderDwarf2.h index 25dbc55..dab6e95 100644 --- a/CodeGen/include/Luau/UnwindBuilderDwarf2.h +++ b/CodeGen/include/Luau/UnwindBuilderDwarf2.h @@ -12,6 +12,9 @@ namespace CodeGen class UnwindBuilderDwarf2 : public UnwindBuilder { public: + void setBeginOffset(size_t beginOffset) override; + size_t getBeginOffset() const override; + void start() override; void spill(int espOffset, RegisterX64 reg) override; @@ -26,6 +29,8 @@ public: void finalize(char* target, void* funcAddress, size_t funcSize) const override; private: + size_t beginOffset = 0; + static const unsigned kRawDataLimit = 128; uint8_t rawData[kRawDataLimit]; uint8_t* pos = rawData; diff --git a/CodeGen/include/Luau/UnwindBuilderWin.h b/CodeGen/include/Luau/UnwindBuilderWin.h index 801eb6e..0051377 100644 --- a/CodeGen/include/Luau/UnwindBuilderWin.h +++ b/CodeGen/include/Luau/UnwindBuilderWin.h @@ -22,6 +22,9 @@ struct UnwindCodeWin class UnwindBuilderWin : public UnwindBuilder { public: + void setBeginOffset(size_t beginOffset) override; + size_t getBeginOffset() const override; + void start() override; void spill(int espOffset, RegisterX64 reg) override; @@ -36,6 +39,8 @@ public: void finalize(char* target, void* funcAddress, size_t funcSize) const override; private: + size_t beginOffset = 0; + // Windows unwind codes are written in reverse, so we have to collect them all first std::vector unwindCodes; diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 32325b0..cd3079a 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -354,10 +354,15 @@ void AssemblyBuilderX64::jmp(Label& label) void AssemblyBuilderX64::jmp(OperandX64 op) { + LUAU_ASSERT((op.cat == CategoryX64::reg ? op.base.size : op.memSize) == SizeX64::qword); + if (logText) log("jmp", op); - placeRex(op); + // Indirect absolute calls always work in 64 bit width mode, so REX.W is optional + // While we could keep an optional prefix, in Windows x64 ABI it signals a tail call return statement to the unwinder + placeRexNoW(op); + place(0xff); placeModRegMem(op, 4); commit(); @@ -376,10 +381,14 @@ void AssemblyBuilderX64::call(Label& label) void AssemblyBuilderX64::call(OperandX64 op) { + LUAU_ASSERT((op.cat == CategoryX64::reg ? op.base.size : op.memSize) == SizeX64::qword); + if (logText) log("call", op); - placeRex(op); + // Indirect absolute calls always work in 64 bit width mode, so REX.W is optional + placeRexNoW(op); + place(0xff); placeModRegMem(op, 2); commit(); @@ -838,6 +847,21 @@ void AssemblyBuilderX64::placeRex(OperandX64 op) place(code | 0x40); } +void AssemblyBuilderX64::placeRexNoW(OperandX64 op) +{ + uint8_t code = 0; + + if (op.cat == CategoryX64::reg) + code = REX_B(op.base); + else if (op.cat == CategoryX64::mem) + code = REX_X(op.index) | REX_B(op.base); + else + LUAU_ASSERT(!"No encoding for left operand of this category"); + + if (code != 0) + place(code | 0x40); +} + void AssemblyBuilderX64::placeRex(RegisterX64 lhs, OperandX64 rhs) { uint8_t code = REX_W(lhs.size == SizeX64::qword); diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp index aacf40a..b3787d1 100644 --- a/CodeGen/src/CodeAllocator.cpp +++ b/CodeGen/src/CodeAllocator.cpp @@ -91,7 +91,7 @@ CodeAllocator::CodeAllocator(size_t blockSize, size_t maxTotalSize) : blockSize(blockSize) , maxTotalSize(maxTotalSize) { - LUAU_ASSERT(blockSize > kMaxUnwindDataSize); + LUAU_ASSERT(blockSize > kMaxReservedDataSize); LUAU_ASSERT(maxTotalSize >= blockSize); } @@ -116,15 +116,15 @@ bool CodeAllocator::allocate( size_t totalSize = alignedDataSize + codeSize; // Function has to fit into a single block with unwinding information - if (totalSize > blockSize - kMaxUnwindDataSize) + if (totalSize > blockSize - kMaxReservedDataSize) return false; - size_t unwindInfoSize = 0; + size_t startOffset = 0; // We might need a new block if (totalSize > size_t(blockEnd - blockPos)) { - if (!allocateNewBlock(unwindInfoSize)) + if (!allocateNewBlock(startOffset)) return false; LUAU_ASSERT(totalSize <= size_t(blockEnd - blockPos)); @@ -132,20 +132,20 @@ bool CodeAllocator::allocate( LUAU_ASSERT((uintptr_t(blockPos) & (kPageSize - 1)) == 0); // Allocation starts on page boundary - size_t dataOffset = unwindInfoSize + alignedDataSize - dataSize; - size_t codeOffset = unwindInfoSize + alignedDataSize; + size_t dataOffset = startOffset + alignedDataSize - dataSize; + size_t codeOffset = startOffset + alignedDataSize; if (dataSize) memcpy(blockPos + dataOffset, data, dataSize); if (codeSize) memcpy(blockPos + codeOffset, code, codeSize); - size_t pageAlignedSize = alignToPageSize(unwindInfoSize + totalSize); + size_t pageAlignedSize = alignToPageSize(startOffset + totalSize); makePagesExecutable(blockPos, pageAlignedSize); flushInstructionCache(blockPos + codeOffset, codeSize); - result = blockPos + unwindInfoSize; + result = blockPos + startOffset; resultSize = totalSize; resultCodeStart = blockPos + codeOffset; @@ -190,7 +190,7 @@ bool CodeAllocator::allocateNewBlock(size_t& unwindInfoSize) // 'Round up' to preserve 16 byte alignment of the following data and code unwindInfoSize = (unwindInfoSize + 15) & ~15; - LUAU_ASSERT(unwindInfoSize <= kMaxUnwindDataSize); + LUAU_ASSERT(unwindInfoSize <= kMaxReservedDataSize); if (!unwindInfo) return false; diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index 6191cee..c045ba6 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -51,7 +51,7 @@ namespace Luau namespace CodeGen { -void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock) +void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& beginOffset) { #if defined(_WIN32) && defined(_M_X64) UnwindBuilder* unwind = (UnwindBuilder*)context; @@ -75,7 +75,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz return nullptr; } - unwindDataSizeInBlock = unwindSize; + beginOffset = unwindSize + unwind->getBeginOffset(); return block; #elif !defined(_WIN32) UnwindBuilder* unwind = (UnwindBuilder*)context; @@ -94,7 +94,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz __register_frame(unwindData); #endif - unwindDataSizeInBlock = unwindSize; + beginOffset = unwindSize + unwind->getBeginOffset(); return block; #endif diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index f3886d9..8d06864 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -129,6 +129,16 @@ namespace Luau namespace CodeGen { +void UnwindBuilderDwarf2::setBeginOffset(size_t beginOffset) +{ + this->beginOffset = beginOffset; +} + +size_t UnwindBuilderDwarf2::getBeginOffset() const +{ + return beginOffset; +} + void UnwindBuilderDwarf2::start() { uint8_t* cieLength = pos; diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp index 5405fcf..1b3279e 100644 --- a/CodeGen/src/UnwindBuilderWin.cpp +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -32,6 +32,16 @@ struct UnwindInfoWin uint8_t frameregoff : 4; }; +void UnwindBuilderWin::setBeginOffset(size_t beginOffset) +{ + this->beginOffset = beginOffset; +} + +size_t UnwindBuilderWin::getBeginOffset() const +{ + return beginOffset; +} + void UnwindBuilderWin::start() { stackOffset = 8; // Return address was pushed by calling the function diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index ce47cd9..7cff70a 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -13,6 +13,7 @@ inline bool isFlagExperimental(const char* flag) static const char* kList[] = { "LuauLowerBoundsCalculation", "LuauInterpolatedStringBaseSupport", + "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code // makes sure we always have at least one entry nullptr, }; diff --git a/Makefile b/Makefile index 5b8eb93..8400717 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,7 @@ MAKEFLAGS+=-r -j8 COMMA=, config=debug +protobuf=system BUILD=build/$(config) @@ -95,12 +96,22 @@ ifeq ($(config),fuzz) CXX=clang++ # our fuzzing infra relies on llvm fuzzer CXXFLAGS+=-fsanitize=address,fuzzer -Ibuild/libprotobuf-mutator -O2 LDFLAGS+=-fsanitize=address,fuzzer + LPROTOBUF=-lprotobuf + DPROTOBUF=-D CMAKE_BUILD_TYPE=Release -D LIB_PROTO_MUTATOR_TESTING=OFF + EPROTOC=protoc endif ifeq ($(config),profile) CXXFLAGS+=-O2 -DNDEBUG -gdwarf-4 -DCALLGRIND=1 endif +ifeq ($(protobuf),download) + CXXFLAGS+=-Ibuild/libprotobuf-mutator/external.protobuf/include + LPROTOBUF=build/libprotobuf-mutator/external.protobuf/lib/libprotobuf.a + DPROTOBUF+=-D LIB_PROTO_MUTATOR_DOWNLOAD_PROTOBUF=ON + EPROTOC=../build/libprotobuf-mutator/external.protobuf/bin/protoc +endif + # target-specific flags $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include @@ -115,7 +126,7 @@ $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/ $(TESTS_TARGET): LDFLAGS+=-lpthread $(REPL_CLI_TARGET): LDFLAGS+=-lpthread -fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a -lprotobuf +fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a $(LPROTOBUF) # pseudo targets .PHONY: all test clean coverage format luau-size aliases @@ -199,7 +210,7 @@ $(BUILD)/%.c.o: %.c # protobuf fuzzer setup fuzz/luau.pb.cpp: fuzz/luau.proto build/libprotobuf-mutator - cd fuzz && protoc luau.proto --cpp_out=. + cd fuzz && $(EPROTOC) luau.proto --cpp_out=. mv fuzz/luau.pb.cc fuzz/luau.pb.cpp $(BUILD)/fuzz/proto.cpp.o: fuzz/luau.pb.cpp @@ -207,7 +218,7 @@ $(BUILD)/fuzz/protoprint.cpp.o: fuzz/luau.pb.cpp build/libprotobuf-mutator: git clone https://github.com/google/libprotobuf-mutator build/libprotobuf-mutator - CXX= cmake -S build/libprotobuf-mutator -B build/libprotobuf-mutator -D CMAKE_BUILD_TYPE=Release -D LIB_PROTO_MUTATOR_TESTING=OFF + CXX= cmake -S build/libprotobuf-mutator -B build/libprotobuf-mutator $(DPROTOBUF) make -C build/libprotobuf-mutator -j8 # picks up include dependencies for all object files diff --git a/VM/include/lua.h b/VM/include/lua.h index 33e6851..ea658a4 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -401,13 +401,15 @@ struct lua_Debug const char* name; // (n) const char* what; // (s) `Lua', `C', `main', `tail' const char* source; // (s) + const char* short_src; // (s) int linedefined; // (s) int currentline; // (l) unsigned char nupvals; // (u) number of upvalues unsigned char nparams; // (a) number of parameters char isvararg; // (a) - char short_src[LUA_IDSIZE]; // (s) void* userdata; // only valid in luau_callhook + + char ssbuf[LUA_IDSIZE]; }; // }====================================================================== diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index e695cd2..82af5d3 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,6 +12,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauFasterGetInfo, false) + static const char* getfuncname(Closure* f); static int currentpc(lua_State* L, CallInfo* ci) @@ -89,9 +91,9 @@ const char* lua_setlocal(lua_State* L, int level, int n) return name; } -static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, CallInfo* ci) +static Closure* auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, CallInfo* ci) { - int status = 1; + Closure* cl = NULL; for (; *what; what++) { switch (*what) @@ -103,14 +105,23 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, ar->source = "=[C]"; ar->what = "C"; ar->linedefined = -1; + if (FFlag::LuauFasterGetInfo) + ar->short_src = "[C]"; } else { - ar->source = getstr(f->l.p->source); + TString* source = f->l.p->source; + ar->source = getstr(source); ar->what = "Lua"; ar->linedefined = f->l.p->linedefined; + if (FFlag::LuauFasterGetInfo) + ar->short_src = luaO_chunkid(ar->ssbuf, sizeof(ar->ssbuf), getstr(source), source->len); + } + if (!FFlag::LuauFasterGetInfo) + { + luaO_chunkid(ar->ssbuf, LUA_IDSIZE, ar->source, 0); + ar->short_src = ar->ssbuf; } - luaO_chunkid(ar->short_src, ar->source, LUA_IDSIZE); break; } case 'l': @@ -150,10 +161,15 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, ar->name = ci ? getfuncname(ci_func(ci)) : getfuncname(f); break; } + case 'f': + { + cl = f; + break; + } default:; } } - return status; + return cl; } int lua_stackdepth(lua_State* L) @@ -163,7 +179,6 @@ int lua_stackdepth(lua_State* L) int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar) { - int status = 0; Closure* f = NULL; CallInfo* ci = NULL; if (level < 0) @@ -180,15 +195,28 @@ int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar) } if (f) { - status = auxgetinfo(L, what, ar, f, ci); - if (strchr(what, 'f')) + if (FFlag::LuauFasterGetInfo) { - luaC_threadbarrier(L); - setclvalue(L, L->top, f); - incr_top(L); + // auxgetinfo fills ar and optionally requests to put closure on stack + if (Closure* fcl = auxgetinfo(L, what, ar, f, ci)) + { + luaC_threadbarrier(L); + setclvalue(L, L->top, fcl); + incr_top(L); + } + } + else + { + auxgetinfo(L, what, ar, f, ci); + if (strchr(what, 'f')) + { + luaC_threadbarrier(L); + setclvalue(L, L->top, f); + incr_top(L); + } } } - return status; + return f ? 1 : 0; } static const char* getfuncname(Closure* cl) @@ -284,10 +312,11 @@ static void pusherror(lua_State* L, const char* msg) CallInfo* ci = L->ci; if (isLua(ci)) { - char buff[LUA_IDSIZE]; // add file:line information - luaO_chunkid(buff, getstr(getluaproto(ci)->source), LUA_IDSIZE); + TString* source = getluaproto(ci)->source; + char chunkbuf[LUA_IDSIZE]; // add file:line information + const char* chunkid = luaO_chunkid(chunkbuf, sizeof(chunkbuf), getstr(source), source->len); int line = currentline(L, ci); - luaO_pushfstring(L, "%s:%d: %s", buff, line, msg); + luaO_pushfstring(L, "%s:%d: %s", chunkid, line, msg); } else { diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 4b9fbb6..f50b33d 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauBetterThreadMark, false) - /* * Luau uses an incremental non-generational non-moving mark&sweep garbage collector. * @@ -473,54 +471,25 @@ static size_t propagatemark(global_State* g) bool active = th->isactive || th == th->global->mainthread; - if (FFlag::LuauBetterThreadMark) + traversestack(g, th); + + // active threads will need to be rescanned later to mark new stack writes so we mark them gray again + if (active) { - traversestack(g, th); + th->gclist = g->grayagain; + g->grayagain = o; - // active threads will need to be rescanned later to mark new stack writes so we mark them gray again - if (active) - { - th->gclist = g->grayagain; - g->grayagain = o; - - black2gray(o); - } - - // the stack needs to be cleared after the last modification of the thread state before sweep begins - // if the thread is inactive, we might not see the thread in this cycle so we must clear it now - if (!active || g->gcstate == GCSatomic) - clearstack(th); - - // we could shrink stack at any time but we opt to do it during initial mark to do that just once per cycle - if (g->gcstate == GCSpropagate) - shrinkstack(th); + black2gray(o); } - else - { - // TODO: Refactor this logic! - if (!active && g->gcstate == GCSpropagate) - { - traversestack(g, th); - clearstack(th); - } - else - { - th->gclist = g->grayagain; - g->grayagain = o; - black2gray(o); + // the stack needs to be cleared after the last modification of the thread state before sweep begins + // if the thread is inactive, we might not see the thread in this cycle so we must clear it now + if (!active || g->gcstate == GCSatomic) + clearstack(th); - traversestack(g, th); - - // final traversal? - if (g->gcstate == GCSatomic) - clearstack(th); - } - - // we could shrink stack at any time but we opt to skip it during atomic since it's redundant to do that more than once per cycle - if (g->gcstate != GCSatomic) - shrinkstack(th); - } + // we could shrink stack at any time but we opt to do it during initial mark to do that just once per cycle + if (g->gcstate == GCSpropagate) + shrinkstack(th); return sizeof(lua_State) + sizeof(TValue) * th->stacksize + sizeof(CallInfo) * th->size_ci; } diff --git a/VM/src/lobject.cpp b/VM/src/lobject.cpp index f5f1cd0..8b3e478 100644 --- a/VM/src/lobject.cpp +++ b/VM/src/lobject.cpp @@ -15,6 +15,8 @@ +LUAU_FASTFLAG(LuauFasterGetInfo) + const TValue luaO_nilobject_ = {{NULL}, {0}, LUA_TNIL}; int luaO_log2(unsigned int x) @@ -117,44 +119,68 @@ const char* luaO_pushfstring(lua_State* L, const char* fmt, ...) return msg; } -void luaO_chunkid(char* out, const char* source, size_t bufflen) +const char* luaO_chunkid(char* buf, size_t buflen, const char* source, size_t srclen) { if (*source == '=') { - source++; // skip the `=' - size_t srclen = strlen(source); - size_t dstlen = srclen < bufflen ? srclen : bufflen - 1; - memcpy(out, source, dstlen); - out[dstlen] = '\0'; + if (FFlag::LuauFasterGetInfo) + { + if (srclen <= buflen) + return source + 1; + // truncate the part after = + memcpy(buf, source + 1, buflen - 1); + buf[buflen - 1] = '\0'; + } + else + { + source++; // skip the `=' + size_t len = strlen(source); + size_t dstlen = len < buflen ? len : buflen - 1; + memcpy(buf, source, dstlen); + buf[dstlen] = '\0'; + } } else if (*source == '@') { - size_t l; - source++; // skip the `@' - bufflen -= sizeof("..."); - l = strlen(source); - strcpy(out, ""); - if (l > bufflen) + if (FFlag::LuauFasterGetInfo) { - source += (l - bufflen); // get last part of file name - strcat(out, "..."); - } - strcat(out, source); - } - else - { // out = [string "string"] - size_t len = strcspn(source, "\n\r"); // stop at first newline - bufflen -= sizeof("[string \"...\"]"); - if (len > bufflen) - len = bufflen; - strcpy(out, "[string \""); - if (source[len] != '\0') - { // must truncate? - strncat(out, source, len); - strcat(out, "..."); + if (srclen <= buflen) + return source + 1; + // truncate the part after @ + memcpy(buf, "...", 3); + memcpy(buf + 3, source + srclen - (buflen - 4), buflen - 4); + buf[buflen - 1] = '\0'; } else - strcat(out, source); - strcat(out, "\"]"); + { + size_t l; + source++; // skip the `@' + buflen -= sizeof("..."); + l = strlen(source); + strcpy(buf, ""); + if (l > buflen) + { + source += (l - buflen); // get last part of file name + strcat(buf, "..."); + } + strcat(buf, source); + } } + else + { // buf = [string "string"] + size_t len = strcspn(source, "\n\r"); // stop at first newline + buflen -= sizeof("[string \"...\"]"); + if (len > buflen) + len = buflen; + strcpy(buf, "[string \""); + if (source[len] != '\0') + { // must truncate? + strncat(buf, source, len); + strcat(buf, "..."); + } + else + strcat(buf, source); + strcat(buf, "\"]"); + } + return buf; } diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 5f5e7b1..41bf338 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -460,4 +460,4 @@ LUAI_FUNC int luaO_rawequalKey(const TKey* t1, const TValue* t2); LUAI_FUNC int luaO_str2d(const char* s, double* result); LUAI_FUNC const char* luaO_pushvfstring(lua_State* L, const char* fmt, va_list argp); LUAI_FUNC const char* luaO_pushfstring(lua_State* L, const char* fmt, ...); -LUAI_FUNC void luaO_chunkid(char* out, const char* source, size_t len); +LUAI_FUNC const char* luaO_chunkid(char* buf, size_t buflen, const char* source, size_t srclen); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 490358c..7ee3ee9 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -781,6 +781,7 @@ reentry: default: LUAU_ASSERT(!"Unknown upvalue capture type"); + LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks } } @@ -1184,7 +1185,9 @@ reentry: // slow path after switch() break; - default:; + default: + LUAU_ASSERT(!"Unknown value type"); + LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks } // slow-path: tables with metatables and userdata values @@ -1296,7 +1299,9 @@ reentry: // slow path after switch() break; - default:; + default: + LUAU_ASSERT(!"Unknown value type"); + LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks } // slow-path: tables with metatables and userdata values diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index bd40bad..3edec68 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -148,16 +148,16 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size // 0 means the rest of the bytecode is the error message if (version == 0) { - char chunkid[LUA_IDSIZE]; - luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); + char chunkbuf[LUA_IDSIZE]; + const char* chunkid = luaO_chunkid(chunkbuf, sizeof(chunkbuf), chunkname, strlen(chunkname)); lua_pushfstring(L, "%s%.*s", chunkid, int(size - offset), data + offset); return 1; } if (version < LBC_VERSION_MIN || version > LBC_VERSION_MAX) { - char chunkid[LUA_IDSIZE]; - luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); + char chunkbuf[LUA_IDSIZE]; + const char* chunkid = luaO_chunkid(chunkbuf, sizeof(chunkbuf), chunkname, strlen(chunkname)); lua_pushfstring(L, "%s: bytecode version mismatch (expected [%d..%d], got %d)", chunkid, LBC_VERSION_MIN, LBC_VERSION_MAX, version); return 1; } diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 5c55515..05d3975 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -220,8 +220,13 @@ int luaV_strcmp(const TString* ls, const TString* rs) return 0; const char* l = getstr(ls); - size_t ll = ls->len; const char* r = getstr(rs); + + // always safe to read one character because even empty strings are nul terminated + if (*l != *r) + return uint8_t(*l) - uint8_t(*r); + + size_t ll = ls->len; size_t lr = rs->len; size_t lmin = ll < lr ? ll : lr; diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 08f241e..a051088 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -240,12 +240,12 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfLea") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfAbsoluteJumps") { - SINGLE_COMPARE(jmp(rax), 0x48, 0xff, 0xe0); - SINGLE_COMPARE(jmp(r14), 0x49, 0xff, 0xe6); - SINGLE_COMPARE(jmp(qword[r14 + rdx * 4]), 0x49, 0xff, 0x24, 0x96); - SINGLE_COMPARE(call(rax), 0x48, 0xff, 0xd0); - SINGLE_COMPARE(call(r14), 0x49, 0xff, 0xd6); - SINGLE_COMPARE(call(qword[r14 + rdx * 4]), 0x49, 0xff, 0x14, 0x96); + SINGLE_COMPARE(jmp(rax), 0xff, 0xe0); + SINGLE_COMPARE(jmp(r14), 0x41, 0xff, 0xe6); + SINGLE_COMPARE(jmp(qword[r14 + rdx * 4]), 0x41, 0xff, 0x24, 0x96); + SINGLE_COMPARE(call(rax), 0xff, 0xd0); + SINGLE_COMPARE(call(r14), 0x41, 0xff, 0xd6); + SINGLE_COMPARE(call(qword[r14 + rdx * 4]), 0x41, 0xff, 0x14, 0x96); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "FormsOfImul") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index a64d372..9a5c341 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2955,8 +2955,6 @@ local abc = b@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_on_class") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - loadDefinition(R"( declare class Foo function one(self): number @@ -2995,8 +2993,6 @@ t.@1 TEST_CASE_FIXTURE(ACFixture, "do_compatible_self_calls") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( local t = {} function t:m() end @@ -3011,8 +3007,6 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( local t = {} function t.m() end @@ -3027,8 +3021,6 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_2") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( local f: (() -> number) & ((number) -> number) = function(x: number?) return 2 end local t = {} @@ -3059,8 +3051,6 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "no_wrong_compatible_self_calls_with_generics") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( local t = {} function t.m(a: T) end @@ -3076,8 +3066,6 @@ t:@1 TEST_CASE_FIXTURE(ACFixture, "string_prim_self_calls_are_fine") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( local s = "hello" s:@1 @@ -3095,8 +3083,6 @@ s:@1 TEST_CASE_FIXTURE(ACFixture, "string_prim_non_self_calls_are_avoided") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( local s = "hello" s.@1 @@ -3112,8 +3098,6 @@ s.@1 TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_non_self_calls_are_fine") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( string.@1 )"); @@ -3143,8 +3127,6 @@ table.@1 TEST_CASE_FIXTURE(ACBuiltinsFixture, "library_self_calls_are_invalid") { - ScopedFastFlag selfCallAutocompleteFix3{"LuauSelfCallAutocompleteFix3", true}; - check(R"( string:@1 )"); diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 758fb44..4e553f0 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -96,12 +96,12 @@ TEST_CASE("CodeAllocationWithUnwindCallbacks") data.resize(8); allocator.context = &info; - allocator.createBlockUnwindInfo = [](void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock) -> void* { + allocator.createBlockUnwindInfo = [](void* context, uint8_t* block, size_t blockSize, size_t& beginOffset) -> void* { Info& info = *(Info*)context; CHECK(info.unwind.size() == 8); memcpy(block, info.unwind.data(), info.unwind.size()); - unwindDataSizeInBlock = 8; + beginOffset = 8; info.block = block; @@ -194,10 +194,12 @@ TEST_CASE("Dwarf2UnwindCodesX64") // Windows x64 ABI constexpr RegisterX64 rArg1 = rcx; constexpr RegisterX64 rArg2 = rdx; +constexpr RegisterX64 rArg3 = r8; #else // System V AMD64 ABI constexpr RegisterX64 rArg1 = rdi; constexpr RegisterX64 rArg2 = rsi; +constexpr RegisterX64 rArg3 = rdx; #endif constexpr RegisterX64 rNonVol1 = r12; @@ -313,6 +315,119 @@ TEST_CASE("GeneratedCodeExecutionWithThrow") } } +TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGate") +{ + AssemblyBuilderX64 build(/* logText= */ false); + +#if defined(_WIN32) + std::unique_ptr unwind = std::make_unique(); +#else + std::unique_ptr unwind = std::make_unique(); +#endif + + unwind->start(); + + // Prologue (some of these registers don't have to be saved, but we want to have a big prologue) + build.push(r10); + unwind->save(r10); + build.push(r11); + unwind->save(r11); + build.push(r12); + unwind->save(r12); + build.push(r13); + unwind->save(r13); + build.push(r14); + unwind->save(r14); + build.push(r15); + unwind->save(r15); + build.push(rbp); + unwind->save(rbp); + + int stackSize = 64; + int localsSize = 16; + + build.sub(rsp, stackSize + localsSize); + unwind->allocStack(stackSize + localsSize); + + build.lea(rbp, qword[rsp + stackSize]); + unwind->setupFrameReg(rbp, stackSize); + + unwind->finish(); + + size_t prologueSize = build.setLabel().location; + + // Body + build.mov(rax, rArg1); + build.mov(rArg1, 25); + build.jmp(rax); + + Label returnOffset = build.setLabel(); + + // Epilogue + build.lea(rsp, qword[rbp + localsSize]); + build.pop(rbp); + build.pop(r15); + build.pop(r14); + build.pop(r13); + build.pop(r12); + build.pop(r11); + build.pop(r10); + build.ret(); + + build.finalize(); + + size_t blockSize = 4096; // Force allocate to create a new block each time + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + allocator.context = unwind.get(); + allocator.createBlockUnwindInfo = createBlockUnwindInfo; + allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + + uint8_t* nativeData1; + size_t sizeNativeData1; + uint8_t* nativeEntry1; + REQUIRE( + allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData1, sizeNativeData1, nativeEntry1)); + REQUIRE(nativeEntry1); + + // Now we set the offset at the begining so that functions in new blocks will not overlay the locations + // specified by the unwind information of the entry function + unwind->setBeginOffset(prologueSize); + + using FunctionType = int64_t(void*, void (*)(int64_t), void*); + FunctionType* f = (FunctionType*)nativeEntry1; + + uint8_t* nativeExit = nativeEntry1 + returnOffset.location; + + AssemblyBuilderX64 build2(/* logText= */ false); + + build2.mov(r12, rArg3); + build2.call(rArg2); + build2.jmp(r12); + + build2.finalize(); + + uint8_t* nativeData2; + size_t sizeNativeData2; + uint8_t* nativeEntry2; + REQUIRE(allocator.allocate( + build2.data.data(), build2.data.size(), build2.code.data(), build2.code.size(), nativeData2, sizeNativeData2, nativeEntry2)); + REQUIRE(nativeEntry2); + + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here + try + { + f(nativeEntry2, throwing, nativeExit); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } + + REQUIRE(nativeEntry2); +} + #endif TEST_SUITE_END(); diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp index 9976bd2..3420fd8 100644 --- a/tests/ConstraintSolver.test.cpp +++ b/tests/ConstraintSolver.test.cpp @@ -28,8 +28,11 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello") cgb.visit(block); NotNull rootScope{cgb.rootScope}; + InternalErrorReporter iceHandler; + UnifierSharedState sharedState{&iceHandler}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; NullModuleResolver resolver; - ConstraintSolver cs{&arena, singletonTypes, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; + ConstraintSolver cs{NotNull{&normalizer}, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; cs.run(); @@ -49,9 +52,11 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") cgb.visit(block); NotNull rootScope{cgb.rootScope}; + InternalErrorReporter iceHandler; + UnifierSharedState sharedState{&iceHandler}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; NullModuleResolver resolver; - ConstraintSolver cs{&arena, singletonTypes, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; - + ConstraintSolver cs{NotNull{&normalizer}, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; cs.run(); TypeId idType = requireBinding(rootScope, "id"); @@ -79,7 +84,10 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") ToStringOptions opts; NullModuleResolver resolver; - ConstraintSolver cs{&arena, singletonTypes, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; + InternalErrorReporter iceHandler; + UnifierSharedState sharedState{&iceHandler}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + ConstraintSolver cs{NotNull{&normalizer}, rootScope, "MainModule", NotNull(&resolver), {}, &logger}; cs.run(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index dcc0222..4d3c885 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -22,6 +22,8 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauReportShadowedTypeAlias) +extern std::optional randomSeed; // tests/main.cpp + namespace Luau { @@ -90,7 +92,7 @@ std::optional TestFileResolver::getEnvironmentForModule(const Modul Fixture::Fixture(bool freeze, bool prepareAutocomplete) : sff_DebugLuauFreezeArena("DebugLuauFreezeArena", freeze) - , frontend(&fileResolver, &configResolver, {/* retainFullTypeGraphs= */ true}) + , frontend(&fileResolver, &configResolver, {/* retainFullTypeGraphs= */ true, /* forAutocomplete */ false, /* randomConstraintResolutionSeed */ randomSeed}) , typeChecker(frontend.typeChecker) , singletonTypes(frontend.singletonTypes) { diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 921e669..fc3ede7 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -21,14 +21,14 @@ end return math.max(fib(5), 1) )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnknownGlobal") { LintResult result = lint("--!nocheck\nreturn foo"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Unknown global 'foo'"); } @@ -39,7 +39,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobal") LintResult result = lintTyped("Wait(5)"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'Wait' is deprecated, use 'wait' instead"); } @@ -53,7 +53,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobalNoReplacement") LintResult result = lintTyped("Version()"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'Version' is deprecated"); } @@ -64,7 +64,7 @@ local _ = 5 return _ )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Placeholder value '_' is read here; consider using a named variable"); } @@ -75,7 +75,7 @@ _ = 5 print(_) )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Placeholder value '_' is read here; consider using a named variable"); } @@ -86,7 +86,7 @@ local _ = 5 _ = 6 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(BuiltinsFixture, "BuiltinGlobalWrite") @@ -100,7 +100,7 @@ end assert(5) )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Built-in global 'math' is overwritten here; consider using a local or changing the name"); CHECK_EQ(result.warnings[1].text, "Built-in global 'assert' is overwritten here; consider using a local or changing the name"); } @@ -111,7 +111,7 @@ TEST_CASE_FIXTURE(Fixture, "MultilineBlock") if true then print(1) print(2) print(3) end )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "A new statement is on the same line; add semi-colon on previous statement to silence"); } @@ -121,7 +121,7 @@ TEST_CASE_FIXTURE(Fixture, "MultilineBlockSemicolonsWhitelisted") print(1); print(2); print(3) )"); - CHECK(result.warnings.empty()); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "MultilineBlockMissedSemicolon") @@ -130,7 +130,7 @@ TEST_CASE_FIXTURE(Fixture, "MultilineBlockMissedSemicolon") print(1); print(2) print(3) )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "A new statement is on the same line; add semi-colon on previous statement to silence"); } @@ -142,7 +142,7 @@ local _x do end )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "ConfusingIndentation") @@ -152,7 +152,7 @@ print(math.max(1, 2)) )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Statement spans multiple lines; use indentation to silence"); } @@ -167,7 +167,7 @@ end return bar() )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'foo' is only used in the enclosing function 'bar'; consider changing it to local"); } @@ -188,7 +188,7 @@ end return bar() + baz() )"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'foo' is never read before being written. Consider changing it to local"); } @@ -213,7 +213,7 @@ end return bar() + baz() + read() )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalWithConditional") @@ -233,7 +233,7 @@ end return bar() + baz() )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "GlobalAsLocal3WithConditionalRead") @@ -257,7 +257,7 @@ end return bar() + baz() + read() )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalInnerRead") @@ -275,7 +275,7 @@ function baz() bar = 0 end return foo() + baz() )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalMulti") @@ -304,7 +304,7 @@ fnA() -- prints "true", "nil" fnB() -- prints "false", "nil" )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Global 'moreInternalLogic' is only used in the enclosing function defined at line 2; consider changing it to local"); } @@ -319,7 +319,7 @@ local arg = 5 print(arg) )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'arg' shadows previous declaration at line 2"); } @@ -337,7 +337,7 @@ end return bar() )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'global' shadows a global variable used at line 3"); } @@ -352,7 +352,7 @@ end return bar() )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'a' shadows previous declaration at line 2"); } @@ -372,7 +372,7 @@ end return bar() )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'arg' is never used; prefix with '_' to silence"); CHECK_EQ(result.warnings[1].text, "Variable 'blarg' is never used; prefix with '_' to silence"); } @@ -387,7 +387,7 @@ local Roact = require(game.Packages.Roact) local _Roact = require(game.Packages.Roact) )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Import 'Roact' is never used; prefix with '_' to silence"); } @@ -412,7 +412,7 @@ end return foo() )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Function 'bar' is never used; prefix with '_' to silence"); CHECK_EQ(result.warnings[1].text, "Function 'qux' is never used; prefix with '_' to silence"); } @@ -427,7 +427,7 @@ end print("hi!") )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 5); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always returns)"); } @@ -443,7 +443,7 @@ end print("hi!") )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 3); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always breaks)"); } @@ -459,7 +459,7 @@ end print("hi!") )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 3); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always continues)"); } @@ -495,7 +495,7 @@ end return { foo1, foo2, foo3 } )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 7); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always returns)"); } @@ -515,7 +515,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnreachableCodeAssertFalseReturnSilent") @@ -532,7 +532,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnreachableCodeErrorReturnNonSilentBranchy") @@ -550,7 +550,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 7); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always errors)"); } @@ -571,7 +571,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 8); CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always errors)"); } @@ -589,7 +589,7 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnreachableCodeLoopRepeat") @@ -605,8 +605,8 @@ end return foo1 )"); - CHECK_EQ(result.warnings.size(), - 0); // this is technically a bug, since the repeat body always returns; fixing this bug is a bit more involved than I'd like + // this is technically a bug, since the repeat body always returns; fixing this bug is a bit more involved than I'd like + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "UnknownType") @@ -633,7 +633,7 @@ local _o02 = type(game) == "vector" local _o03 = typeof(game) == "Part" )"); - REQUIRE_EQ(result.warnings.size(), 3); + REQUIRE(3 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 2); CHECK_EQ(result.warnings[0].text, "Unknown type 'Part' (expected primitive type)"); CHECK_EQ(result.warnings[1].location.begin.line, 3); @@ -654,7 +654,7 @@ for i=#t,1,-1 do end )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 3); CHECK_EQ(result.warnings[0].text, "For loop should iterate backwards; did you forget to specify -1 as step?"); } @@ -669,7 +669,7 @@ for i=8,1,-1 do end )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 1); CHECK_EQ(result.warnings[0].text, "For loop should iterate backwards; did you forget to specify -1 as step?"); } @@ -684,7 +684,7 @@ for i=1.3,7.5,1 do end )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 1); CHECK_EQ(result.warnings[0].text, "For loop ends at 7.3 instead of 7.5; did you forget to specify step?"); } @@ -702,7 +702,7 @@ for i=#t,0 do end )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 1); CHECK_EQ(result.warnings[0].text, "For loop starts at 0, but arrays start at 1"); CHECK_EQ(result.warnings[1].location.begin.line, 7); @@ -730,7 +730,7 @@ local _a,_b,_c = pcall(), nil end )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 5); CHECK_EQ(result.warnings[0].text, "Assigning 2 values to 3 variables initializes extra variables with nil; add 'nil' to value list to silence"); CHECK_EQ(result.warnings[1].location.begin.line, 11); @@ -795,7 +795,7 @@ end return f1,f2,f3,f4,f5,f6,f7 )"); - CHECK_EQ(result.warnings.size(), 3); + REQUIRE(3 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 4); CHECK_EQ(result.warnings[0].text, "Function 'f1' can implicitly return no values even though there's an explicit return at line 4; add explicit return to silence"); @@ -851,7 +851,7 @@ end return f1,f2,f3,f4 )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].location.begin.line, 25); CHECK_EQ(result.warnings[0].text, "Function 'f3' can implicitly return no values even though there's an explicit return at line 21; add explicit return to silence"); @@ -874,7 +874,7 @@ type InputData = { } )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "BreakFromInfiniteLoopMakesStatementReachable") @@ -893,7 +893,7 @@ until true return 1 )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "IgnoreLintAll") @@ -903,7 +903,7 @@ TEST_CASE_FIXTURE(Fixture, "IgnoreLintAll") return foo )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "IgnoreLintSpecific") @@ -914,7 +914,7 @@ local x = 1 return foo )"); - CHECK_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'x' is never used; prefix with '_' to silence"); } @@ -933,7 +933,7 @@ local _ = ("%"):format() string.format("hello %+10d %.02f %%", 4, 5) )"); - CHECK_EQ(result.warnings.size(), 4); + REQUIRE(4 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid format string: unfinished format specifier"); CHECK_EQ(result.warnings[1].text, "Invalid format string: invalid format specifier: must be a string format specifier or %"); CHECK_EQ(result.warnings[2].text, "Invalid format string: invalid format specifier: must be a string format specifier or %"); @@ -973,7 +973,7 @@ string.packsize("c99999999999999999999") string.packsize("=!1bbbI3c42") )"); - CHECK_EQ(result.warnings.size(), 11); + REQUIRE(11 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid pack format: unexpected character; must be a pack specifier or space"); CHECK_EQ(result.warnings[1].text, "Invalid pack format: unexpected character; must be a pack specifier or space"); CHECK_EQ(result.warnings[2].text, "Invalid pack format: unexpected character; must be a pack specifier or space"); @@ -1017,7 +1017,7 @@ local _ = s:match("%q") string.match(s, "[A-Z]+(%d)%1") )"); - CHECK_EQ(result.warnings.size(), 14); + REQUIRE(14 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); CHECK_EQ(result.warnings[2].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); @@ -1049,7 +1049,7 @@ string.match(s, "((a)%1)") string.match(s, "((a)%3)") )~"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: invalid capture reference, must refer to a closed capture"); CHECK_EQ(result.warnings[0].location.begin.line, 7); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: invalid capture reference, must refer to a valid capture"); @@ -1087,7 +1087,7 @@ string.match(s, "[]|'[]") string.match(s, "[^]|'[]") )~"); - CHECK_EQ(result.warnings.size(), 7); + REQUIRE(7 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: expected ] at the end of the string to close a set"); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: expected ] at the end of the string to close a set"); CHECK_EQ(result.warnings[2].text, "Invalid match pattern: character range can't include character sets"); @@ -1118,7 +1118,7 @@ string.find("foo"); ("foo"):find() )"); - CHECK_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); CHECK_EQ(result.warnings[0].location.begin.line, 4); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); @@ -1141,7 +1141,7 @@ string.gsub(s, '[A-Z]+(%d)', "%0%1") string.gsub(s, 'foo', "%0") )"); - CHECK_EQ(result.warnings.size(), 4); + REQUIRE(4 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match replacement: unfinished replacement"); CHECK_EQ(result.warnings[1].text, "Invalid match replacement: unexpected replacement character; must be a digit or %"); CHECK_EQ(result.warnings[2].text, "Invalid match replacement: invalid capture index, must refer to pattern capture"); @@ -1162,7 +1162,7 @@ os.date("it's %c now") os.date("!*t") )"); - CHECK_EQ(result.warnings.size(), 4); + REQUIRE(4 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid date format: unfinished replacement"); CHECK_EQ(result.warnings[1].text, "Invalid date format: unexpected replacement character; must be a date format specifier or %"); CHECK_EQ(result.warnings[2].text, "Invalid date format: unexpected replacement character; must be a date format specifier or %"); @@ -1181,7 +1181,7 @@ s:match("[]") nons:match("[]") )~"); - REQUIRE_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Invalid match pattern: expected ] at the end of the string to close a set"); CHECK_EQ(result.warnings[0].location.begin.line, 3); CHECK_EQ(result.warnings[1].text, "Invalid match pattern: expected ] at the end of the string to close a set"); @@ -1231,7 +1231,7 @@ _ = { } )"); - CHECK_EQ(result.warnings.size(), 6); + REQUIRE(6 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Table field 'first' is a duplicate; previously defined at line 3"); CHECK_EQ(result.warnings[1].text, "Table field 'first' is a duplicate; previously defined at line 9"); CHECK_EQ(result.warnings[2].text, "Table index 1 is a duplicate; previously defined as a list entry"); @@ -1248,7 +1248,7 @@ TEST_CASE_FIXTURE(Fixture, "ImportOnlyUsedInTypeAnnotation") local x: Foo.Y = 1 )"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'x' is never used; prefix with '_' to silence"); } @@ -1259,7 +1259,7 @@ TEST_CASE_FIXTURE(Fixture, "DisableUnknownGlobalWithTypeChecking") unknownGlobal() )"); - REQUIRE_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "no_spurious_warning_after_a_function_type_alias") @@ -1271,7 +1271,7 @@ TEST_CASE_FIXTURE(Fixture, "no_spurious_warning_after_a_function_type_alias") return exports )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") @@ -1294,7 +1294,7 @@ TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") LintResult result = frontend.lint("A"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "DeadLocalsUsed") @@ -1320,7 +1320,7 @@ do end )"); - CHECK_EQ(result.warnings.size(), 3); + REQUIRE(3 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Variable 'x' defined at line 4 is never initialized or assigned; initialize with 'nil' to silence"); CHECK_EQ(result.warnings[1].text, "Assigning 2 values to 3 variables initializes extra variables with nil; add 'nil' to value list to silence"); CHECK_EQ(result.warnings[2].text, "Variable 'c' defined at line 12 is never initialized or assigned; initialize with 'nil' to silence"); @@ -1333,7 +1333,7 @@ local foo function foo() end )"); - CHECK_EQ(result.warnings.size(), 0); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "DuplicateGlobalFunction") @@ -1408,7 +1408,7 @@ TEST_CASE_FIXTURE(Fixture, "DontTriggerTheWarningIfTheFunctionsAreInDifferentSco return c )"); - CHECK(result.warnings.empty()); + REQUIRE(0 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") @@ -1444,7 +1444,7 @@ TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") local h: Hooty.Pt )"); - CHECK_EQ(result.warnings.size(), 12); + REQUIRE(12 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") @@ -1478,7 +1478,7 @@ return function (i: Instance) end )"); - REQUIRE_EQ(result.warnings.size(), 3); + REQUIRE(3 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Member 'Instance.Wait' is deprecated"); CHECK_EQ(result.warnings[1].text, "Member 'toHSV' is deprecated, use 'Color3:ToHSV' instead"); CHECK_EQ(result.warnings[2].text, "Member 'Instance.DataCost' is deprecated"); @@ -1511,7 +1511,7 @@ table.create(42, {}) table.create(42, {} :: {}) )"); - REQUIRE_EQ(result.warnings.size(), 10); + REQUIRE(10 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "table.insert will insert the value before the last element, which is likely a bug; consider removing the " "second argument or wrap it in parentheses to silence"); CHECK_EQ(result.warnings[1].text, "table.insert will append the value to the table; consider removing the second argument for efficiency"); @@ -1556,7 +1556,7 @@ _ = true and true or false -- no warning since this is is a common pattern used _ = if true then 1 elseif true then 2 else 3 )"); - REQUIRE_EQ(result.warnings.size(), 8); + REQUIRE(8 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 2"); CHECK_EQ(result.warnings[0].location.begin.line + 1, 4); CHECK_EQ(result.warnings[1].text, "Condition has already been checked on column 5"); @@ -1580,7 +1580,7 @@ elseif correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls", false)}) t end )"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 4"); CHECK_EQ(result.warnings[0].location.begin.line + 1, 5); } @@ -1601,7 +1601,7 @@ end return foo, moo, a1, a2 )"); - REQUIRE_EQ(result.warnings.size(), 4); + REQUIRE(4 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Function parameter 'a1' already defined on column 14"); CHECK_EQ(result.warnings[1].text, "Variable 'a1' is never used; prefix with '_' to silence"); CHECK_EQ(result.warnings[2].text, "Variable 'a1' already defined on column 7"); @@ -1618,7 +1618,7 @@ _ = math.random() < 0.5 and 0 or 42 _ = (math.random() < 0.5 and false) or 42 -- currently ignored )"); - REQUIRE_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "The and-or expression always evaluates to the second alternative because the first alternative is false; " "consider using if-then-else expression instead"); CHECK_EQ(result.warnings[1].text, "The and-or expression always evaluates to the second alternative because the first alternative is nil; " @@ -1640,7 +1640,7 @@ do end --!nolint )"); - REQUIRE_EQ(result.warnings.size(), 6); + REQUIRE(6 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Unknown comment directive 'struct'; did you mean 'strict'?"); CHECK_EQ(result.warnings[1].text, "Unknown comment directive 'nolintGlobal'"); CHECK_EQ(result.warnings[2].text, "nolint directive refers to unknown lint rule 'Global'"); @@ -1656,7 +1656,7 @@ TEST_CASE_FIXTURE(Fixture, "WrongCommentMuteSelf") --!struct )"); - REQUIRE_EQ(result.warnings.size(), 0); // --!nolint disables WrongComment lint :) + REQUIRE(0 == result.warnings.size()); // --!nolint disables WrongComment lint :) } TEST_CASE_FIXTURE(Fixture, "DuplicateConditionsIfStatAndExpr") @@ -1668,7 +1668,7 @@ elseif if 0 then 5 else 4 then end )"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 2"); } @@ -1681,13 +1681,13 @@ TEST_CASE_FIXTURE(Fixture, "WrongCommentOptimize") --!optimize 2 )"); - REQUIRE_EQ(result.warnings.size(), 3); + REQUIRE(3 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "optimize directive requires an optimization level"); CHECK_EQ(result.warnings[1].text, "optimize directive uses unknown optimization level 'me', 0..2 expected"); CHECK_EQ(result.warnings[2].text, "optimize directive uses unknown optimization level '100500', 0..2 expected"); result = lint("--!optimize "); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "optimize directive requires an optimization level"); } @@ -1700,7 +1700,7 @@ TEST_CASE_FIXTURE(Fixture, "TestStringInterpolation") local _ = `unknown {foo}` )"); - REQUIRE_EQ(result.warnings.size(), 1); + REQUIRE(1 == result.warnings.size()); } TEST_CASE_FIXTURE(Fixture, "IntegerParsing") @@ -1710,7 +1710,7 @@ local _ = 0b10000000000000000000000000000000000000000000000000000000000000000 local _ = 0x10000000000000000 )"); - REQUIRE_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Binary number literal exceeded available precision and has been truncated to 2^64"); CHECK_EQ(result.warnings[1].text, "Hexadecimal number literal exceeded available precision and has been truncated to 2^64"); } @@ -1725,7 +1725,7 @@ local _ = 0x0x123 local _ = 0x0xffffffffffffffffffffffffffffffffff )"); - REQUIRE_EQ(result.warnings.size(), 2); + REQUIRE(2 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "Hexadecimal number literal has a double prefix, which will fail to parse in the future; remove the extra 0x to fix"); CHECK_EQ(result.warnings[1].text, @@ -1756,7 +1756,7 @@ local _ = (a <= b) == 0 local _ = a <= (b == 0) )"); - REQUIRE_EQ(result.warnings.size(), 5); + REQUIRE(5 == result.warnings.size()); CHECK_EQ(result.warnings[0].text, "not X == Y is equivalent to (not X) == Y; consider using X ~= Y, or add parentheses to silence"); CHECK_EQ(result.warnings[1].text, "not X ~= Y is equivalent to (not X) ~= Y; consider using X == Y, or add parentheses to silence"); CHECK_EQ(result.warnings[2].text, "not X <= Y is equivalent to (not X) <= Y; add parentheses to silence"); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 31df707..156cbbc 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -356,6 +356,11 @@ TEST_CASE_FIXTURE(NormalizeFixture, "table_with_any_prop") TEST_CASE_FIXTURE(NormalizeFixture, "intersection") { + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + check(R"( local a: number & string local b: number @@ -374,8 +379,9 @@ TEST_CASE_FIXTURE(NormalizeFixture, "intersection") CHECK(!isSubtype(c, a)); CHECK(isSubtype(a, c)); - CHECK(!isSubtype(d, a)); - CHECK(!isSubtype(a, d)); + // These types are both equivalent to never + CHECK(isSubtype(d, a)); + CHECK(isSubtype(a, d)); } TEST_CASE_FIXTURE(NormalizeFixture, "union_and_intersection") diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index b4064cf..662c290 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2722,4 +2722,59 @@ TEST_CASE_FIXTURE(Fixture, "error_message_for_using_function_as_type_annotation" result.errors[0].getMessage()); } +TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the_end_of_a_function_argument_list") +{ + ScopedFastFlag sff{"LuauCommaParenWarnings", true}; + + ParseResult result = tryParse(R"( + foo(a, b, c,) + )"); + + REQUIRE(1 == result.errors.size()); + + CHECK(Location({1, 20}, {1, 21}) == result.errors[0].getLocation()); + CHECK("Expected expression after ',' but got ')' instead" == result.errors[0].getMessage()); +} + +TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the_end_of_a_function_parameter_list") +{ + ScopedFastFlag sff{"LuauCommaParenWarnings", true}; + + ParseResult result = tryParse(R"( + export type VisitFn = ( + any, + Array>, -- extra comma here + ) -> any + )"); + + REQUIRE(1 == result.errors.size()); + + CHECK(Location({4, 8}, {4, 9}) == result.errors[0].getLocation()); + CHECK("Expected type after ',' but got ')' instead" == result.errors[0].getMessage()); +} + +TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_an_extra_comma_at_the_end_of_a_generic_parameter_list") +{ + ScopedFastFlag sff{"LuauCommaParenWarnings", true}; + + ParseResult result = tryParse(R"( + export type VisitFn = (a: A, b: B) -> () + )"); + + REQUIRE(1 == result.errors.size()); + + CHECK(Location({1, 36}, {1, 37}) == result.errors[0].getLocation()); + CHECK("Expected type after ',' but got '>' instead" == result.errors[0].getMessage()); + + REQUIRE(1 == result.root->body.size); + + AstStatTypeAlias* t = result.root->body.data[0]->as(); + REQUIRE(t != nullptr); + + AstTypeFunction* f = t->type->as(); + REQUIRE(f != nullptr); + + CHECK(2 == f->generics.size); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 98883df..87ec58c 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -480,4 +480,38 @@ local a: ChildClass = i CHECK_EQ("Type 'ChildClass' from 'Test' could not be converted into 'ChildClass' from 'MainModule'", toString(result.errors[0])); } +TEST_CASE_FIXTURE(ClassFixture, "intersections_of_unions_of_classes") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (BaseClass | Vector2) & (ChildClass | AnotherChild) + local y : (ChildClass | AnotherChild) + x = y + y = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "unions_of_intersections_of_classes") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (BaseClass & ChildClass) | (BaseClass & AnotherChild) | (BaseClass & Vector2) + local y : (ChildClass | AnotherChild) + x = y + y = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index a4420b9..fa99ff5 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -2,6 +2,7 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/Error.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -14,6 +15,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(LuauInstantiateInSubtyping); LUAU_FASTFLAG(LuauSpecialTypesAsterisked); TEST_SUITE_BEGIN("TypeInferFunctions"); @@ -1087,10 +1089,20 @@ f(function(a, b, c, ...) return a + b end) LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' + if (FFlag::LuauInstantiateInSubtyping) + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' caused by: Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); + toString(result.errors[0])); + } + else + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + } // Infer from variadic packs into elements result = check(R"( @@ -1189,10 +1201,20 @@ f(function(a, b, c, ...) return a + b end) LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' + if (FFlag::LuauInstantiateInSubtyping) + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' caused by: Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); + toString(result.errors[0])); + } + else + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + } // Infer from variadic packs into elements result = check(R"( diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 3c86777..1b02abc 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -10,6 +10,7 @@ #include "doctest.h" LUAU_FASTFLAG(LuauCheckGenericHOFTypes) +LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauSpecialTypesAsterisked) using namespace Luau; @@ -960,7 +961,11 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("((number) -> number, string) -> number", toString(tm->wantedType)); - CHECK_EQ("((number) -> number, number) -> number", toString(tm->givenType)); + if (FFlag::LuauInstantiateInSubtyping) + CHECK_EQ("((a) -> (b...), a) -> (b...)", toString(tm->givenType)); + else + CHECK_EQ("((number) -> number, number) -> number", toString(tm->givenType)); + } TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") @@ -980,7 +985,10 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("(string, string) -> number", toString(tm->wantedType)); - CHECK_EQ("((string) -> number, string) -> number", toString(*tm->givenType)); + if (FFlag::LuauInstantiateInSubtyping) + CHECK_EQ("((a) -> (b...), a) -> (b...)", toString(tm->givenType)); + else + CHECK_EQ("((string) -> number, string) -> number", toString(*tm->givenType)); } TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") @@ -1110,6 +1118,15 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i { LUAU_REQUIRE_NO_ERRORS(result); } + else if (FFlag::LuauInstantiateInSubtyping) + { + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ( + R"(Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a' +caused by: + Argument #1 type is not compatible. Generic subtype escaping scope)", + toString(result.errors[0])); + } else { LUAU_REQUIRE_ERRORS(result); @@ -1219,4 +1236,48 @@ TEST_CASE_FIXTURE(Fixture, "do_not_always_instantiate_generic_intersection_types LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "hof_subtype_instantiation_regression") +{ + CheckResult result = check(R"( +--!strict + +local function defaultSort(a: T, b: T) + return true +end +type A = any +return function(array: {T}): {T} + table.sort(array, defaultSort) + return array +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "higher_rank_polymorphism_should_not_accept_instantiated_arguments") +{ + ScopedFastFlag sffs[] = { + {"LuauInstantiateInSubtyping", true}, + {"LuauCheckGenericHOFTypes", true}, // necessary because of interactions with the test + }; + + CheckResult result = check(R"( +--!strict + +local function instantiate(f: (a) -> a): (number) -> number + return f +end + +instantiate(function(x: string) return "foo" end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto tm1 = get(result.errors[0]); + REQUIRE(tm1); + + CHECK_EQ("(a) -> a", toString(tm1->wantedType)); + CHECK_EQ("(string) -> string", toString(tm1->givenType)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 818d012..e49df10 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -446,4 +446,459 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection") LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "intersect_bool_and_false") +{ + CheckResult result = check(R"( + local x : (boolean & false) + local y : false = x -- OK + local z : true = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type 'boolean & false' could not be converted into 'true'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersect_false_and_bool_and_false") +{ + CheckResult result = check(R"( + local x : false & (boolean & false) + local y : false = x -- OK + local z : true = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + // TODO: odd stringification of `false & (boolean & false)`.) + CHECK_EQ(toString(result.errors[0]), "Type 'boolean & false & false' could not be converted into 'true'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersect_saturate_overloaded_functions") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : ((number?) -> number?) & ((string?) -> string?) + local y : (nil) -> nil = x -- OK + local z : (number) -> number = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> number?) & ((string?) -> string?)' could not be converted into '(number) -> number'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_saturate_overloaded_functions") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : ((number) -> number) & ((string) -> string) + local y : ((number | string) -> (number | string)) = x -- OK + local z : ((number | boolean) -> (number | boolean)) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number) -> number) & ((string) -> string)' could not be converted into '(boolean | number) -> boolean | number'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_of_tables") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : { p : number?, q : string? } & { p : number?, q : number?, r : number? } + local y : { p : number?, q : nil, r : number? } = x -- OK + local z : { p : nil } = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' could not be converted into '{| p: nil |}'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") +{ + CheckResult result = check(R"( + local x : { p : number?, q : any } & { p : unknown, q : string? } + local y : { p : number?, q : string? } = x -- OK + local z : { p : string?, q : number? } = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into '{| p: string?, q: number? |}'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : { p : number?, q : never } & { p : never, q : string? } + local y : { p : never, q : never } = x -- OK + local z : never = x -- OK + )"); + + // TODO: this should not produce type errors, since never <: { p : never } + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '{| p: never, q: string? |} & {| p: number?, q: never |}' could not be converted into 'never'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_functions_returning_intersections") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : ((number?) -> ({ p : number } & { q : number })) & ((string?) -> ({ p : number } & { r : number })) + local y : (nil) -> { p : number, q : number, r : number} = x -- OK + local z : (number?) -> { p : number, q : number, r : number} = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' could not be converted into '(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number?) -> (a | number)) & ((string?) -> (a | string)) + local y : (nil) -> a = x -- OK + local z : (number?) -> a = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> a | number) & ((string?) -> a | string)' could not be converted into '(number?) -> a'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generics") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((a?) -> (a | b)) & ((c?) -> (b | c)) + local y : (nil) -> ((a & c) | b) = x -- OK + local z : (a?) -> ((a & c) | b) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((a?) -> a | b) & ((c?) -> b | c)' could not be converted into '(a?) -> (a & c) | b'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic_packs") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...)) + local y : ((nil, a...) -> (nil, b...)) = x -- OK + local z : ((nil, b...) -> (nil, a...)) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))' could not be converted into '(nil, b...) -> (nil, a...)'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_result") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number) -> number) & ((nil) -> unknown) + local y : (number?) -> unknown = x -- OK + local z : (number?) -> number? = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((nil) -> unknown) & ((number) -> number)' could not be converted into '(number?) -> number?'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_arguments") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number) -> number?) & ((unknown) -> string?) + local y : (number) -> nil = x -- OK + local z : (number?) -> nil = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number) -> number?) & ((unknown) -> string?)' could not be converted into '(number?) -> nil'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_result") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number) -> number) & ((nil) -> never) + local y : (number?) -> number = x -- OK + local z : (number?) -> never = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((nil) -> never) & ((number) -> number)' could not be converted into '(number?) -> never'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_arguments") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : ((number) -> number?) & ((never) -> string?) + local y : (never) -> nil = x -- OK + local z : (number?) -> nil = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((never) -> string?) & ((number) -> number?)' could not be converted into '(number?) -> nil'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_overlapping_results_and_variadics") +{ + CheckResult result = check(R"( + local x : ((string?) -> (string | number)) & ((number?) -> ...number) + local y : ((nil) -> (number, number?)) = x -- OK + local z : ((string | number) -> (number, number?)) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> (...number)) & ((string?) -> number | string)' could not be converted into '(number | string) -> (number, number?)'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_1") +{ + CheckResult result = check(R"( + function f() + local x : (() -> a...) & (() -> b...) + local y : (() -> b...) & (() -> a...) = x -- OK + local z : () -> () = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(() -> (a...)) & (() -> (b...))' could not be converted into '() -> ()'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_2") +{ + CheckResult result = check(R"( + function f() + local x : ((a...) -> ()) & ((b...) -> ()) + local y : ((b...) -> ()) & ((a...) -> ()) = x -- OK + local z : () -> () = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((a...) -> ()) & ((b...) -> ())' could not be converted into '() -> ()'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_3") +{ + CheckResult result = check(R"( + function f() + local x : (() -> a...) & (() -> (number?,a...)) + local y : (() -> (number?,a...)) & (() -> a...) = x -- OK + local z : () -> (number) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(() -> (a...)) & (() -> (number?, a...))' could not be converted into '() -> number'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_4") +{ + CheckResult result = check(R"( + function f() + local x : ((a...) -> ()) & ((number,a...) -> number) + local y : ((number,a...) -> number) & ((a...) -> ()) = x -- OK + local z : (number?) -> () = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '((a...) -> ()) & ((number, a...) -> number)' could not be converted into '(number?) -> ()'; none of the intersection parts are compatible"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local a : string? = nil + local b : number? = nil + + local x = setmetatable({}, { p = 5, q = a }); + local y = setmetatable({}, { q = b, r = "hi" }); + local z = setmetatable({}, { p = 5, q = nil, r = "hi" }); + + type X = typeof(x) + type Y = typeof(y) + type Z = typeof(z) + + local xy : X&Y = z; + local yx : Y&X = z; + z = xy; + z = yx; + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_subtypes") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x = setmetatable({ a = 5 }, { p = 5 }); + local y = setmetatable({ b = "hi" }, { p = 5, q = "hi" }); + local z = setmetatable({ a = 5, b = "hi" }, { p = 5, q = "hi" }); + + type X = typeof(x) + type Y = typeof(y) + type Z = typeof(z) + + local xy : X&Y = z; + local yx : Y&X = z; + z = xy; + z = yx; + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + + +TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables_with_properties") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x = setmetatable({ a = 5 }, { p = 5 }); + local y = setmetatable({ b = "hi" }, { q = "hi" }); + local z = setmetatable({ a = 5, b = "hi" }, { p = 5, q = "hi" }); + + type X = typeof(x) + type Y = typeof(y) + type Z = typeof(z) + + local xy : X&Y = z; + z = xy; + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatable_with table") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x = setmetatable({ a = 5 }, { p = 5 }); + local z = setmetatable({ a = 5, b = "hi" }, { p = 5 }); + + type X = typeof(x) + type Y = { b : string } + type Z = typeof(z) + + -- TODO: once we have shape types, we should be able to initialize these with z + local xy : X&Y; + local yx : Y&X; + z = xy; + z = yx; + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "CLI-44817") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + type X = {x: number} + type Y = {y: number} + type Z = {z: number} + + type XY = {x: number, y: number} + type XYZ = {x:number, y: number, z: number} + + local xy: XY = {x = 0, y = 0} + local xyz: XYZ = {x = 0, y = 0, z = 0} + + local xNy: X&Y = xy + local xNyNz: X&Y&Z = xyz + + local t1: XY = xNy -- Type 'X & Y' could not be converted into 'XY' + local t2: XY = xNyNz -- Type 'X & Y & Z' could not be converted into 'XY' + local t3: XYZ = xNyNz -- Type 'X & Y & Z' could not be converted into 'XYZ' + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index ede84f4..36943ca 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -10,6 +10,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauInstantiateInSubtyping) + using namespace Luau; LUAU_FASTFLAG(LuauSpecialTypesAsterisked) @@ -248,7 +250,24 @@ end return m )"); - LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauInstantiateInSubtyping) + { + // though this didn't error before the flag, it seems as though it should error since fields of a table are invariant. + // the user's intent would likely be that these "method" fields would be read-only, but without an annotation, accepting this should be unsound. + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(R"(Type 'n' could not be converted into 't1 where t1 = {- Clone: (t1) -> (a...) -}' +caused by: + Property 'Clone' is not compatible. Type '(a) -> ()' could not be converted into 't1 where t1 = ({- Clone: t1 -}) -> (a...)'; different number of generic type parameters)", + toString(result.errors[0])); + } + else + { + LUAU_REQUIRE_NO_ERRORS(result); + } + } TEST_CASE_FIXTURE(BuiltinsFixture, "custom_require_global") @@ -367,8 +386,6 @@ type Table = typeof(tbl) TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_5") { - ScopedFastFlag luauInplaceDemoteSkipAllBound{"LuauInplaceDemoteSkipAllBound", true}; - fileResolver.source["game/A"] = R"( export type Type = {x: number, y: number} local arrayops = {} diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 45740a0..2aac665 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -271,30 +271,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "bail_early_if_unification_is_too_complicated } } -// Should be in TypeInfer.tables.test.cpp -// It's unsound to instantiate tables containing generic methods, -// since mutating properties means table properties should be invariant. -// We currently allow this but we shouldn't! -TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_tables_in_call_is_unsound") -{ - CheckResult result = check(R"( - --!strict - local t = {} - function t.m(x) return x end - local a : string = t.m("hi") - local b : number = t.m(5) - function f(x : { m : (number)->number }) - x.m = function(x) return 1+x end - end - f(t) -- This shouldn't typecheck - local c : string = t.m("hi") - )"); - - // TODO: this should error! - // This should be fixed by replacing generic tables by generics with type bounds. - LUAU_REQUIRE_NO_ERRORS(result); -} - // FIXME: Move this test to another source file when removing FFlag::LuauLowerBoundsCalculation TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack") { @@ -608,7 +584,8 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") InternalErrorReporter iceHandler; UnifierSharedState sharedState{&iceHandler}; - Unifier u{&arena, singletonTypes, Mode::Strict, NotNull{scope.get()}, Location{}, Variance::Covariant, sharedState}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, NotNull{scope.get()}, Location{}, Variance::Covariant}; u.tryUnify(option1, option2); @@ -635,4 +612,87 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators") LUAU_REQUIRE_NO_ERRORS(result); } +// Ideally, we would not try to export a function type with generic types from incorrect scope +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface") +{ + ScopedFastFlag LuauAnyifyModuleReturnGenerics{"LuauAnyifyModuleReturnGenerics", true}; + + fileResolver.source["game/A"] = R"( +local wrapStrictTable + +local metatable = { + __index = function(self, key) + local value = self.__tbl[key] + if type(value) == "table" then + -- unification of the free 'wrapStrictTable' with this function type causes generics of this function to leak out of scope + return wrapStrictTable(value, self.__name .. "." .. key) + end + return value + end, +} + +return wrapStrictTable + )"; + + frontend.check("game/A"); + + fileResolver.source["game/B"] = R"( +local wrapStrictTable = require(game.A) + +local Constants = {} + +return wrapStrictTable(Constants, "Constants") + )"; + + frontend.check("game/B"); + + ModulePtr m = frontend.moduleResolver.modules["game/B"]; + REQUIRE(m); + + std::optional result = first(m->getModuleScope()->returnType); + REQUIRE(result); + CHECK(get(*result)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface_variadic") +{ + ScopedFastFlag LuauAnyifyModuleReturnGenerics{"LuauAnyifyModuleReturnGenerics", true}; + + fileResolver.source["game/A"] = R"( +local wrapStrictTable + +local metatable = { + __index = function(self, key, ...: T) + local value = self.__tbl[key] + if type(value) == "table" then + -- unification of the free 'wrapStrictTable' with this function type causes generics of this function to leak out of scope + return wrapStrictTable(value, self.__name .. "." .. key) + end + return ... + end, +} + +return wrapStrictTable + )"; + + frontend.check("game/A"); + + fileResolver.source["game/B"] = R"( +local wrapStrictTable = require(game.A) + +local Constants = {} + +return wrapStrictTable(Constants, "Constants") + )"; + + frontend.check("game/B"); + + ModulePtr m = frontend.moduleResolver.modules["game/B"]; + REQUIRE(m); + + std::optional result = first(m->getModuleScope()->returnType); + REQUIRE(result); + CHECK(get(*result)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index d183f65..a6d870f 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -11,7 +11,8 @@ using namespace Luau; -LUAU_FASTFLAG(LuauLowerBoundsCalculation); +LUAU_FASTFLAG(LuauLowerBoundsCalculation) +LUAU_FASTFLAG(LuauInstantiateInSubtyping) TEST_SUITE_BEGIN("TableTests"); @@ -2038,11 +2039,22 @@ caused by: caused by: Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' + if (FFlag::LuauInstantiateInSubtyping) + { + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' +caused by: + Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' +caused by: + Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); + } + else + { + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); + } } TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_key") @@ -3173,4 +3185,53 @@ caused by: CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); } +TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_tables_in_call_is_unsound") +{ + ScopedFastFlag sff[]{ + {"LuauInstantiateInSubtyping", true}, + }; + + CheckResult result = check(R"( + --!strict + local t = {} + function t.m(x) return x end + local a : string = t.m("hi") + local b : number = t.m(5) + function f(x : { m : (number)->number }) + x.m = function(x) return 1+x end + end + f(t) -- This shouldn't typecheck + local c : string = t.m("hi") + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), R"(Type 't' could not be converted into '{| m: (number) -> number |}' +caused by: + Property 'm' is not compatible. Type '(a) -> a' could not be converted into '(number) -> number'; different number of generic type parameters)"); + // this error message is not great since the underlying issue is that the context is invariant, + // and `(number) -> number` cannot be a subtype of `(a) -> a`. +} + + +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_table_instantiation_potential_regression") +{ + CheckResult result = check(R"( +--!strict + +function f(x) + x.p = 5 + return x +end +local g : ({ p : number, q : string }) -> ({ p : number, r : boolean }) = f + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + MissingProperties* error = get(result.errors[0]); + REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); + + CHECK_EQ("r", error->properties[0]); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 8ed61b4..26171c5 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -17,7 +17,9 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauInstantiateInSubtyping); LUAU_FASTFLAG(LuauSpecialTypesAsterisked); +LUAU_FASTFLAG(LuauCheckGenericHOFTypes); using namespace Luau; @@ -999,7 +1001,26 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") end )"); - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauInstantiateInSubtyping && !FFlag::LuauCheckGenericHOFTypes) + { + // though this didn't error before the flag, it seems as though it should error since fields of a table are invariant. + // the user's intent would likely be that these "method" fields would be read-only, but without an annotation, accepting this should be unsound. + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(R"(Type 't1 where t1 = {+ getStoreFieldName: (t1, {| fieldName: string |} & {| from: number? |}) -> (a, b...) +}' could not be converted into 'Policies' +caused by: + Property 'getStoreFieldName' is not compatible. Type 't1 where t1 = ({+ getStoreFieldName: t1 +}, {| fieldName: string |} & {| from: number? |}) -> (a, b...)' could not be converted into '(Policies, FieldSpecifier) -> string' +caused by: + Argument #2 type is not compatible. Type 'FieldSpecifier' could not be converted into 'FieldSpecifier & {| from: number? |}' +caused by: + Not all intersection parts are compatible. Table type 'FieldSpecifier' not compatible with type '{| from: number? |}' because the former has extra field 'fieldName')", + toString(result.errors[0])); + } + else + { + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") @@ -1020,6 +1041,43 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_no_ice") CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer") +{ + ScopedFastInt sfi("LuauTypeInferRecursionLimit", 10); + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + {"LuauAutocompleteDynamicLimits", true}, + }; + + CheckResult result = check(R"( + function f() + local x : a&b&c&d&e&f&g&h&(i?) + local y : (a&b&c&d&e&f&g&h&i)? = x + end + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Internal error: Code is too complex to typecheck! Consider adding type annotations around this area", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "type_infer_cache_limit_normalizer") +{ + ScopedFastInt sfi("LuauNormalizeCacheLimit", 10); + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : ((number) -> number) & ((string) -> string) & ((nil) -> nil) & (({}) -> {}) + local y : (number | string | nil | {}) -> (number | string | nil | {}) = x + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Internal error: Code is too complex to typecheck! Consider adding type annotations around this area", toString(result.errors[0])); +} + TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 3911c52..dedb7d2 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -17,8 +17,8 @@ struct TryUnifyFixture : Fixture ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; - - Unifier state{&arena, singletonTypes, Mode::Strict, NotNull{globalScope.get()}, Location{}, Variance::Covariant, unifierState}; + Normalizer normalizer{&arena, singletonTypes, NotNull{&unifierState}}; + Unifier state{NotNull{&normalizer}, Mode::Strict, NotNull{globalScope.get()}, Location{}, Variance::Covariant}; }; TEST_SUITE_BEGIN("TryUnifyTests"); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 7d33809..eb61c39 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -1000,4 +1000,23 @@ TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments_free") CHECK_EQ(toString(result.errors[0]), "Type 'number' could not be converted into 'boolean'"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "type_packs_with_tails_in_vararg_adjustment") +{ + ScopedFastFlag luauFixVarargExprHeadType{"LuauFixVarargExprHeadType", true}; + + CheckResult result = check(R"( + local function wrapReject(fn: (self: any, ...TArg) -> ...TResult): (self: any, ...TArg) -> ...TResult + return function(self, ...) + local arguments = { ... } + local ok, result = pcall(function() + return fn(self, table.unpack(arguments)) + end) + return result + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 8eb485e..64c9b56 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -541,5 +541,182 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect") R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); } +TEST_CASE_FIXTURE(Fixture, "union_true_and_false") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : boolean + local y1 : (true | false) = x -- OK + local y2 : (true | false | (string & number)) = x -- OK + local y3 : (true | (string & number) | false) = x -- OK + local y4 : (true | (boolean & true) | false) = x -- OK + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (number) -> number? + local y : ((number?) -> number?) | ((number) -> number) = x -- OK + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_generic_functions") +{ + CheckResult result = check(R"( + local x : (a) -> a? + local y : ((a?) -> a?) | ((b) -> b) = x -- Not OK + )"); + + // TODO: should this example typecheck? + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_generic_typepack_functions") +{ + CheckResult result = check(R"( + local x : (number, a...) -> (number?, a...) + local y : ((number?, a...) -> (number?, a...)) | ((number, b...) -> (number, b...)) = x -- Not OK + )"); + + // TODO: should this example typecheck? + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generics") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : (a) -> a? + local y : ((a?) -> nil) | ((a) -> a) = x -- OK + local z : ((b?) -> nil) | ((b) -> b) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(a) -> a?' could not be converted into '((b) -> b) | ((b?) -> nil)'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generic_typepacks") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + function f() + local x : (number, a...) -> (number?, a...) + local y : ((number | string, a...) -> (number, a...)) | ((number?, a...) -> (nil, a...)) = x -- OK + local z : ((number) -> number) | ((number?, a...) -> (number?, a...)) = x -- Not OK + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(number, a...) -> (number?, a...)' could not be converted into '((number) -> number) | ((number?, a...) -> (number?, a...))'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_arities") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (number) -> number? + local y : ((number?) -> number) | ((number | string) -> nil) = x -- OK + local z : ((number, string?) -> number) | ((number) -> nil) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(number) -> number?' could not be converted into '((number) -> nil) | ((number, string?) -> number)'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_arities") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : () -> (number | string) + local y : (() -> number) | (() -> string) = x -- OK + local z : (() -> number) | (() -> (string, string)) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '() -> number | string' could not be converted into '(() -> (string, string)) | (() -> number)'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_variadics") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (...nil) -> (...number?) + local y : ((...string?) -> (...number)) | ((...number?) -> nil) = x -- OK + local z : ((...string?) -> (...number)) | ((...string?) -> nil) = x -- OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(...nil) -> (...number?)' could not be converted into '((...string?) -> (...number)) | ((...string?) -> nil)'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_variadics") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : (number) -> () + local y : ((number?) -> ()) | ((...number) -> ()) = x -- OK + local z : ((number?) -> ()) | ((...number?) -> ()) = x -- Not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '(number) -> ()' could not be converted into '((...number?) -> ()) | ((number?) -> ())'; none of the union options are compatible"); +} + +TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_variadics") +{ + ScopedFastFlag sffs[] { + {"LuauSubtypeNormalizer", true}, + {"LuauTypeNormalization2", true}, + }; + + CheckResult result = check(R"( + local x : () -> (number?, ...number) + local y : (() -> (...number)) | (() -> nil) = x -- OK + local z : (() -> (...number)) | (() -> number) = x -- OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '() -> (number?, ...number)' could not be converted into '(() -> (...number)) | (() -> number)'; none of the union options are compatible"); +} TEST_SUITE_END(); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 9c19da5..b09b087 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -508,6 +508,9 @@ assert((function() function cmp(a,b) return ab,a>=b end return concat( assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('abc', 'abd')) end)() == "true,true,false,false") assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('ab\\0c', 'ab\\0d')) end)() == "true,true,false,false") assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('ab\\0c', 'ab\\0')) end)() == "false,false,true,true") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('\\0a', '\\0b')) end)() == "true,true,false,false") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('a', 'a\\0')) end)() == "true,true,false,false") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('a', '\200')) end)() == "true,true,false,false") -- array access assert((function() local a = {4,5,6} return a[3] end)() == 6) diff --git a/tests/main.cpp b/tests/main.cpp index 3e480c9..3f564c0 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -21,6 +21,7 @@ #include #endif +#include #include // Indicates if verbose output is enabled; can be overridden via --verbose @@ -30,6 +31,10 @@ bool verbose = false; // Default optimization level for conformance test; can be overridden via -On int optimizationLevel = 1; +// Something to seed a pseudorandom number generator with. Defaults to +// something from std::random_device. +std::optional randomSeed; + static bool skipFastFlag(const char* flagName) { if (strncmp(flagName, "Test", 4) == 0) @@ -261,6 +266,16 @@ int main(int argc, char** argv) optimizationLevel = level; } + int rseed = -1; + if (doctest::parseIntOption(argc, argv, "--random-seed=", doctest::option_int, rseed)) + randomSeed = unsigned(rseed); + + if (doctest::parseOption(argc, argv, "--randomize") && !randomSeed) + { + randomSeed = std::random_device()(); + printf("Using RNG seed %u\n", *randomSeed); + } + if (std::vector flags; doctest::parseCommaSepArgs(argc, argv, "--fflags=", flags)) setFastFlags(flags); @@ -295,6 +310,8 @@ int main(int argc, char** argv) printf(" --verbose Enables verbose output (e.g. lua 'print' statements)\n"); printf(" --fflags= Sets specified fast flags\n"); printf(" --list-fflags List all fast flags\n"); + printf(" --randomize Use a random RNG seed\n"); + printf(" --random-seed=n Use a particular RNG seed\n"); } return result; } diff --git a/tools/lldb_formatters.lldb b/tools/lldb_formatters.lldb index 4a5acd7..f10faa9 100644 --- a/tools/lldb_formatters.lldb +++ b/tools/lldb_formatters.lldb @@ -5,3 +5,6 @@ type synthetic add -x "^Luau::Variant<.+>$" -l lldb_formatters.LuauVariantSynthe type summary add -x "^Luau::Variant<.+>$" -F lldb_formatters.luau_variant_summary type synthetic add -x "^Luau::AstArray<.+>$" -l lldb_formatters.AstArraySyntheticChildrenProvider + +type summary add --summary-string "${var.line}:${var.column}" Luau::Position +type summary add --summary-string "${var.begin}-${var.end}" Luau::Location diff --git a/tools/natvis/Ast.natvis b/tools/natvis/Ast.natvis index 322eb8f..18e7b76 100644 --- a/tools/natvis/Ast.natvis +++ b/tools/natvis/Ast.natvis @@ -22,4 +22,25 @@ + + {value,na} + + + + {name.value,na} + + + + local {local->name.value,na} + global {global.value,na} + + + + {line}:{column} + + + + {begin}-{end} + + diff --git a/tools/test_dcr.py b/tools/test_dcr.py index 1e3a501..5f1c870 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -22,6 +22,10 @@ def safeParseInt(i, default=0): return default +def makeDottedName(path): + return ".".join(path) + + class Handler(x.ContentHandler): def __init__(self, failList): self.currentTest = [] @@ -41,7 +45,7 @@ class Handler(x.ContentHandler): if self.currentTest: passed = attrs["test_case_success"] == "true" - dottedName = ".".join(self.currentTest) + dottedName = makeDottedName(self.currentTest) # Sometimes we get multiple XML trees for the same test. All of # them must report a pass in order for us to consider the test @@ -60,6 +64,10 @@ class Handler(x.ContentHandler): self.currentTest.pop() +def print_stderr(*args, **kw): + print(*args, **kw, file=sys.stderr) + + def main(): parser = argparse.ArgumentParser( description="Run Luau.UnitTest with deferred constraint resolution enabled" @@ -80,6 +88,16 @@ def main(): help="Write a new faillist.txt after running tests.", ) + parser.add_argument("--randomize", action="store_true", help="Pick a random seed") + + parser.add_argument( + "--random-seed", + action="store", + dest="random_seed", + type=int, + help="Accept a specific RNG seed", + ) + args = parser.parse_args() failList = loadFailList() @@ -90,7 +108,12 @@ def main(): "--fflags=true,DebugLuauDeferredConstraintResolution=true", ] - print('>', ' '.join(commandLine), file=sys.stderr) + if args.random_seed: + commandLine.append("--random-seed=" + str(args.random_seed)) + elif args.randomize: + commandLine.append("--randomize") + + print_stderr(">", " ".join(commandLine)) p = sp.Popen( commandLine, @@ -104,15 +127,21 @@ def main(): sys.stdout.buffer.write(line) return else: - x.parse(p.stdout, handler) + try: + x.parse(p.stdout, handler) + except x.SAXParseException as e: + print_stderr( + f"XML parsing failed during test {makeDottedName(handler.currentTest)}. That probably means that the test crashed" + ) + sys.exit(1) p.wait() for testName, passed in handler.results.items(): if passed and testName in failList: - print("UNEXPECTED: {} should have failed".format(testName)) + print_stderr(f"UNEXPECTED: {testName} should have failed") elif not passed and testName not in failList: - print("UNEXPECTED: {} should have passed".format(testName)) + print_stderr(f"UNEXPECTED: {testName} should have passed") if args.write: newFailList = sorted( @@ -126,14 +155,11 @@ def main(): with open(FAIL_LIST_PATH, "w", newline="\n") as f: for name in newFailList: print(name, file=f) - print("Updated faillist.txt", file=sys.stderr) + print_stderr("Updated faillist.txt") if handler.numSkippedTests > 0: - print( - "{} test(s) were skipped! That probably means that a test segfaulted!".format( - handler.numSkippedTests - ), - file=sys.stderr, + print_stderr( + f"{handler.numSkippedTests} test(s) were skipped! That probably means that a test segfaulted!" ) sys.exit(1) @@ -143,7 +169,7 @@ def main(): ) if ok: - print("Everything in order!", file=sys.stderr) + print_stderr("Everything in order!") sys.exit(0 if ok else 1)