From 721f6e10fbdb25909e020351dc7130393201da8d Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 19 May 2023 12:37:30 -0700 Subject: [PATCH] Sync to upstream/release/577 (#934) Lots of things going on this week: * Fix a crash that could occur in the presence of a cyclic union. We shouldn't be creating cyclic unions, but we shouldn't be crashing when they arise either. * Minor cleanup of `luau_precall` * Internal change to make L->top handling slightly more uniform * Optimize SETGLOBAL & GETGLOBAL fallback C functions. * https://github.com/Roblox/luau/pull/929 * The syntax to the `luau-reduce` commandline tool has changed. It now accepts a script, a command to execute, and an error to search for. It no longer automatically passes the script to the command which makes it a lot more flexible. Also be warned that it edits the script it is passed **in place**. Do not point it at something that is not in source control! New solver * Switch to a greedier but more fallible algorithm for simplifying union and intersection types that are created as part of refinement calculation. This has much better and more predictable performance. * Fix a constraint cycle in recursive function calls. * Much improved inference of binary addition. Functions like `function add(x, y) return x + y end` can now be inferred without annotations. We also accurately typecheck calls to functions like this. * Many small bugfixes surrounding things like table indexers * Add support for indexers on class types. This was previously added to the old solver; we now add it to the new one for feature parity. JIT * https://github.com/Roblox/luau/pull/931 * Fuse key.value and key.tt loads for CEHCK_SLOT_MATCH in A64 * Implement remaining aliases of BFM for A64 * Implement new callinfo flag for A64 * Add instruction simplification for int->num->int conversion chains * Don't even load execdata for X64 calls * Treat opcode fallbacks the same as manually written fallbacks --------- Co-authored-by: Arseny Kapoulkine Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/Constraint.h | 44 +- .../include/Luau/ConstraintGraphBuilder.h | 4 +- Analysis/include/Luau/ConstraintSolver.h | 28 +- Analysis/include/Luau/Module.h | 3 - Analysis/include/Luau/Normalize.h | 10 + Analysis/include/Luau/Simplify.h | 36 + Analysis/include/Luau/ToString.h | 5 +- Analysis/include/Luau/TxnLog.h | 6 + Analysis/include/Luau/Type.h | 6 +- Analysis/include/Luau/TypeFamily.h | 26 +- Analysis/include/Luau/TypeReduction.h | 85 - Analysis/src/Autocomplete.cpp | 1 - Analysis/src/Clone.cpp | 31 +- Analysis/src/ConstraintGraphBuilder.cpp | 212 ++- Analysis/src/ConstraintSolver.cpp | 256 ++- Analysis/src/Frontend.cpp | 18 +- Analysis/src/Instantiation.cpp | 4 +- Analysis/src/Module.cpp | 7 +- Analysis/src/Normalize.cpp | 57 +- Analysis/src/Quantify.cpp | 4 +- Analysis/src/Simplify.cpp | 1270 ++++++++++++++ Analysis/src/ToString.cpp | 10 + Analysis/src/TxnLog.cpp | 16 +- Analysis/src/TypeChecker2.cpp | 247 +-- Analysis/src/TypeFamily.cpp | 174 +- Analysis/src/TypeInfer.cpp | 4 +- Analysis/src/TypeReduction.cpp | 1200 ------------- Analysis/src/Unifier.cpp | 4 +- CLI/Reduce.cpp | 53 +- CodeGen/include/Luau/AssemblyBuilderA64.h | 8 +- CodeGen/include/luacodegen.h | 2 +- CodeGen/src/AssemblyBuilderA64.cpp | 38 +- CodeGen/src/CodeBlockUnwind.cpp | 1 + CodeGen/src/CodeGen.cpp | 101 +- CodeGen/src/CodeGenA64.cpp | 27 +- CodeGen/src/CodeGenUtils.cpp | 671 +++++++- CodeGen/src/CodeGenUtils.h | 12 + CodeGen/src/CustomExecUtils.h | 24 - CodeGen/src/EmitCommon.h | 9 +- CodeGen/src/EmitCommonX64.cpp | 6 +- CodeGen/src/EmitCommonX64.h | 6 +- CodeGen/src/EmitInstructionX64.cpp | 23 +- CodeGen/src/Fallbacks.cpp | 639 ------- CodeGen/src/Fallbacks.h | 24 - CodeGen/src/FallbacksProlog.h | 56 - CodeGen/src/IrLoweringA64.cpp | 76 +- CodeGen/src/IrLoweringX64.cpp | 24 +- CodeGen/src/NativeState.cpp | 38 +- CodeGen/src/NativeState.h | 28 +- CodeGen/src/OptimizeConstProp.cpp | 13 + Makefile | 1 + Sources.cmake | 9 +- VM/src/ldo.cpp | 11 +- VM/src/ldo.h | 2 +- VM/src/lfunc.cpp | 3 +- VM/src/lobject.h | 3 +- VM/src/lstate.h | 1 + VM/src/lvmexecute.cpp | 58 +- tests/AssemblyBuilderA64.test.cpp | 16 + tests/Autocomplete.test.cpp | 34 - tests/ClassFixture.cpp | 11 +- tests/ConstraintGraphBuilderFixture.cpp | 3 - tests/IrBuilder.test.cpp | 30 + tests/Module.test.cpp | 29 + tests/Normalize.test.cpp | 6 +- tests/Simplify.test.cpp | 508 ++++++ tests/ToString.test.cpp | 24 +- tests/TxnLog.test.cpp | 11 + tests/TypeFamily.test.cpp | 37 +- tests/TypeInfer.annotations.test.cpp | 12 + tests/TypeInfer.cfa.test.cpp | 5 +- tests/TypeInfer.classes.test.cpp | 13 +- tests/TypeInfer.functions.test.cpp | 20 +- tests/TypeInfer.intersectionTypes.test.cpp | 138 +- tests/TypeInfer.operators.test.cpp | 102 +- tests/TypeInfer.provisional.test.cpp | 59 +- tests/TypeInfer.refinements.test.cpp | 37 +- tests/TypeInfer.tables.test.cpp | 4 +- tests/TypeInfer.test.cpp | 5 +- tests/TypeInfer.typePacks.cpp | 10 + tests/TypeInfer.unionTypes.test.cpp | 51 +- tests/TypeReduction.test.cpp | 1509 ----------------- tests/TypeVar.test.cpp | 1 - tools/faillist.txt | 30 +- tools/lvmexecute_split.py | 112 -- 85 files changed, 4058 insertions(+), 4494 deletions(-) create mode 100644 Analysis/include/Luau/Simplify.h delete mode 100644 Analysis/include/Luau/TypeReduction.h create mode 100644 Analysis/src/Simplify.cpp delete mode 100644 Analysis/src/TypeReduction.cpp delete mode 100644 CodeGen/src/Fallbacks.cpp delete mode 100644 CodeGen/src/Fallbacks.h delete mode 100644 CodeGen/src/FallbacksProlog.h create mode 100644 tests/Simplify.test.cpp delete mode 100644 tests/TypeReduction.test.cpp delete mode 100644 tools/lvmexecute_split.py diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 3aa3c86..c815bef 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -144,6 +144,24 @@ struct HasPropConstraint TypeId resultType; TypeId subjectType; std::string prop; + + // HACK: We presently need types like true|false or string|"hello" when + // deciding whether a particular literal expression should have a singleton + // type. This boolean is set to true when extracting the property type of a + // value that may be a union of tables. + // + // For example, in the following code fragment, we want the lookup of the + // success property to yield true|false when extracting an expectedType in + // this expression: + // + // type Result = {success:true, result: T} | {success:false, error: E} + // + // local r: Result = {success=true, result=9} + // + // If we naively simplify the expectedType to boolean, we will erroneously + // compute the type boolean for the success property of the table literal. + // This causes type checking to fail. + bool suppressSimplification = false; }; // result ~ setProp subjectType ["prop", "prop2", ...] propType @@ -198,6 +216,24 @@ struct UnpackConstraint TypePackId sourcePack; }; +// resultType ~ refine type mode discriminant +// +// Compute type & discriminant (or type | discriminant) as soon as possible (but +// no sooner), simplify, and bind resultType to that type. +struct RefineConstraint +{ + enum + { + Intersection, + Union + } mode; + + TypeId resultType; + + TypeId type; + TypeId discriminant; +}; + // ty ~ reduce ty // // Try to reduce ty, if it is a TypeFamilyInstanceType. Otherwise, do nothing. @@ -214,10 +250,10 @@ struct ReducePackConstraint TypePackId tp; }; -using ConstraintV = - Variant; +using ConstraintV = Variant; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 5800d14..ababe0a 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -188,6 +188,7 @@ struct ConstraintGraphBuilder Inference check(const ScopePtr& scope, AstExprGlobal* global); Inference check(const ScopePtr& scope, AstExprIndexName* indexName); Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); + Inference check(const ScopePtr& scope, AstExprFunction* func, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprUnary* unary); Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); @@ -213,7 +214,8 @@ struct ConstraintGraphBuilder ScopePtr bodyScope; }; - FunctionSignature checkFunctionSignature(const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType = {}); + FunctionSignature checkFunctionSignature( + const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType = {}, std::optional originalName = {}); /** * Checks the body of a function expression. diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index f6b1aed..1a43a25 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -8,7 +8,6 @@ #include "Luau/Normalize.h" #include "Luau/ToString.h" #include "Luau/Type.h" -#include "Luau/TypeReduction.h" #include "Luau/Variant.h" #include @@ -121,6 +120,7 @@ struct ConstraintSolver bool tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force); bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); bool tryDispatch(const UnpackConstraint& c, NotNull constraint); + bool tryDispatch(const RefineConstraint& c, NotNull constraint, bool force); bool tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force); bool tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force); @@ -132,8 +132,10 @@ struct ConstraintSolver bool tryDispatchIterableFunction( TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); - std::pair, std::optional> lookupTableProp(TypeId subjectType, const std::string& propName); - std::pair, std::optional> lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen); + std::pair, std::optional> lookupTableProp( + TypeId subjectType, const std::string& propName, bool suppressSimplification = false); + std::pair, std::optional> lookupTableProp( + TypeId subjectType, const std::string& propName, bool suppressSimplification, std::unordered_set& seen); void block(NotNull target, NotNull constraint); /** @@ -143,6 +145,16 @@ struct ConstraintSolver bool block(TypeId target, NotNull constraint); bool block(TypePackId target, NotNull constraint); + // Block on every target + template + bool block(const T& targets, NotNull constraint) + { + for (TypeId target : targets) + block(target, constraint); + + return false; + } + /** * For all constraints that are blocked on one constraint, make them block * on a new constraint. @@ -151,15 +163,15 @@ struct ConstraintSolver */ void inheritBlocks(NotNull source, NotNull addition); - // Traverse the type. If any blocked or pending types are found, block - // the constraint on them. + // Traverse the type. If any pending types are found, block the constraint + // on them. // // Returns false if a type blocks the constraint. // // FIXME: This use of a boolean for the return result is an appalling // interface. - bool recursiveBlock(TypeId target, NotNull constraint); - bool recursiveBlock(TypePackId target, NotNull constraint); + bool blockOnPendingTypes(TypeId target, NotNull constraint); + bool blockOnPendingTypes(TypePackId target, NotNull constraint); void unblock(NotNull progressed); void unblock(TypeId progressed); @@ -255,6 +267,8 @@ private: TypeId unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes); + TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp); + ToStringOptions opts; }; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index b9be820..1fa2e03 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -85,14 +85,11 @@ struct Module DenseHashMap astOverloadResolvedTypes{nullptr}; DenseHashMap astResolvedTypes{nullptr}; - DenseHashMap astOriginalResolvedTypes{nullptr}; DenseHashMap astResolvedTypePacks{nullptr}; // Map AST nodes to the scope they create. Cannot be NotNull because we need a sentinel value for the map. DenseHashMap astScopes{nullptr}; - std::unique_ptr reduction; - std::unordered_map declaredGlobals; ErrorVec errors; LintResult lintResult; diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 2ec5406..978ddb4 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -267,8 +267,18 @@ struct NormalizedType NormalizedType(NormalizedType&&) = default; NormalizedType& operator=(NormalizedType&&) = default; + + // IsType functions + + /// Returns true if the type is a subtype of function. This includes any and unknown. + bool isFunction() const; + + /// Returns true if the type is a subtype of number. This includes any and unknown. + bool isNumber() const; }; + + class Normalizer { std::unordered_map> cachedNormals; diff --git a/Analysis/include/Luau/Simplify.h b/Analysis/include/Luau/Simplify.h new file mode 100644 index 0000000..27ed44f --- /dev/null +++ b/Analysis/include/Luau/Simplify.h @@ -0,0 +1,36 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Type.h" + +#include + +namespace Luau +{ + +struct TypeArena; +struct BuiltinTypes; + +struct SimplifyResult +{ + TypeId result; + + std::set blockedTypes; +}; + +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); +SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); + +enum class Relation +{ + Disjoint, // No A is a B or vice versa + Coincident, // Every A is in B and vice versa + Intersects, // Some As are in B and some Bs are in A. ex (number | string) <-> (string | boolean) + Subset, // Every A is in B + Superset, // Every B is in A +}; + +Relation relate(TypeId left, TypeId right); + +} // namespace Luau diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 7758e8f..dec2c1f 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -99,10 +99,7 @@ inline std::string toString(const Constraint& c, ToStringOptions&& opts) return toString(c, opts); } -inline std::string toString(const Constraint& c) -{ - return toString(c, ToStringOptions{}); -} +std::string toString(const Constraint& c); std::string toString(const Type& tv, ToStringOptions& opts); std::string toString(const TypePackVar& tp, ToStringOptions& opts); diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 907908d..951f89e 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -308,6 +308,12 @@ public: // used. Else we use the embedded Scope*. bool useScopes = false; + // It is sometimes the case under DCR that we speculatively rebind + // GenericTypes to other types as though they were free. We mark logs that + // contain these kinds of substitutions as radioactive so that we know that + // we must never commit one. + bool radioactive = false; + // Used to avoid infinite recursion when types are cyclic. // Shared with all the descendent TxnLogs. std::vector>* sharedSeen; diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 80a044c..d42f58b 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -349,7 +349,9 @@ struct FunctionType DcrMagicFunction dcrMagicFunction = nullptr; DcrMagicRefinement dcrMagicRefinement = nullptr; bool hasSelf; - bool hasNoGenerics = false; + // `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it. + // this flag is used as an optimization to exit early from procedures that manipulate free or generic types. + bool hasNoFreeOrGenericTypes = false; }; enum class TableState @@ -530,7 +532,7 @@ struct ClassType */ struct TypeFamilyInstanceType { - NotNull family; + NotNull family; std::vector typeArguments; std::vector packArguments; diff --git a/Analysis/include/Luau/TypeFamily.h b/Analysis/include/Luau/TypeFamily.h index 4c04f52..bf47de3 100644 --- a/Analysis/include/Luau/TypeFamily.h +++ b/Analysis/include/Luau/TypeFamily.h @@ -21,6 +21,7 @@ using TypePackId = const TypePackVar*; struct TypeArena; struct BuiltinTypes; struct TxnLog; +class Normalizer; /// Represents a reduction result, which may have successfully reduced the type, /// may have concretely failed to reduce the type, or may simply be stuck @@ -52,8 +53,8 @@ struct TypeFamily std::string name; /// The reducer function for the type family. - std::function( - std::vector, std::vector, NotNull, NotNull, NotNull log)> + std::function(std::vector, std::vector, NotNull, NotNull, + NotNull, NotNull, NotNull)> reducer; }; @@ -66,8 +67,8 @@ struct TypePackFamily std::string name; /// The reducer function for the type pack family. - std::function( - std::vector, std::vector, NotNull, NotNull, NotNull log)> + std::function(std::vector, std::vector, NotNull, NotNull, + NotNull, NotNull, NotNull)> reducer; }; @@ -93,8 +94,8 @@ struct FamilyGraphReductionResult * against the TxnLog, otherwise substitutions will directly mutate the type * graph. Do not provide the empty TxnLog, as a result. */ -FamilyGraphReductionResult reduceFamilies( - TypeId entrypoint, Location location, NotNull arena, NotNull builtins, TxnLog* log = nullptr, bool force = false); +FamilyGraphReductionResult reduceFamilies(TypeId entrypoint, Location location, NotNull arena, NotNull builtins, + NotNull scope, NotNull normalizer, TxnLog* log = nullptr, bool force = false); /** * Attempt to reduce all instances of any type or type pack family in the type @@ -109,7 +110,16 @@ FamilyGraphReductionResult reduceFamilies( * against the TxnLog, otherwise substitutions will directly mutate the type * graph. Do not provide the empty TxnLog, as a result. */ -FamilyGraphReductionResult reduceFamilies( - TypePackId entrypoint, Location location, NotNull arena, NotNull builtins, TxnLog* log = nullptr, bool force = false); +FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location location, NotNull arena, NotNull builtins, + NotNull scope, NotNull normalizer, TxnLog* log = nullptr, bool force = false); + +struct BuiltinTypeFamilies +{ + BuiltinTypeFamilies(); + + TypeFamily addFamily; +}; + +const BuiltinTypeFamilies kBuiltinTypeFamilies{}; } // namespace Luau diff --git a/Analysis/include/Luau/TypeReduction.h b/Analysis/include/Luau/TypeReduction.h deleted file mode 100644 index 3f64870..0000000 --- a/Analysis/include/Luau/TypeReduction.h +++ /dev/null @@ -1,85 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#pragma once - -#include "Luau/Type.h" -#include "Luau/TypeArena.h" -#include "Luau/TypePack.h" -#include "Luau/Variant.h" - -namespace Luau -{ - -namespace detail -{ -template -struct ReductionEdge -{ - T type = nullptr; - bool irreducible = false; -}; - -struct TypeReductionMemoization -{ - TypeReductionMemoization() = default; - - TypeReductionMemoization(const TypeReductionMemoization&) = delete; - TypeReductionMemoization& operator=(const TypeReductionMemoization&) = delete; - - TypeReductionMemoization(TypeReductionMemoization&&) = default; - TypeReductionMemoization& operator=(TypeReductionMemoization&&) = default; - - DenseHashMap> types{nullptr}; - DenseHashMap> typePacks{nullptr}; - - bool isIrreducible(TypeId ty); - bool isIrreducible(TypePackId tp); - - TypeId memoize(TypeId ty, TypeId reducedTy); - TypePackId memoize(TypePackId tp, TypePackId reducedTp); - - // Reducing A into B may have a non-irreducible edge A to B for which B is not irreducible, which means B could be reduced into C. - // Because reduction should always be transitive, A should point to C if A points to B and B points to C. - std::optional> memoizedof(TypeId ty) const; - std::optional> memoizedof(TypePackId tp) const; -}; -} // namespace detail - -struct TypeReductionOptions -{ - /// If it's desirable for type reduction to allocate into a different arena than the TypeReduction instance you have, you will need - /// to create a temporary TypeReduction in that case, and set [`TypeReductionOptions::allowTypeReductionsFromOtherArenas`] to true. - /// This is because TypeReduction caches the reduced type. - bool allowTypeReductionsFromOtherArenas = false; -}; - -struct TypeReduction -{ - explicit TypeReduction(NotNull arena, NotNull builtinTypes, NotNull handle, - const TypeReductionOptions& opts = {}); - - TypeReduction(const TypeReduction&) = delete; - TypeReduction& operator=(const TypeReduction&) = delete; - - TypeReduction(TypeReduction&&) = default; - TypeReduction& operator=(TypeReduction&&) = default; - - std::optional reduce(TypeId ty); - std::optional reduce(TypePackId tp); - std::optional reduce(const TypeFun& fun); - -private: - NotNull arena; - NotNull builtinTypes; - NotNull handle; - - TypeReductionOptions options; - detail::TypeReductionMemoization memoization; - - // Computes an *estimated length* of the cartesian product of the given type. - size_t cartesianProductSize(TypeId ty) const; - - bool hasExceededCartesianProductLimit(TypeId ty) const; - bool hasExceededCartesianProductLimit(TypePackId tp) const; -}; - -} // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 4b66568..8dd7473 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -7,7 +7,6 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/TypeReduction.h" #include #include diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 0c1b24a..1eb7854 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -11,6 +11,7 @@ LUAU_FASTFLAG(LuauClonePublicInterfaceLess2) LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) +LUAU_FASTFLAGVARIABLE(LuauCloneCyclicUnions, false) namespace Luau { @@ -282,7 +283,7 @@ void TypeCloner::operator()(const FunctionType& t) ftv->argTypes = clone(t.argTypes, dest, cloneState); ftv->argNames = t.argNames; ftv->retTypes = clone(t.retTypes, dest, cloneState); - ftv->hasNoGenerics = t.hasNoGenerics; + ftv->hasNoFreeOrGenericTypes = t.hasNoFreeOrGenericTypes; } void TypeCloner::operator()(const TableType& t) @@ -373,14 +374,30 @@ void TypeCloner::operator()(const AnyType& t) void TypeCloner::operator()(const UnionType& t) { - std::vector options; - options.reserve(t.options.size()); + if (FFlag::LuauCloneCyclicUnions) + { + TypeId result = dest.addType(FreeType{nullptr}); + seenTypes[typeId] = result; - for (TypeId ty : t.options) - options.push_back(clone(ty, dest, cloneState)); + std::vector options; + options.reserve(t.options.size()); - TypeId result = dest.addType(UnionType{std::move(options)}); - seenTypes[typeId] = result; + for (TypeId ty : t.options) + options.push_back(clone(ty, dest, cloneState)); + + asMutable(result)->ty.emplace(std::move(options)); + } + else + { + std::vector options; + options.reserve(t.options.size()); + + for (TypeId ty : t.options) + options.push_back(clone(ty, dest, cloneState)); + + TypeId result = dest.addType(UnionType{std::move(options)}); + seenTypes[typeId] = result; + } } void TypeCloner::operator()(const IntersectionType& t) diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index c8d99ad..b190f4a 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -13,6 +13,9 @@ #include "Luau/Scope.h" #include "Luau/TypeUtils.h" #include "Luau/Type.h" +#include "Luau/TypeFamily.h" +#include "Luau/Simplify.h" +#include "Luau/VisitType.h" #include @@ -195,8 +198,23 @@ struct RefinementPartition using RefinementContext = std::unordered_map; -static void unionRefinements(const RefinementContext& lhs, const RefinementContext& rhs, RefinementContext& dest, NotNull arena) +static void unionRefinements(NotNull builtinTypes, NotNull arena, const RefinementContext& lhs, const RefinementContext& rhs, + RefinementContext& dest, std::vector* constraints) { + const auto intersect = [&](const std::vector& types) { + if (1 == types.size()) + return types[0]; + else if (2 == types.size()) + { + // TODO: It may be advantageous to create a RefineConstraint here when there are blockedTypes. + SimplifyResult sr = simplifyIntersection(builtinTypes, arena, types[0], types[1]); + if (sr.blockedTypes.empty()) + return sr.result; + } + + return arena->addType(IntersectionType{types}); + }; + for (auto& [def, partition] : lhs) { auto rhsIt = rhs.find(def); @@ -206,55 +224,54 @@ static void unionRefinements(const RefinementContext& lhs, const RefinementConte LUAU_ASSERT(!partition.discriminantTypes.empty()); LUAU_ASSERT(!rhsIt->second.discriminantTypes.empty()); - TypeId leftDiscriminantTy = - partition.discriminantTypes.size() == 1 ? partition.discriminantTypes[0] : arena->addType(IntersectionType{partition.discriminantTypes}); + TypeId leftDiscriminantTy = partition.discriminantTypes.size() == 1 ? partition.discriminantTypes[0] : intersect(partition.discriminantTypes); - TypeId rightDiscriminantTy = rhsIt->second.discriminantTypes.size() == 1 ? rhsIt->second.discriminantTypes[0] - : arena->addType(IntersectionType{rhsIt->second.discriminantTypes}); + TypeId rightDiscriminantTy = + rhsIt->second.discriminantTypes.size() == 1 ? rhsIt->second.discriminantTypes[0] : intersect(rhsIt->second.discriminantTypes); - dest[def].discriminantTypes.push_back(arena->addType(UnionType{{leftDiscriminantTy, rightDiscriminantTy}})); + dest[def].discriminantTypes.push_back(simplifyUnion(builtinTypes, arena, leftDiscriminantTy, rightDiscriminantTy).result); dest[def].shouldAppendNilType |= partition.shouldAppendNilType || rhsIt->second.shouldAppendNilType; } } -static void computeRefinement(const ScopePtr& scope, RefinementId refinement, RefinementContext* refis, bool sense, NotNull arena, bool eq, - std::vector* constraints) +static void computeRefinement(NotNull builtinTypes, NotNull arena, const ScopePtr& scope, RefinementId refinement, + RefinementContext* refis, bool sense, bool eq, std::vector* constraints) { if (!refinement) return; else if (auto variadic = get(refinement)) { for (RefinementId refi : variadic->refinements) - computeRefinement(scope, refi, refis, sense, arena, eq, constraints); + computeRefinement(builtinTypes, arena, scope, refi, refis, sense, eq, constraints); } else if (auto negation = get(refinement)) - return computeRefinement(scope, negation->refinement, refis, !sense, arena, eq, constraints); + return computeRefinement(builtinTypes, arena, scope, negation->refinement, refis, !sense, eq, constraints); else if (auto conjunction = get(refinement)) { RefinementContext lhsRefis; RefinementContext rhsRefis; - computeRefinement(scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, arena, eq, constraints); - computeRefinement(scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, arena, eq, constraints); + computeRefinement(builtinTypes, arena, scope, conjunction->lhs, sense ? refis : &lhsRefis, sense, eq, constraints); + computeRefinement(builtinTypes, arena, scope, conjunction->rhs, sense ? refis : &rhsRefis, sense, eq, constraints); if (!sense) - unionRefinements(lhsRefis, rhsRefis, *refis, arena); + unionRefinements(builtinTypes, arena, lhsRefis, rhsRefis, *refis, constraints); } else if (auto disjunction = get(refinement)) { RefinementContext lhsRefis; RefinementContext rhsRefis; - computeRefinement(scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, arena, eq, constraints); - computeRefinement(scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, arena, eq, constraints); + computeRefinement(builtinTypes, arena, scope, disjunction->lhs, sense ? &lhsRefis : refis, sense, eq, constraints); + computeRefinement(builtinTypes, arena, scope, disjunction->rhs, sense ? &rhsRefis : refis, sense, eq, constraints); if (sense) - unionRefinements(lhsRefis, rhsRefis, *refis, arena); + unionRefinements(builtinTypes, arena, lhsRefis, rhsRefis, *refis, constraints); } else if (auto equivalence = get(refinement)) { - computeRefinement(scope, equivalence->lhs, refis, sense, arena, true, constraints); - computeRefinement(scope, equivalence->rhs, refis, sense, arena, true, constraints); + computeRefinement(builtinTypes, arena, scope, equivalence->lhs, refis, sense, true, constraints); + computeRefinement(builtinTypes, arena, scope, equivalence->rhs, refis, sense, true, constraints); } else if (auto proposition = get(refinement)) { @@ -300,6 +317,63 @@ static void computeRefinement(const ScopePtr& scope, RefinementId refinement, Re } } +namespace +{ + +/* + * Constraint generation may be called upon to simplify an intersection or union + * of types that are not sufficiently solved yet. We use + * FindSimplificationBlockers to recognize these types and defer the + * simplification until constraint solution. + */ +struct FindSimplificationBlockers : TypeOnceVisitor +{ + bool found = false; + + bool visit(TypeId) override + { + return !found; + } + + bool visit(TypeId, const BlockedType&) override + { + found = true; + return false; + } + + bool visit(TypeId, const FreeType&) override + { + found = true; + return false; + } + + bool visit(TypeId, const PendingExpansionType&) override + { + found = true; + return false; + } + + // We do not need to know anything at all about a function's argument or + // return types in order to simplify it in an intersection or union. + bool visit(TypeId, const FunctionType&) override + { + return false; + } + + bool visit(TypeId, const ClassType&) override + { + return false; + } +}; + +bool mustDeferIntersection(TypeId ty) +{ + FindSimplificationBlockers bts; + bts.traverse(ty); + return bts.found; +} +} // namespace + void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement) { if (!refinement) @@ -307,7 +381,7 @@ void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location lo RefinementContext refinements; std::vector constraints; - computeRefinement(scope, refinement, &refinements, /*sense*/ true, arena, /*eq*/ false, &constraints); + computeRefinement(builtinTypes, arena, scope, refinement, &refinements, /*sense*/ true, /*eq*/ false, &constraints); for (auto& [def, partition] : refinements) { @@ -317,8 +391,24 @@ void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location lo if (partition.shouldAppendNilType) ty = arena->addType(UnionType{{ty, builtinTypes->nilType}}); - partition.discriminantTypes.push_back(ty); - scope->dcrRefinements[def] = arena->addType(IntersectionType{std::move(partition.discriminantTypes)}); + // Intersect ty with every discriminant type. If either type is not + // sufficiently solved, we queue the intersection up via an + // IntersectConstraint. + + for (TypeId dt : partition.discriminantTypes) + { + if (mustDeferIntersection(ty) || mustDeferIntersection(dt)) + { + TypeId r = arena->addType(BlockedType{}); + addConstraint(scope, location, RefineConstraint{RefineConstraint::Intersection, r, ty, dt}); + + ty = r; + } + else + ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + } + + scope->dcrRefinements[def] = ty; } } @@ -708,7 +798,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFun functionType = arena->addType(BlockedType{}); scope->bindings[function->name] = Binding{functionType, function->name->location}; - FunctionSignature sig = checkFunctionSignature(scope, function->func); + FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location); sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location}; BreadcrumbId bc = dfg->getBreadcrumb(function->name); @@ -741,10 +831,12 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction TypeId generalizedType = arena->addType(BlockedType{}); Checkpoint start = checkpoint(this); - FunctionSignature sig = checkFunctionSignature(scope, function->func); + FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location); std::unordered_set excludeList; + const NullableBreadcrumbId functionBreadcrumb = dfg->getBreadcrumb(function->name); + if (AstExprLocal* localName = function->name->as()) { std::optional existingFunctionTy = scope->lookup(localName->local); @@ -759,6 +851,9 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction scope->bindings[localName->local] = Binding{generalizedType, localName->location}; sig.bodyScope->bindings[localName->local] = Binding{sig.signature, localName->location}; + + if (functionBreadcrumb) + sig.bodyScope->dcrRefinements[functionBreadcrumb->def] = sig.signature; } else if (AstExprGlobal* globalName = function->name->as()) { @@ -769,6 +864,9 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction generalizedType = *existingFunctionTy; sig.bodyScope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; + + if (functionBreadcrumb) + sig.bodyScope->dcrRefinements[functionBreadcrumb->def] = sig.signature; } else if (AstExprIndexName* indexName = function->name->as()) { @@ -795,8 +893,8 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction if (generalizedType == nullptr) ice->ice("generalizedType == nullptr", function->location); - if (NullableBreadcrumbId bc = dfg->getBreadcrumb(function->name)) - scope->dcrRefinements[bc->def] = generalizedType; + if (functionBreadcrumb) + scope->dcrRefinements[functionBreadcrumb->def] = generalizedType; checkFunctionBody(sig.bodyScope, function->func); Checkpoint end = checkpoint(this); @@ -1469,21 +1567,7 @@ Inference ConstraintGraphBuilder::check( else if (auto call = expr->as()) result = flattenPack(scope, expr->location, checkPack(scope, call)); // TODO: needs predicates too else if (auto a = expr->as()) - { - Checkpoint startCheckpoint = checkpoint(this); - FunctionSignature sig = checkFunctionSignature(scope, a, expectedType); - checkFunctionBody(sig.bodyScope, a); - Checkpoint endCheckpoint = checkpoint(this); - - TypeId generalizedTy = arena->addType(BlockedType{}); - NotNull gc = addConstraint(sig.signatureScope, expr->location, GeneralizationConstraint{generalizedTy, sig.signature}); - - forEachConstraint(startCheckpoint, endCheckpoint, this, [gc](const ConstraintPtr& constraint) { - gc->dependencies.emplace_back(constraint.get()); - }); - - result = Inference{generalizedTy}; - } + result = check(scope, a, expectedType); else if (auto indexName = expr->as()) result = check(scope, indexName); else if (auto indexExpr = expr->as()) @@ -1651,6 +1735,23 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexExpr* return Inference{result}; } +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprFunction* func, std::optional expectedType) +{ + Checkpoint startCheckpoint = checkpoint(this); + FunctionSignature sig = checkFunctionSignature(scope, func, expectedType); + checkFunctionBody(sig.bodyScope, func); + Checkpoint endCheckpoint = checkpoint(this); + + TypeId generalizedTy = arena->addType(BlockedType{}); + NotNull gc = addConstraint(sig.signatureScope, func->location, GeneralizationConstraint{generalizedTy, sig.signature}); + + forEachConstraint(startCheckpoint, endCheckpoint, this, [gc](const ConstraintPtr& constraint) { + gc->dependencies.emplace_back(constraint.get()); + }); + + return Inference{generalizedTy}; +} + Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprUnary* unary) { auto [operandType, refinement] = check(scope, unary->expr); @@ -1667,6 +1768,17 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* bi { auto [leftType, rightType, refinement] = checkBinary(scope, binary, expectedType); + if (binary->op == AstExprBinary::Op::Add) + { + TypeId resultType = arena->addType(TypeFamilyInstanceType{ + NotNull{&kBuiltinTypeFamilies.addFamily}, + {leftType, rightType}, + {}, + }); + addConstraint(scope, binary->location, ReduceConstraint{resultType}); + return Inference{resultType, std::move(refinement)}; + } + TypeId resultType = arena->addType(BlockedType{}); addConstraint(scope, binary->location, BinaryConstraint{binary->op, leftType, rightType, resultType, binary, &module->astOriginalCallTypes, &module->astOverloadResolvedTypes}); @@ -1686,7 +1798,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* if applyRefinements(elseScope, ifElse->falseExpr->location, refinementArena.negation(refinement)); TypeId elseType = check(elseScope, ifElse->falseExpr, ValueContext::RValue, expectedType).ty; - return Inference{expectedType ? *expectedType : arena->addType(UnionType{{thenType, elseType}})}; + return Inference{expectedType ? *expectedType : simplifyUnion(builtinTypes, arena, thenType, elseType).result}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) @@ -1902,6 +2014,8 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) } else if (auto indexExpr = e->as()) { + // We need to populate the type for the index value + check(scope, indexExpr->index, ValueContext::RValue); if (auto strIndex = indexExpr->index->as()) { segments.push_back(std::string(strIndex->value.data, strIndex->value.size)); @@ -2018,12 +2132,12 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp else { expectedValueType = arena->addType(BlockedType{}); - addConstraint(scope, item.value->location, HasPropConstraint{*expectedValueType, *expectedType, stringKey->value.data}); + addConstraint(scope, item.value->location, + HasPropConstraint{*expectedValueType, *expectedType, stringKey->value.data, /*suppressSimplification*/ true}); } } } - // We'll resolve the expected index result type here with the following priority: // 1. Record table types - in which key, value pairs must be handled on a k,v pair basis. // In this case, the above if-statement will populate expectedValueType @@ -2079,7 +2193,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp } ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature( - const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType) + const ScopePtr& parent, AstExprFunction* fn, std::optional expectedType, std::optional originalName) { ScopePtr signatureScope = nullptr; ScopePtr bodyScope = nullptr; @@ -2235,12 +2349,18 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS // TODO: Preserve argument names in the function's type. FunctionType actualFunction{TypeLevel{}, parent.get(), arena->addTypePack(argTypes, varargPack), returnType}; - actualFunction.hasNoGenerics = !hasGenerics; actualFunction.generics = std::move(genericTypes); actualFunction.genericPacks = std::move(genericTypePacks); actualFunction.argNames = std::move(argNames); actualFunction.hasSelf = fn->self != nullptr; + FunctionDefinition defn; + defn.definitionModuleName = module->name; + defn.definitionLocation = fn->location; + defn.varargLocation = fn->vararg ? std::make_optional(fn->varargLocation) : std::nullopt; + defn.originalNameLocation = originalName.value_or(Location(fn->location.begin, 0)); + actualFunction.definition = defn; + TypeId actualFunctionType = arena->addType(std::move(actualFunction)); LUAU_ASSERT(actualFunctionType); module->astTypes[fn] = actualFunctionType; @@ -2283,6 +2403,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b if (ref->parameters.size != 1 || !ref->parameters.data[0].type) { reportError(ty->location, GenericError{"_luau_print requires one generic parameter"}); + module->astResolvedTypes[ty] = builtinTypes->errorRecoveryType(); return builtinTypes->errorRecoveryType(); } else @@ -2420,7 +2541,6 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b // This replicates the behavior of the appropriate FunctionType // constructors. - ftv.hasNoGenerics = !hasGenerics; ftv.generics = std::move(genericTypes); ftv.genericPacks = std::move(genericTypePacks); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 488fd4b..9c688f4 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -11,12 +11,13 @@ #include "Luau/Metamethods.h" #include "Luau/ModuleResolver.h" #include "Luau/Quantify.h" +#include "Luau/Simplify.h" #include "Luau/ToString.h" -#include "Luau/TypeUtils.h" #include "Luau/Type.h" +#include "Luau/TypeFamily.h" +#include "Luau/TypeUtils.h" #include "Luau/Unifier.h" #include "Luau/VisitType.h" -#include "Luau/TypeFamily.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAG(LuauRequirePathTrueModuleName) @@ -73,7 +74,7 @@ static std::pair, std::vector> saturateArguments // mutually exclusive with the type pack -> type conversion we do below: // extraTypes will only have elements in it if we have more types than we // have parameter slots for them to go into. - if (!extraTypes.empty()) + if (!extraTypes.empty() && !fn.typePackParams.empty()) { saturatedPackArguments.push_back(arena->addTypePack(extraTypes)); } @@ -89,7 +90,7 @@ static std::pair, std::vector> saturateArguments { saturatedTypeArguments.push_back(*first(tp)); } - else + else if (saturatedPackArguments.size() < fn.typePackParams.size()) { saturatedPackArguments.push_back(tp); } @@ -426,7 +427,9 @@ void ConstraintSolver::finalizeModule() rootScope->returnType = builtinTypes->errorTypePack; } else - rootScope->returnType = *returnType; + { + rootScope->returnType = anyifyModuleReturnTypePackGenerics(*returnType); + } } bool ConstraintSolver::tryDispatch(NotNull constraint, bool force) @@ -468,6 +471,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*sottc, constraint); else if (auto uc = get(*constraint)) success = tryDispatch(*uc, constraint); + else if (auto rc = get(*constraint)) + success = tryDispatch(*rc, constraint, force); else if (auto rc = get(*constraint)) success = tryDispatch(*rc, constraint, force); else if (auto rpc = get(*constraint)) @@ -541,15 +546,25 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNullscope); std::optional instantiated = inst.substitute(c.superType); - LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS LUAU_ASSERT(get(c.subType)); + + if (!instantiated.has_value()) + { + reportError(UnificationTooComplex{}, constraint->location); + + asMutable(c.subType)->ty.emplace(errorRecoveryType()); + unblock(c.subType); + + return true; + } + asMutable(c.subType)->ty.emplace(*instantiated); InstantiationQueuer queuer{constraint->scope, constraint->location, this}; @@ -759,9 +774,11 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullnormalize(leftType); if (hasTypeInIntersection(leftType) && force) asMutable(leftType)->ty.emplace(anyPresent ? builtinTypes->anyType : builtinTypes->numberType); - if (isNumber(leftType)) + if (normLeftTy && normLeftTy->isNumber()) { unify(leftType, rightType, constraint->scope); asMutable(resultType)->ty.emplace(anyPresent ? builtinTypes->anyType : leftType); @@ -770,6 +787,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNulladdType(IntersectionType{{builtinTypes->falsyType, leftType}}); + TypeId leftFilteredTy = simplifyIntersection(builtinTypes, arena, leftType, builtinTypes->falsyType).result; - asMutable(resultType)->ty.emplace(arena->addType(UnionType{{leftFilteredTy, rightType}})); + asMutable(resultType)->ty.emplace(simplifyUnion(builtinTypes, arena, rightType, leftFilteredTy).result); unblock(resultType); return true; } @@ -819,9 +837,9 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNulladdType(IntersectionType{{builtinTypes->truthyType, leftType}}); + TypeId leftFilteredTy = simplifyIntersection(builtinTypes, arena, leftType, builtinTypes->truthyType).result; - asMutable(resultType)->ty.emplace(arena->addType(UnionType{{leftFilteredTy, rightType}})); + asMutable(resultType)->ty.emplace(simplifyUnion(builtinTypes, arena, rightType, leftFilteredTy).result); unblock(resultType); return true; } @@ -1266,7 +1284,12 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull instantiated = inst.substitute(overload); - LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS + + if (!instantiated.has_value()) + { + reportError(UnificationTooComplex{}, constraint->location); + return true; + } Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; u.enableScopeTests(); @@ -1374,7 +1397,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(followed)) *asMutable(c.resultType) = BoundType{c.discriminantType}; else - *asMutable(c.resultType) = BoundType{builtinTypes->unknownType}; + *asMutable(c.resultType) = BoundType{builtinTypes->anyType}; + + unblock(c.resultType); return true; } @@ -1700,10 +1725,131 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull found; + bool visit(TypeId ty, const BlockedType&) override + { + found.insert(ty); + return false; + } + + bool visit(TypeId ty, const PendingExpansionType&) override + { + found.insert(ty); + return false; + } +}; + +} + +static bool isNegatedAny(TypeId ty) +{ + ty = follow(ty); + const NegationType* nt = get(ty); + if (!nt) + return false; + TypeId negatedTy = follow(nt->ty); + return bool(get(negatedTy)); +} + +bool ConstraintSolver::tryDispatch(const RefineConstraint& c, NotNull constraint, bool force) +{ + if (isBlocked(c.discriminant)) + return block(c.discriminant, constraint); + + FindRefineConstraintBlockers fbt; + fbt.traverse(c.discriminant); + + if (!fbt.found.empty()) + { + bool foundOne = false; + + for (TypeId blocked : fbt.found) + { + if (blocked == c.type) + continue; + + block(blocked, constraint); + foundOne = true; + } + + if (foundOne) + return false; + } + + /* HACK: Refinements sometimes produce a type T & ~any under the assumption + * that ~any is the same as any. This is so so weird, but refinements needs + * some way to say "I may refine this, but I'm not sure." + * + * It does this by refining on a blocked type and deferring the decision + * until it is unblocked. + * + * Refinements also get negated, so we wind up with types like T & ~*blocked* + * + * We need to treat T & ~any as T in this case. + */ + + if (c.mode == RefineConstraint::Intersection && isNegatedAny(c.discriminant)) + { + asMutable(c.resultType)->ty.emplace(c.type); + unblock(c.resultType); + return true; + } + + const TypeId type = follow(c.type); + + LUAU_ASSERT(get(c.resultType)); + + if (type == c.resultType) + { + /* + * Sometimes, we get a constraint of the form + * + * *blocked-N* ~ refine *blocked-N* & U + * + * The constraint essentially states that a particular type is a + * refinement of itself. This is weird and I think vacuous. + * + * I *believe* it is safe to replace the result with a fresh type that + * is constrained by U. We effect this by minting a fresh type for the + * result when U = any, else we bind the result to whatever discriminant + * was offered. + */ + if (get(follow(c.discriminant))) + asMutable(c.resultType)->ty.emplace(constraint->scope); + else + asMutable(c.resultType)->ty.emplace(c.discriminant); + + unblock(c.resultType); + return true; + } + + auto [result, blockedTypes] = c.mode == RefineConstraint::Intersection ? simplifyIntersection(builtinTypes, NotNull{arena}, type, c.discriminant) + : simplifyUnion(builtinTypes, NotNull{arena}, type, c.discriminant); + + if (!force && !blockedTypes.empty()) + return block(blockedTypes, constraint); + + asMutable(c.resultType)->ty.emplace(result); + + unblock(c.resultType); + + return true; +} + bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force) { TypeId ty = follow(c.ty); - FamilyGraphReductionResult result = reduceFamilies(ty, constraint->location, NotNull{arena}, builtinTypes, nullptr, force); + FamilyGraphReductionResult result = + reduceFamilies(ty, constraint->location, NotNull{arena}, builtinTypes, constraint->scope, normalizer, nullptr, force); for (TypeId r : result.reducedTypes) unblock(r); @@ -1726,7 +1872,8 @@ bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force) { TypePackId tp = follow(c.tp); - FamilyGraphReductionResult result = reduceFamilies(tp, constraint->location, NotNull{arena}, builtinTypes, nullptr, force); + FamilyGraphReductionResult result = + reduceFamilies(tp, constraint->location, NotNull{arena}, builtinTypes, constraint->scope, normalizer, nullptr, force); for (TypeId r : result.reducedTypes) unblock(r); @@ -1951,13 +2098,15 @@ bool ConstraintSolver::tryDispatchIterableFunction( return true; } -std::pair, std::optional> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName) +std::pair, std::optional> ConstraintSolver::lookupTableProp( + TypeId subjectType, const std::string& propName, bool suppressSimplification) { std::unordered_set seen; - return lookupTableProp(subjectType, propName, seen); + return lookupTableProp(subjectType, propName, suppressSimplification, seen); } -std::pair, std::optional> ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen) +std::pair, std::optional> ConstraintSolver::lookupTableProp( + TypeId subjectType, const std::string& propName, bool suppressSimplification, std::unordered_set& seen) { if (!seen.insert(subjectType).second) return {}; @@ -1985,7 +2134,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa } else if (auto mt = get(subjectType)) { - auto [blocked, result] = lookupTableProp(mt->table, propName, seen); + auto [blocked, result] = lookupTableProp(mt->table, propName, suppressSimplification, seen); if (!blocked.empty() || result) return {blocked, result}; @@ -2016,13 +2165,17 @@ std::pair, std::optional> ConstraintSolver::lookupTa } } else - return lookupTableProp(indexType, propName, seen); + return lookupTableProp(indexType, propName, suppressSimplification, seen); } } else if (auto ct = get(subjectType)) { if (auto p = lookupClassProp(ct, propName)) return {{}, p->type()}; + if (ct->indexer) + { + return {{}, ct->indexer->indexResultType}; + } } else if (auto pt = get(subjectType); pt && pt->metatable) { @@ -2033,7 +2186,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa if (indexProp == metatable->props.end()) return {{}, std::nullopt}; - return lookupTableProp(indexProp->second.type(), propName, seen); + return lookupTableProp(indexProp->second.type(), propName, suppressSimplification, seen); } else if (auto ft = get(subjectType)) { @@ -2054,7 +2207,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa for (TypeId ty : utv) { - auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); + auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, suppressSimplification, seen); blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); if (innerResult) options.insert(*innerResult); @@ -2067,6 +2220,12 @@ std::pair, std::optional> ConstraintSolver::lookupTa return {{}, std::nullopt}; else if (options.size() == 1) return {{}, *begin(options)}; + else if (options.size() == 2 && !suppressSimplification) + { + TypeId one = *begin(options); + TypeId two = *(++begin(options)); + return {{}, simplifyUnion(builtinTypes, arena, one, two).result}; + } else return {{}, arena->addType(UnionType{std::vector(begin(options), end(options))})}; } @@ -2077,7 +2236,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa for (TypeId ty : itv) { - auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, seen); + auto [innerBlocked, innerResult] = lookupTableProp(ty, propName, suppressSimplification, seen); blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); if (innerResult) options.insert(*innerResult); @@ -2090,6 +2249,12 @@ std::pair, std::optional> ConstraintSolver::lookupTa return {{}, std::nullopt}; else if (options.size() == 1) return {{}, *begin(options)}; + else if (options.size() == 2 && !suppressSimplification) + { + TypeId one = *begin(options); + TypeId two = *(++begin(options)); + return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; + } else return {{}, arena->addType(IntersectionType{std::vector(begin(options), end(options))})}; } @@ -2214,13 +2379,6 @@ struct Blocker : TypeOnceVisitor { } - bool visit(TypeId ty, const BlockedType&) - { - blocked = true; - solver->block(ty, constraint); - return false; - } - bool visit(TypeId ty, const PendingExpansionType&) { blocked = true; @@ -2229,14 +2387,14 @@ struct Blocker : TypeOnceVisitor } }; -bool ConstraintSolver::recursiveBlock(TypeId target, NotNull constraint) +bool ConstraintSolver::blockOnPendingTypes(TypeId target, NotNull constraint) { Blocker blocker{NotNull{this}, constraint}; blocker.traverse(target); return !blocker.blocked; } -bool ConstraintSolver::recursiveBlock(TypePackId pack, NotNull constraint) +bool ConstraintSolver::blockOnPendingTypes(TypePackId pack, NotNull constraint) { Blocker blocker{NotNull{this}, constraint}; blocker.traverse(pack); @@ -2482,4 +2640,34 @@ TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, return arena->addType(UnionType{types}); } +TypePackId ConstraintSolver::anyifyModuleReturnTypePackGenerics(TypePackId tp) +{ + tp = follow(tp); + + if (const VariadicTypePack* vtp = get(tp)) + { + TypeId ty = follow(vtp->ty); + return get(ty) ? builtinTypes->anyTypePack : tp; + } + + if (!get(follow(tp))) + return tp; + + std::vector resultTypes; + std::optional resultTail; + + TypePackIterator it = begin(tp); + + for (TypePackIterator e = end(tp); it != e; ++it) + { + TypeId ty = follow(*it); + resultTypes.push_back(get(ty) ? builtinTypes->anyType : ty); + } + + if (std::optional tail = it.tail()) + resultTail = anyifyModuleReturnTypePackGenerics(*tail); + + return arena->addTypePack(resultTypes, resultTail); +} + } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index b16eda8..07393eb 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -16,7 +16,6 @@ #include "Luau/TimeTrace.h" #include "Luau/TypeChecker2.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeReduction.h" #include "Luau/Variant.h" #include @@ -622,7 +621,6 @@ CheckResult Frontend::check_DEPRECATED(const ModuleName& name, std::optionalastOriginalCallTypes.clear(); module->astOverloadResolvedTypes.clear(); module->astResolvedTypes.clear(); - module->astOriginalResolvedTypes.clear(); module->astResolvedTypePacks.clear(); module->astScopes.clear(); @@ -1138,7 +1136,6 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) module->astOriginalCallTypes.clear(); module->astOverloadResolvedTypes.clear(); module->astResolvedTypes.clear(); - module->astOriginalResolvedTypes.clear(); module->astResolvedTypePacks.clear(); module->astScopes.clear(); @@ -1311,7 +1308,6 @@ ModulePtr check(const SourceModule& sourceModule, const std::vector(); result->name = sourceModule.name; result->humanReadableName = sourceModule.humanReadableName; - result->reduction = std::make_unique(NotNull{&result->internalTypes}, builtinTypes, iceHandler); std::unique_ptr logger; if (recordJsonLog) @@ -1365,11 +1361,17 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectorinternalTypes); freeze(result->interfaceTypes); diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 7d0f0f7..1d6092f 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -13,7 +13,7 @@ bool Instantiation::isDirty(TypeId ty) { if (const FunctionType* ftv = log->getMutable(ty)) { - if (ftv->hasNoGenerics) + if (ftv->hasNoFreeOrGenericTypes) return false; return true; @@ -74,7 +74,7 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty) { if (const FunctionType* ftv = log->getMutable(ty)) { - if (ftv->hasNoGenerics) + if (ftv->hasNoFreeOrGenericTypes) return true; // We aren't recursing in the case of a generic function which diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 830aaf7..0addaa3 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -10,7 +10,6 @@ #include "Luau/Type.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" -#include "Luau/TypeReduction.h" #include "Luau/VisitType.h" #include @@ -20,7 +19,6 @@ LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess2, false); LUAU_FASTFLAG(LuauSubstitutionReentrant); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); LUAU_FASTFLAG(LuauSubstitutionFixMissingFields); -LUAU_FASTFLAGVARIABLE(LuauCopyExportedTypes, false); namespace Luau { @@ -238,10 +236,7 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr // Copy external stuff over to Module itself this->returnType = moduleScope->returnType; - if (FFlag::DebugLuauDeferredConstraintResolution || FFlag::LuauCopyExportedTypes) - this->exportedTypeBindings = moduleScope->exportedTypeBindings; - else - this->exportedTypeBindings = std::move(moduleScope->exportedTypeBindings); + this->exportedTypeBindings = moduleScope->exportedTypeBindings; } bool Module::hasModuleScope() const diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index cfc0ae1..24c31f7 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); -LUAU_FASTFLAGVARIABLE(LuauNormalizeMetatableFixes, false); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(LuauTransitiveSubtyping) @@ -228,6 +227,16 @@ NormalizedType::NormalizedType(NotNull builtinTypes) { } +bool NormalizedType::isFunction() const +{ + return !get(tops) || !functions.parts.empty(); +} + +bool NormalizedType::isNumber() const +{ + return !get(tops) || !get(numbers); +} + static bool isShallowInhabited(const NormalizedType& norm) { // This test is just a shallow check, for example it returns `true` for `{ p : never }` @@ -516,7 +525,8 @@ static bool areNormalizedClasses(const NormalizedClassType& tys) static bool isPlainTyvar(TypeId ty) { - return (get(ty) || get(ty) || (FFlag::LuauNormalizeBlockedTypes && get(ty)) || get(ty)); + return (get(ty) || get(ty) || (FFlag::LuauNormalizeBlockedTypes && get(ty)) || + get(ty) || get(ty)); } static bool isNormalizedTyvar(const NormalizedTyvars& tyvars) @@ -1366,7 +1376,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor else if (FFlag::LuauTransitiveSubtyping && get(here.tops)) return true; else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there)) || - get(there)) + get(there) || get(there)) { if (tyvarIndex(there) <= ignoreSmallerTyvars) return true; @@ -1436,7 +1446,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor } else if (!FFlag::LuauNormalizeBlockedTypes && get(there)) LUAU_ASSERT(!"Internal error: Trying to normalize a BlockedType"); - else if (get(there)) + else if (get(there) || get(there)) { // nothing } @@ -1981,17 +1991,14 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there else if (isPrim(there, PrimitiveType::Table)) return here; - if (FFlag::LuauNormalizeMetatableFixes) - { - if (get(here)) - return there; - else if (get(there)) - return here; - else if (get(here)) - return there; - else if (get(there)) - return here; - } + if (get(here)) + return there; + else if (get(there)) + return here; + else if (get(here)) + return there; + else if (get(there)) + return here; TypeId htable = here; TypeId hmtable = nullptr; @@ -2009,22 +2016,12 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there } const TableType* httv = get(htable); - if (FFlag::LuauNormalizeMetatableFixes) - { - if (!httv) - return std::nullopt; - } - else - LUAU_ASSERT(httv); + if (!httv) + return std::nullopt; const TableType* tttv = get(ttable); - if (FFlag::LuauNormalizeMetatableFixes) - { - if (!tttv) - return std::nullopt; - } - else - LUAU_ASSERT(tttv); + if (!tttv) + return std::nullopt; if (httv->state == TableState::Free || tttv->state == TableState::Free) @@ -2471,7 +2468,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) return true; } else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there)) || - get(there)) + get(there) || get(there)) { NormalizedType thereNorm{builtinTypes}; NormalizedType topNorm{builtinTypes}; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 5a7a050..3528d53 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -136,7 +136,7 @@ void quantify(TypeId ty, TypeLevel level) ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) - ftv->hasNoGenerics = true; + ftv->hasNoFreeOrGenericTypes = true; } } else @@ -276,7 +276,7 @@ std::optional quantify(TypeArena* arena, TypeId ty, Scope* sco for (auto k : quantifier.insertedGenericPacks.keys) ftv->genericPacks.push_back(quantifier.insertedGenericPacks.pairings[k]); - ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty() && !quantifier.seenGenericType && !quantifier.seenMutableType; + ftv->hasNoFreeOrGenericTypes = ftv->generics.empty() && ftv->genericPacks.empty() && !quantifier.seenGenericType && !quantifier.seenMutableType; return std::optional({*result, std::move(quantifier.insertedGenerics), std::move(quantifier.insertedGenericPacks)}); } diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp new file mode 100644 index 0000000..8e9424a --- /dev/null +++ b/Analysis/src/Simplify.cpp @@ -0,0 +1,1270 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Simplify.h" + +#include "Luau/RecursionCounter.h" +#include "Luau/ToString.h" +#include "Luau/TypeArena.h" +#include "Luau/Normalize.h" // TypeIds + +LUAU_FASTINT(LuauTypeReductionRecursionLimit) + +namespace Luau +{ + +struct TypeSimplifier +{ + NotNull builtinTypes; + NotNull arena; + + std::set blockedTypes; + + int recursionDepth = 0; + + TypeId mkNegation(TypeId ty); + + TypeId intersectFromParts(std::set parts); + + TypeId intersectUnionWithType(TypeId unionTy, TypeId right); + TypeId intersectUnions(TypeId left, TypeId right); + TypeId intersectNegatedUnion(TypeId unionTy, TypeId right); + + TypeId intersectTypeWithNegation(TypeId a, TypeId b); + TypeId intersectNegations(TypeId a, TypeId b); + + TypeId intersectIntersectionWithType(TypeId left, TypeId right); + + // Attempt to intersect the two types. Does not recurse. Does not handle + // unions, intersections, or negations. + std::optional basicIntersect(TypeId left, TypeId right); + + TypeId intersect(TypeId ty, TypeId discriminant); + TypeId union_(TypeId ty, TypeId discriminant); + + TypeId simplify(TypeId ty); + TypeId simplify(TypeId ty, DenseHashSet& seen); +}; + +template +static std::pair get2(TID one, TID two) +{ + const A* a = get(one); + const B* b = get(two); + return a && b ? std::make_pair(a, b) : std::make_pair(nullptr, nullptr); +} + +// Match the exact type false|nil +static bool isFalsyType(TypeId ty) +{ + ty = follow(ty); + const UnionType* ut = get(ty); + if (!ut) + return false; + + bool hasFalse = false; + bool hasNil = false; + + auto it = begin(ut); + if (it == end(ut)) + return false; + + TypeId t = follow(*it); + + if (auto pt = get(t); pt && pt->type == PrimitiveType::NilType) + hasNil = true; + else if (auto st = get(t); st && st->variant == BooleanSingleton{false}) + hasFalse = true; + else + return false; + + ++it; + if (it == end(ut)) + return false; + + t = follow(*it); + + if (auto pt = get(t); pt && pt->type == PrimitiveType::NilType) + hasNil = true; + else if (auto st = get(t); st && st->variant == BooleanSingleton{false}) + hasFalse = true; + else + return false; + + ++it; + if (it != end(ut)) + return false; + + return hasFalse && hasNil; +} + +// Match the exact type ~(false|nil) +bool isTruthyType(TypeId ty) +{ + ty = follow(ty); + + const NegationType* nt = get(ty); + if (!nt) + return false; + + return isFalsyType(nt->ty); +} + +Relation flip(Relation rel) +{ + switch (rel) + { + case Relation::Subset: + return Relation::Superset; + case Relation::Superset: + return Relation::Subset; + default: + return rel; + } +} + +// FIXME: I'm not completely certain that this function is theoretically reasonable. +Relation combine(Relation a, Relation b) +{ + switch (a) + { + case Relation::Disjoint: + switch (b) + { + case Relation::Disjoint: + return Relation::Disjoint; + case Relation::Coincident: + return Relation::Superset; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Intersects; + case Relation::Superset: + return Relation::Intersects; + } + case Relation::Coincident: + switch (b) + { + case Relation::Disjoint: + return Relation::Coincident; + case Relation::Coincident: + return Relation::Coincident; + case Relation::Intersects: + return Relation::Superset; + case Relation::Subset: + return Relation::Coincident; + case Relation::Superset: + return Relation::Intersects; + } + case Relation::Superset: + switch (b) + { + case Relation::Disjoint: + return Relation::Superset; + case Relation::Coincident: + return Relation::Superset; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Intersects; + case Relation::Superset: + return Relation::Superset; + } + case Relation::Subset: + switch (b) + { + case Relation::Disjoint: + return Relation::Subset; + case Relation::Coincident: + return Relation::Coincident; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Subset; + case Relation::Superset: + return Relation::Intersects; + } + case Relation::Intersects: + switch (b) + { + case Relation::Disjoint: + return Relation::Intersects; + case Relation::Coincident: + return Relation::Superset; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Intersects; + case Relation::Superset: + return Relation::Intersects; + } + } + + LUAU_UNREACHABLE(); + return Relation::Intersects; +} + +// Given A & B, what is A & ~B? +Relation invert(Relation r) +{ + switch (r) + { + case Relation::Disjoint: + return Relation::Subset; + case Relation::Coincident: + return Relation::Disjoint; + case Relation::Intersects: + return Relation::Intersects; + case Relation::Subset: + return Relation::Disjoint; + case Relation::Superset: + return Relation::Intersects; + } + + LUAU_UNREACHABLE(); + return Relation::Intersects; +} + +static bool isTypeVariable(TypeId ty) +{ + return get(ty) || get(ty) || get(ty) || get(ty); +} + +Relation relate(TypeId left, TypeId right); + +Relation relateTables(TypeId left, TypeId right) +{ + NotNull leftTable{get(left)}; + NotNull rightTable{get(right)}; + LUAU_ASSERT(1 == rightTable->props.size()); + + const auto [propName, rightProp] = *begin(rightTable->props); + + auto it = leftTable->props.find(propName); + if (it == leftTable->props.end()) + { + // Every table lacking a property is a supertype of a table having that + // property but the reverse is not true. + return Relation::Superset; + } + + const Property leftProp = it->second; + + Relation r = relate(leftProp.type(), rightProp.type()); + if (r == Relation::Coincident && 1 != leftTable->props.size()) + { + // eg {tag: "cat", prop: string} & {tag: "cat"} + return Relation::Subset; + } + else + return r; +} + +// A cheap and approximate subtype test +Relation relate(TypeId left, TypeId right) +{ + // TODO nice to have: Relate functions of equal argument and return arity + + left = follow(left); + right = follow(right); + + if (left == right) + return Relation::Coincident; + + if (get(left)) + { + if (get(right)) + return Relation::Subset; + else if (get(right)) + return Relation::Coincident; + else if (get(right)) + return Relation::Disjoint; + else + return Relation::Superset; + } + + if (get(right)) + return flip(relate(right, left)); + + if (get(left)) + { + if (get(right)) + return Relation::Coincident; + else + return Relation::Superset; + } + + if (get(right)) + return flip(relate(right, left)); + + // Type variables + // * FreeType + // * GenericType + // * BlockedType + // * PendingExpansionType + + // Tops and bottoms + // * ErrorType + // * AnyType + // * NeverType + // * UnknownType + + // Concrete + // * PrimitiveType + // * SingletonType + // * FunctionType + // * TableType + // * MetatableType + // * ClassType + // * UnionType + // * IntersectionType + // * NegationType + + if (isTypeVariable(left) || isTypeVariable(right)) + return Relation::Intersects; + + if (get(left)) + { + if (get(right)) + return Relation::Coincident; + else if (get(right)) + return Relation::Subset; + else + return Relation::Disjoint; + } + if (get(right)) + return flip(relate(right, left)); + + if (get(left)) + { + if (get(right)) + return Relation::Coincident; + else + return Relation::Subset; + } + if (get(right)) + return flip(relate(right, left)); + + if (auto ut = get(left)) + return Relation::Intersects; + else if (auto ut = get(right)) + return Relation::Intersects; + + if (auto ut = get(left)) + return Relation::Intersects; + else if (auto ut = get(right)) + return Relation::Intersects; + + if (auto rnt = get(right)) + { + Relation a = relate(left, rnt->ty); + switch (a) + { + case Relation::Coincident: + // number & ~number + return Relation::Disjoint; + case Relation::Disjoint: + if (get(left)) + { + // ~number & ~string + return Relation::Intersects; + } + else + { + // number & ~string + return Relation::Subset; + } + case Relation::Intersects: + // ~(false?) & ~boolean + return Relation::Intersects; + case Relation::Subset: + // "hello" & ~string + return Relation::Disjoint; + case Relation::Superset: + // ~function & ~(false?) -> ~function + // boolean & ~(false?) -> true + // string & ~"hello" -> string & ~"hello" + return Relation::Intersects; + } + } + else if (get(left)) + return flip(relate(right, left)); + + if (auto lp = get(left)) + { + if (auto rp = get(right)) + { + if (lp->type == rp->type) + return Relation::Coincident; + else + return Relation::Disjoint; + } + + if (auto rs = get(right)) + { + if (lp->type == PrimitiveType::String && rs->variant.get_if()) + return Relation::Superset; + else if (lp->type == PrimitiveType::Boolean && rs->variant.get_if()) + return Relation::Superset; + else + return Relation::Disjoint; + } + + if (lp->type == PrimitiveType::Function) + { + if (get(right)) + return Relation::Superset; + else + return Relation::Disjoint; + } + if (lp->type == PrimitiveType::Table) + { + if (get(right)) + return Relation::Superset; + else + return Relation::Disjoint; + } + + if (get(right) || get(right) || get(right) || get(right)) + return Relation::Disjoint; + } + + if (auto ls = get(left)) + { + if (get(right) || get(right) || get(right) || get(right)) + return Relation::Disjoint; + + if (get(right)) + return flip(relate(right, left)); + if (auto rs = get(right)) + { + if (ls->variant == rs->variant) + return Relation::Coincident; + else + return Relation::Disjoint; + } + } + + if (get(left)) + { + if (auto rp = get(right)) + { + if (rp->type == PrimitiveType::Function) + return Relation::Subset; + else + return Relation::Disjoint; + } + else + return Relation::Intersects; + } + + if (auto lt = get(left)) + { + if (auto rp = get(right)) + { + if (rp->type == PrimitiveType::Table) + return Relation::Subset; + else + return Relation::Disjoint; + } + else if (auto rt = get(right)) + { + // TODO PROBABLY indexers and metatables. + if (1 == rt->props.size()) + { + Relation r = relateTables(left, right); + /* + * A reduction of these intersections is certainly possible, but + * it would require minting new table types. Also, I don't think + * it's super likely for this to arise from a refinement. + * + * Time will tell! + * + * ex we simplify this + * {tag: string} & {tag: "cat"} + * but not this + * {tag: string, prop: number} & {tag: "cat"} + */ + if (lt->props.size() > 1 && r == Relation::Superset) + return Relation::Intersects; + else + return r; + } + else if (1 == lt->props.size()) + return flip(relate(right, left)); + else + return Relation::Intersects; + } + // TODO metatables + + return Relation::Disjoint; + } + + if (auto ct = get(left)) + { + if (auto rct = get(right)) + { + if (isSubclass(ct, rct)) + return Relation::Subset; + else if (isSubclass(rct, ct)) + return Relation::Superset; + else + return Relation::Disjoint; + } + + return Relation::Disjoint; + } + + return Relation::Intersects; +} + +TypeId TypeSimplifier::mkNegation(TypeId ty) +{ + TypeId result = nullptr; + + if (ty == builtinTypes->truthyType) + result = builtinTypes->falsyType; + else if (ty == builtinTypes->falsyType) + result = builtinTypes->truthyType; + else if (auto ntv = get(ty)) + result = follow(ntv->ty); + else + result = arena->addType(NegationType{ty}); + + return result; +} + +TypeId TypeSimplifier::intersectFromParts(std::set parts) +{ + if (0 == parts.size()) + return builtinTypes->neverType; + else if (1 == parts.size()) + return *begin(parts); + + { + auto it = begin(parts); + while (it != end(parts)) + { + TypeId t = follow(*it); + + auto copy = it; + ++it; + + if (auto ut = get(t)) + { + for (TypeId part : ut) + parts.insert(part); + parts.erase(copy); + } + } + } + + std::set newParts; + + /* + * It is possible that the parts of the passed intersection are themselves + * reducable. + * + * eg false & boolean + * + * We do a comparison between each pair of types and look for things that we + * can elide. + */ + for (TypeId part : parts) + { + if (newParts.empty()) + { + newParts.insert(part); + continue; + } + + auto it = begin(newParts); + while (it != end(newParts)) + { + TypeId p = *it; + + switch (relate(part, p)) + { + case Relation::Disjoint: + // eg boolean & string + return builtinTypes->neverType; + case Relation::Subset: + { + /* part is a subset of p. Remove p from the set and replace it + * with part. + * + * eg boolean & true + */ + auto saveIt = it; + ++it; + newParts.erase(saveIt); + continue; + } + case Relation::Coincident: + case Relation::Superset: + { + /* part is coincident or a superset of p. We do not need to + * include part in the final intersection. + * + * ex true & boolean + */ + ++it; + continue; + } + case Relation::Intersects: + { + /* It's complicated! A simplification may still be possible, + * but we have to pull the types apart to figure it out. + * + * ex boolean & ~false + */ + std::optional simplified = basicIntersect(part, p); + + auto saveIt = it; + ++it; + + if (simplified) + { + newParts.erase(saveIt); + newParts.insert(*simplified); + } + else + newParts.insert(part); + continue; + } + } + } + } + + if (0 == newParts.size()) + return builtinTypes->neverType; + else if (1 == newParts.size()) + return *begin(newParts); + else + return arena->addType(IntersectionType{std::vector{begin(newParts), end(newParts)}}); +} + +TypeId TypeSimplifier::intersectUnionWithType(TypeId left, TypeId right) +{ + const UnionType* leftUnion = get(left); + LUAU_ASSERT(leftUnion); + + bool changed = false; + std::set newParts; + + for (TypeId part : leftUnion) + { + TypeId simplified = intersect(right, part); + changed |= simplified != part; + + if (get(simplified)) + { + changed = true; + continue; + } + + newParts.insert(simplified); + } + + if (!changed) + return left; + else if (newParts.empty()) + return builtinTypes->neverType; + else if (newParts.size() == 1) + return *begin(newParts); + else + return arena->addType(UnionType{std::vector(begin(newParts), end(newParts))}); +} + +TypeId TypeSimplifier::intersectUnions(TypeId left, TypeId right) +{ + const UnionType* leftUnion = get(left); + LUAU_ASSERT(leftUnion); + + const UnionType* rightUnion = get(right); + LUAU_ASSERT(rightUnion); + + std::set newParts; + + for (TypeId leftPart : leftUnion) + { + for (TypeId rightPart : rightUnion) + { + TypeId simplified = intersect(leftPart, rightPart); + if (get(simplified)) + continue; + + newParts.insert(simplified); + } + } + + if (newParts.empty()) + return builtinTypes->neverType; + else if (newParts.size() == 1) + return *begin(newParts); + else + return arena->addType(UnionType{std::vector(begin(newParts), end(newParts))}); +} + +TypeId TypeSimplifier::intersectNegatedUnion(TypeId left, TypeId right) +{ + // ~(A | B) & C + // (~A & C) & (~B & C) + + const NegationType* leftNegation = get(left); + LUAU_ASSERT(leftNegation); + + TypeId negatedTy = follow(leftNegation->ty); + + const UnionType* negatedUnion = get(negatedTy); + LUAU_ASSERT(negatedUnion); + + bool changed = false; + std::set newParts; + + for (TypeId part : negatedUnion) + { + Relation r = relate(part, right); + switch (r) + { + case Relation::Disjoint: + // If A is disjoint from B, then ~A & B is just B. + // + // ~(false?) & true + // (~false & true) & (~nil & true) + // true & true + newParts.insert(right); + break; + case Relation::Coincident: + // If A is coincident with or a superset of B, then ~A & B is never. + // + // ~(false?) & false + // (~false & false) & (~nil & false) + // never & false + // + // fallthrough + case Relation::Superset: + // If A is a superset of B, then ~A & B is never. + // + // ~(boolean | nil) & true + // (~boolean & true) & (~boolean & nil) + // never & nil + return builtinTypes->neverType; + case Relation::Subset: + case Relation::Intersects: + // If A is a subset of B, then ~A & B is a bit more complicated. We need to think harder. + // + // ~(false?) & boolean + // (~false & boolean) & (~nil & boolean) + // true & boolean + TypeId simplified = intersectTypeWithNegation(mkNegation(part), right); + changed |= simplified != right; + if (get(simplified)) + changed = true; + else + newParts.insert(simplified); + break; + } + } + + if (!changed) + return right; + else + return intersectFromParts(std::move(newParts)); +} + +TypeId TypeSimplifier::intersectTypeWithNegation(TypeId left, TypeId right) +{ + const NegationType* leftNegation = get(left); + LUAU_ASSERT(leftNegation); + + TypeId negatedTy = follow(leftNegation->ty); + + if (negatedTy == right) + return builtinTypes->neverType; + + if (auto ut = get(negatedTy)) + { + // ~(A | B) & C + // (~A & C) & (~B & C) + + bool changed = false; + std::set newParts; + + for (TypeId part : ut) + { + Relation r = relate(part, right); + switch (r) + { + case Relation::Coincident: + // ~(false?) & nil + // (~false & nil) & (~nil & nil) + // nil & never + // + // fallthrough + case Relation::Superset: + // ~(boolean | string) & true + // (~boolean & true) & (~boolean & string) + // never & string + + return builtinTypes->neverType; + + case Relation::Disjoint: + // ~nil & boolean + newParts.insert(right); + break; + + case Relation::Subset: + // ~false & boolean + // fallthrough + case Relation::Intersects: + // FIXME: The mkNegation here is pretty unfortunate. + // Memoizing this will probably be important. + changed = true; + newParts.insert(right); + newParts.insert(mkNegation(part)); + } + } + + if (!changed) + return right; + else + return intersectFromParts(std::move(newParts)); + } + + if (auto rightUnion = get(right)) + { + // ~A & (B | C) + bool changed = false; + std::set newParts; + + for (TypeId part : rightUnion) + { + Relation r = relate(negatedTy, part); + switch (r) + { + case Relation::Coincident: + changed = true; + continue; + case Relation::Disjoint: + newParts.insert(part); + break; + case Relation::Superset: + changed = true; + continue; + case Relation::Subset: + // fallthrough + case Relation::Intersects: + changed = true; + newParts.insert(arena->addType(IntersectionType{{left, part}})); + } + } + + if (!changed) + return right; + else if (0 == newParts.size()) + return builtinTypes->neverType; + else if (1 == newParts.size()) + return *begin(newParts); + else + return arena->addType(UnionType{std::vector{begin(newParts), end(newParts)}}); + } + + if (auto pt = get(right); pt && pt->type == PrimitiveType::Boolean) + { + if (auto st = get(negatedTy)) + { + if (st->variant == BooleanSingleton{true}) + return builtinTypes->falseType; + else if (st->variant == BooleanSingleton{false}) + return builtinTypes->trueType; + else + // boolean & ~"hello" + return builtinTypes->booleanType; + } + } + + Relation r = relate(negatedTy, right); + + switch (r) + { + case Relation::Disjoint: + // ~boolean & string + return right; + case Relation::Coincident: + // ~string & string + // fallthrough + case Relation::Superset: + // ~string & "hello" + return builtinTypes->neverType; + case Relation::Subset: + // ~string & unknown + // ~"hello" & string + // fallthrough + case Relation::Intersects: + // ~("hello" | boolean) & string + // fallthrough + default: + return arena->addType(IntersectionType{{left, right}}); + } +} + +TypeId TypeSimplifier::intersectNegations(TypeId left, TypeId right) +{ + const NegationType* leftNegation = get(left); + LUAU_ASSERT(leftNegation); + + if (get(follow(leftNegation->ty))) + return intersectNegatedUnion(left, right); + + const NegationType* rightNegation = get(right); + LUAU_ASSERT(rightNegation); + + if (get(follow(rightNegation->ty))) + return intersectNegatedUnion(right, left); + + Relation r = relate(leftNegation->ty, rightNegation->ty); + + switch (r) + { + case Relation::Coincident: + // ~true & ~true + return left; + case Relation::Subset: + // ~true & ~boolean + return right; + case Relation::Superset: + // ~boolean & ~true + return left; + case Relation::Intersects: + case Relation::Disjoint: + default: + // ~boolean & ~string + return arena->addType(IntersectionType{{left, right}}); + } +} + +TypeId TypeSimplifier::intersectIntersectionWithType(TypeId left, TypeId right) +{ + const IntersectionType* leftIntersection = get(left); + LUAU_ASSERT(leftIntersection); + + bool changed = false; + std::set newParts; + + for (TypeId part : leftIntersection) + { + Relation r = relate(part, right); + switch (r) + { + case Relation::Disjoint: + return builtinTypes->neverType; + case Relation::Coincident: + newParts.insert(part); + continue; + case Relation::Subset: + newParts.insert(part); + continue; + case Relation::Superset: + newParts.insert(right); + changed = true; + continue; + default: + newParts.insert(part); + newParts.insert(right); + changed = true; + continue; + } + } + + // It is sometimes the case that an intersection operation will result in + // clipping a free type from the result. + // + // eg (number & 'a) & string --> never + // + // We want to only report the free types that are part of the result. + for (TypeId part : newParts) + { + if (isTypeVariable(part)) + blockedTypes.insert(part); + } + + if (!changed) + return left; + return intersectFromParts(std::move(newParts)); +} + +std::optional TypeSimplifier::basicIntersect(TypeId left, TypeId right) +{ + if (get(left)) + return right; + if (get(right)) + return left; + if (get(left)) + return left; + if (get(right)) + return right; + + if (auto pt = get(left); pt && pt->type == PrimitiveType::Boolean) + { + if (auto st = get(right); st && st->variant.get_if()) + return right; + if (auto nt = get(right)) + { + if (auto st = get(follow(nt->ty)); st && st->variant.get_if()) + { + if (st->variant == BooleanSingleton{true}) + return builtinTypes->falseType; + else + return builtinTypes->trueType; + } + } + } + else if (auto pt = get(right); pt && pt->type == PrimitiveType::Boolean) + { + if (auto st = get(left); st && st->variant.get_if()) + return left; + if (auto nt = get(left)) + { + if (auto st = get(follow(nt->ty)); st && st->variant.get_if()) + { + if (st->variant == BooleanSingleton{true}) + return builtinTypes->falseType; + else + return builtinTypes->trueType; + } + } + } + + if (const auto [lt, rt] = get2(left, right); lt && rt) + { + if (1 == lt->props.size()) + { + const auto [propName, leftProp] = *begin(lt->props); + + auto it = rt->props.find(propName); + if (it != rt->props.end()) + { + Relation r = relate(leftProp.type(), it->second.type()); + + switch (r) + { + case Relation::Disjoint: + return builtinTypes->neverType; + case Relation::Coincident: + return right; + default: + break; + } + } + } + else if (1 == rt->props.size()) + return basicIntersect(right, left); + } + + Relation relation = relate(left, right); + if (left == right || Relation::Coincident == relation) + return left; + + if (relation == Relation::Disjoint) + return builtinTypes->neverType; + else if (relation == Relation::Subset) + return left; + else if (relation == Relation::Superset) + return right; + + return std::nullopt; +} + +TypeId TypeSimplifier::intersect(TypeId left, TypeId right) +{ + RecursionLimiter rl(&recursionDepth, 15); + + left = simplify(left); + right = simplify(right); + + if (get(left)) + return right; + if (get(right)) + return left; + if (get(left)) + return left; + if (get(right)) + return right; + + if (isTypeVariable(left)) + { + blockedTypes.insert(left); + return arena->addType(IntersectionType{{left, right}}); + } + + if (isTypeVariable(right)) + { + blockedTypes.insert(right); + return arena->addType(IntersectionType{{left, right}}); + } + + if (auto ut = get(left)) + { + if (get(right)) + return intersectUnions(left, right); + else + return intersectUnionWithType(left, right); + } + else if (auto ut = get(right)) + return intersectUnionWithType(right, left); + + if (auto it = get(left)) + return intersectIntersectionWithType(left, right); + else if (auto it = get(right)) + return intersectIntersectionWithType(right, left); + + if (get(left)) + { + if (get(right)) + return intersectNegations(left, right); + else + return intersectTypeWithNegation(left, right); + } + else if (get(right)) + return intersectTypeWithNegation(right, left); + + std::optional res = basicIntersect(left, right); + if (res) + return *res; + else + return arena->addType(IntersectionType{{left, right}}); +} + +TypeId TypeSimplifier::union_(TypeId left, TypeId right) +{ + RecursionLimiter rl(&recursionDepth, 15); + + left = simplify(left); + right = simplify(right); + + if (auto leftUnion = get(left)) + { + bool changed = false; + std::set newParts; + for (TypeId part : leftUnion) + { + if (get(part)) + { + changed = true; + continue; + } + + Relation r = relate(part, right); + switch (r) + { + case Relation::Coincident: + case Relation::Superset: + return left; + default: + newParts.insert(part); + newParts.insert(right); + changed = true; + break; + } + } + + if (!changed) + return left; + if (1 == newParts.size()) + return *begin(newParts); + return arena->addType(UnionType{std::vector{begin(newParts), end(newParts)}}); + } + else if (get(right)) + return union_(right, left); + + Relation r = relate(left, right); + if (left == right || r == Relation::Coincident || r == Relation::Superset) + return left; + + if (r == Relation::Subset) + return right; + + if (auto as = get(left)) + { + if (auto abs = as->variant.get_if()) + { + if (auto bs = get(right)) + { + if (auto bbs = bs->variant.get_if()) + { + if (abs->value != bbs->value) + return builtinTypes->booleanType; + } + } + } + } + + return arena->addType(UnionType{{left, right}}); +} + +TypeId TypeSimplifier::simplify(TypeId ty) +{ + DenseHashSet seen{nullptr}; + return simplify(ty, seen); +} + +TypeId TypeSimplifier::simplify(TypeId ty, DenseHashSet& seen) +{ + RecursionLimiter limiter(&recursionDepth, 60); + + ty = follow(ty); + + if (seen.find(ty)) + return ty; + seen.insert(ty); + + if (auto nt = get(ty)) + { + TypeId negatedTy = follow(nt->ty); + if (get(negatedTy)) + return builtinTypes->neverType; + else if (get(negatedTy)) + return builtinTypes->anyType; + if (auto nnt = get(negatedTy)) + return simplify(nnt->ty, seen); + } + + // Promote {x: never} to never + if (auto tt = get(ty)) + { + if (1 == tt->props.size()) + { + TypeId propTy = simplify(begin(tt->props)->second.type(), seen); + if (get(propTy)) + return builtinTypes->neverType; + } + } + + return ty; +} + +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right) +{ + TypeSimplifier s{builtinTypes, arena}; + + // fprintf(stderr, "Intersect %s and %s ...\n", toString(left).c_str(), toString(right).c_str()); + + TypeId res = s.intersect(left, right); + + // fprintf(stderr, "Intersect %s and %s -> %s\n", toString(left).c_str(), toString(right).c_str(), toString(res).c_str()); + + return SimplifyResult{res, std::move(s.blockedTypes)}; +} + +SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right) +{ + TypeSimplifier s{builtinTypes, arena}; + + TypeId res = s.union_(left, right); + + // fprintf(stderr, "Union %s and %s -> %s\n", toString(a).c_str(), toString(b).c_str(), toString(res).c_str()); + + return SimplifyResult{res, std::move(s.blockedTypes)}; +} + +} // namespace Luau diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index f5b908e..347380c 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1639,6 +1639,11 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) } else if constexpr (std::is_same_v) return tos(c.resultPack) + " ~ unpack " + tos(c.sourcePack); + else if constexpr (std::is_same_v) + { + const char* op = c.mode == RefineConstraint::Union ? "union" : "intersect"; + return tos(c.resultType) + " ~ refine " + tos(c.type) + " " + op + " " + tos(c.discriminant); + } else if constexpr (std::is_same_v) return "reduce " + tos(c.ty); else if constexpr (std::is_same_v) @@ -1652,6 +1657,11 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) return visit(go, constraint.c); } +std::string toString(const Constraint& constraint) +{ + return toString(constraint, ToStringOptions{}); +} + std::string dump(const Constraint& c) { ToStringOptions opts; diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 53dd3b4..5d38f28 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -82,6 +82,8 @@ void TxnLog::concat(TxnLog rhs) for (auto& [tp, rep] : rhs.typePackChanges) typePackChanges[tp] = std::move(rep); + + radioactive |= rhs.radioactive; } void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) @@ -103,6 +105,8 @@ void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) for (auto& [tp, rep] : rhs.typePackChanges) typePackChanges[tp] = std::move(rep); + + radioactive |= rhs.radioactive; } void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) @@ -199,10 +203,14 @@ void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) for (auto& [tp, rep] : rhs.typePackChanges) typePackChanges[tp] = std::move(rep); + + radioactive |= rhs.radioactive; } void TxnLog::commit() { + LUAU_ASSERT(!radioactive); + for (auto& [ty, rep] : typeVarChanges) { if (!rep->dead) @@ -234,6 +242,8 @@ TxnLog TxnLog::inverse() for (auto& [tp, _rep] : typePackChanges) inversed.typePackChanges[tp] = std::make_unique(*tp); + inversed.radioactive = radioactive; + return inversed; } @@ -293,7 +303,8 @@ void TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) PendingType* TxnLog::queue(TypeId ty) { - LUAU_ASSERT(!ty->persistent); + if (ty->persistent) + radioactive = true; // Explicitly don't look in ancestors. If we have discovered something new // about this type, we don't want to mutate the parent's state. @@ -309,7 +320,8 @@ PendingType* TxnLog::queue(TypeId ty) PendingTypePack* TxnLog::queue(TypePackId tp) { - LUAU_ASSERT(!tp->persistent); + if (tp->persistent) + radioactive = true; // Explicitly don't look in ancestors. If we have discovered something new // about this type, we don't want to mutate the parent's state. diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index a1f764a..40376e3 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -13,7 +13,6 @@ #include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/Type.h" -#include "Luau/TypeReduction.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier.h" #include "Luau/TypeFamily.h" @@ -21,7 +20,6 @@ #include LUAU_FASTFLAG(DebugLuauMagicTypes) -LUAU_FASTFLAG(DebugLuauDontReduceTypes) namespace Luau { @@ -117,7 +115,7 @@ struct TypeChecker2 TypeId checkForFamilyInhabitance(TypeId instance, Location location) { TxnLog fake{}; - reportErrors(reduceFamilies(instance, location, NotNull{&testArena}, builtinTypes, &fake, true).errors); + reportErrors(reduceFamilies(instance, location, NotNull{&testArena}, builtinTypes, stack.back(), NotNull{&normalizer}, &fake, true).errors); return instance; } @@ -1002,7 +1000,9 @@ struct TypeChecker2 LUAU_ASSERT(ftv); reportErrors(tryUnify(stack.back(), call->location, ftv->retTypes, expectedRetType, CountMismatch::Context::Return, /* genericsOkay */ true)); - reportErrors(reduceFamilies(ftv->retTypes, call->location, NotNull{&testArena}, builtinTypes, &fake, true).errors); + reportErrors( + reduceFamilies(ftv->retTypes, call->location, NotNull{&testArena}, builtinTypes, stack.back(), NotNull{&normalizer}, &fake, true) + .errors); auto it = begin(expectedArgTypes); size_t i = 0; @@ -1020,7 +1020,7 @@ struct TypeChecker2 Location argLoc = argLocs.at(i >= argLocs.size() ? argLocs.size() - 1 : i); reportErrors(tryUnify(stack.back(), argLoc, expectedArg, arg, CountMismatch::Context::Arg, /* genericsOkay */ true)); - reportErrors(reduceFamilies(arg, argLoc, NotNull{&testArena}, builtinTypes, &fake, true).errors); + reportErrors(reduceFamilies(arg, argLoc, NotNull{&testArena}, builtinTypes, stack.back(), NotNull{&normalizer}, &fake, true).errors); ++it; ++i; @@ -1032,12 +1032,11 @@ struct TypeChecker2 { TypePackId remainingArgs = testArena.addTypePack(TypePack{std::move(slice), std::nullopt}); reportErrors(tryUnify(stack.back(), argLocs.back(), *tail, remainingArgs, CountMismatch::Context::Arg, /* genericsOkay */ true)); - reportErrors(reduceFamilies(remainingArgs, argLocs.back(), NotNull{&testArena}, builtinTypes, &fake, true).errors); + reportErrors(reduceFamilies( + remainingArgs, argLocs.back(), NotNull{&testArena}, builtinTypes, stack.back(), NotNull{&normalizer}, &fake, true) + .errors); } } - - // We do not need to do an arity test because this overload was - // selected based on its arity already matching. } else { @@ -1160,25 +1159,26 @@ struct TypeChecker2 return ty; } - void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context) + void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context, TypeId astIndexExprTy) { visit(expr, ValueContext::RValue); - TypeId leftType = stripFromNilAndReport(lookupType(expr), location); - checkIndexTypeFromType(leftType, propName, location, context); + checkIndexTypeFromType(leftType, propName, location, context, astIndexExprTy); } void visit(AstExprIndexName* indexName, ValueContext context) { - visitExprName(indexName->expr, indexName->location, indexName->index.value, context); + // If we're indexing like _.foo - foo could either be a prop or a string. + visitExprName(indexName->expr, indexName->location, indexName->index.value, context, builtinTypes->stringType); } void visit(AstExprIndexExpr* indexExpr, ValueContext context) { if (auto str = indexExpr->index->as()) { + TypeId astIndexExprType = lookupType(indexExpr->index); const std::string stringValue(str->value.data, str->value.size); - visitExprName(indexExpr->expr, indexExpr->location, stringValue, context); + visitExprName(indexExpr->expr, indexExpr->location, stringValue, context, astIndexExprType); return; } @@ -1198,6 +1198,8 @@ struct TypeChecker2 else reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); } + else if (auto cls = get(exprType); cls && cls->indexer) + reportErrors(tryUnify(scope, indexExpr->index->location, indexType, cls->indexer->indexType)); else if (get(exprType) && isOptional(exprType)) reportError(OptionalValueAccess{exprType}, indexExpr->location); } @@ -1209,32 +1211,52 @@ struct TypeChecker2 visitGenerics(fn->generics, fn->genericPacks); TypeId inferredFnTy = lookupType(fn); - const FunctionType* inferredFtv = get(inferredFnTy); - LUAU_ASSERT(inferredFtv); - // There is no way to write an annotation for the self argument, so we - // cannot do anything to check it. - auto argIt = begin(inferredFtv->argTypes); - if (fn->self) - ++argIt; - - for (const auto& arg : fn->args) + const NormalizedType* normalizedFnTy = normalizer.normalize(inferredFnTy); + if (!normalizedFnTy) { - if (argIt == end(inferredFtv->argTypes)) - break; + reportError(CodeTooComplex{}, fn->location); + } + else if (get(normalizedFnTy->errors)) + { + // Nothing + } + else if (!normalizedFnTy->isFunction()) + { + ice->ice("Internal error: Lambda has non-function type " + toString(inferredFnTy), fn->location); + } + else + { + if (1 != normalizedFnTy->functions.parts.size()) + ice->ice("Unexpected: Lambda has unexpected type " + toString(inferredFnTy), fn->location); - if (arg->annotation) + const FunctionType* inferredFtv = get(normalizedFnTy->functions.parts.front()); + LUAU_ASSERT(inferredFtv); + + // There is no way to write an annotation for the self argument, so we + // cannot do anything to check it. + auto argIt = begin(inferredFtv->argTypes); + if (fn->self) + ++argIt; + + for (const auto& arg : fn->args) { - TypeId inferredArgTy = *argIt; - TypeId annotatedArgTy = lookupAnnotation(arg->annotation); + if (argIt == end(inferredFtv->argTypes)) + break; - if (!isSubtype(inferredArgTy, annotatedArgTy, stack.back())) + if (arg->annotation) { - reportError(TypeMismatch{inferredArgTy, annotatedArgTy}, arg->location); - } - } + TypeId inferredArgTy = *argIt; + TypeId annotatedArgTy = lookupAnnotation(arg->annotation); - ++argIt; + if (!isSubtype(inferredArgTy, annotatedArgTy, stack.back())) + { + reportError(TypeMismatch{inferredArgTy, annotatedArgTy}, arg->location); + } + } + + ++argIt; + } } visit(fn->body); @@ -1345,6 +1367,10 @@ struct TypeChecker2 TypeId leftType = lookupType(expr->left); TypeId rightType = lookupType(expr->right); + TypeId expectedResult = lookupType(expr); + + if (get(expectedResult)) + return expectedResult; if (expr->op == AstExprBinary::Op::Or) { @@ -1432,7 +1458,11 @@ struct TypeChecker2 TypeId instantiatedMm = module->astOverloadResolvedTypes[key]; if (!instantiatedMm) - reportError(CodeTooComplex{}, expr->location); + { + // reportError(CodeTooComplex{}, expr->location); + // was handled by a type family + return expectedResult; + } else if (const FunctionType* ftv = get(follow(instantiatedMm))) { @@ -1715,7 +1745,7 @@ struct TypeChecker2 { // No further validation is necessary in this case. The main logic for // _luau_print is contained in lookupAnnotation. - if (FFlag::DebugLuauMagicTypes && ty->name == "_luau_print" && ty->parameters.size > 0) + if (FFlag::DebugLuauMagicTypes && ty->name == "_luau_print") return; for (const AstTypeOrPack& param : ty->parameters) @@ -1764,6 +1794,7 @@ struct TypeChecker2 if (packsProvided != 0) { reportError(GenericError{"Type parameters must come before type pack parameters"}, ty->location); + continue; } if (typesProvided < typesRequired) @@ -1792,7 +1823,11 @@ struct TypeChecker2 if (extraTypes != 0 && packsProvided == 0) { - packsProvided += 1; + // Extra types are only collected into a pack if a pack is expected + if (packsRequired != 0) + packsProvided += 1; + else + typesProvided += extraTypes; } for (size_t i = typesProvided; i < typesRequired; ++i) @@ -1943,69 +1978,6 @@ struct TypeChecker2 } } - void reduceTypes() - { - if (FFlag::DebugLuauDontReduceTypes) - return; - - for (auto [_, scope] : module->scopes) - { - for (auto& [_, b] : scope->bindings) - { - if (auto reduced = module->reduction->reduce(b.typeId)) - b.typeId = *reduced; - } - - if (auto reduced = module->reduction->reduce(scope->returnType)) - scope->returnType = *reduced; - - if (scope->varargPack) - { - if (auto reduced = module->reduction->reduce(*scope->varargPack)) - scope->varargPack = *reduced; - } - - auto reduceMap = [this](auto& map) { - for (auto& [_, tf] : map) - { - if (auto reduced = module->reduction->reduce(tf)) - tf = *reduced; - } - }; - - reduceMap(scope->exportedTypeBindings); - reduceMap(scope->privateTypeBindings); - reduceMap(scope->privateTypePackBindings); - for (auto& [_, space] : scope->importedTypeBindings) - reduceMap(space); - } - - auto reduceOrError = [this](auto& map) { - for (auto [ast, t] : map) - { - if (!t) - continue; // Reminder: this implies that the recursion limit was exceeded. - else if (auto reduced = module->reduction->reduce(t)) - map[ast] = *reduced; - else - reportError(NormalizationTooComplex{}, ast->location); - } - }; - - module->astOriginalResolvedTypes = module->astResolvedTypes; - - // Both [`Module::returnType`] and [`Module::exportedTypeBindings`] are empty here, and - // is populated by [`Module::clonePublicInterface`] in the future, so by that point these - // two aforementioned fields will only contain types that are irreducible. - reduceOrError(module->astTypes); - reduceOrError(module->astTypePacks); - reduceOrError(module->astExpectedTypes); - reduceOrError(module->astOriginalCallTypes); - reduceOrError(module->astOverloadResolvedTypes); - reduceOrError(module->astResolvedTypes); - reduceOrError(module->astResolvedTypePacks); - } - template bool isSubtype(TID subTy, TID superTy, NotNull scope, bool genericsOkay = false) { @@ -2034,6 +2006,9 @@ struct TypeChecker2 void reportError(TypeErrorData data, const Location& location) { + if (auto utk = get_if(&data)) + diagnoseMissingTableKey(utk, data); + module->errors.emplace_back(location, module->name, std::move(data)); if (logger) @@ -2052,7 +2027,7 @@ struct TypeChecker2 } // If the provided type does not have the named property, report an error. - void checkIndexTypeFromType(TypeId tableTy, const std::string& prop, const Location& location, ValueContext context) + void checkIndexTypeFromType(TypeId tableTy, const std::string& prop, const Location& location, ValueContext context, TypeId astIndexExprType) { const NormalizedType* norm = normalizer.normalize(tableTy); if (!norm) @@ -2069,7 +2044,7 @@ struct TypeChecker2 return; std::unordered_set seen; - bool found = hasIndexTypeFromType(ty, prop, location, seen); + bool found = hasIndexTypeFromType(ty, prop, location, seen, astIndexExprType); foundOneProp |= found; if (!found) typesMissingTheProp.push_back(ty); @@ -2129,7 +2104,7 @@ struct TypeChecker2 } } - bool hasIndexTypeFromType(TypeId ty, const std::string& prop, const Location& location, std::unordered_set& seen) + bool hasIndexTypeFromType(TypeId ty, const std::string& prop, const Location& location, std::unordered_set& seen, TypeId astIndexExprType) { // If we have already encountered this type, we must assume that some // other codepath will do the right thing and signal false if the @@ -2153,31 +2128,83 @@ struct TypeChecker2 if (findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location)) return true; - else if (tt->indexer && isPrim(tt->indexer->indexType, PrimitiveType::String)) - return true; + if (tt->indexer) + { + TypeId indexType = follow(tt->indexer->indexType); + if (isPrim(indexType, PrimitiveType::String)) + return true; + // If the indexer looks like { [any] : _} - the prop lookup should be allowed! + else if (get(indexType) || get(indexType)) + return true; + } - else - return false; + return false; } else if (const ClassType* cls = get(ty)) - return bool(lookupClassProp(cls, prop)); + { + // If the property doesn't exist on the class, we consult the indexer + // We need to check if the type of the index expression foo (x[foo]) + // is compatible with the indexer's indexType + // Construct the intersection and test inhabitedness! + if (auto property = lookupClassProp(cls, prop)) + return true; + if (cls->indexer) + { + TypeId inhabitatedTestType = testArena.addType(IntersectionType{{cls->indexer->indexType, astIndexExprType}}); + return normalizer.isInhabited(inhabitatedTestType); + } + return false; + } else if (const UnionType* utv = get(ty)) return std::all_of(begin(utv), end(utv), [&](TypeId part) { - return hasIndexTypeFromType(part, prop, location, seen); + return hasIndexTypeFromType(part, prop, location, seen, astIndexExprType); }); else if (const IntersectionType* itv = get(ty)) return std::any_of(begin(itv), end(itv), [&](TypeId part) { - return hasIndexTypeFromType(part, prop, location, seen); + return hasIndexTypeFromType(part, prop, location, seen, astIndexExprType); }); else return false; } + + void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const + { + std::string_view sv(utk->key); + std::set candidates; + + auto accumulate = [&](const TableType::Props& props) { + for (const auto& [name, ty] : props) + { + if (sv != name && equalsLower(sv, name)) + candidates.insert(name); + } + }; + + if (auto ttv = getTableType(utk->table)) + accumulate(ttv->props); + else if (auto ctv = get(follow(utk->table))) + { + while (ctv) + { + accumulate(ctv->props); + + if (!ctv->parent) + break; + + ctv = get(*ctv->parent); + LUAU_ASSERT(ctv); + } + } + + if (!candidates.empty()) + data = TypeErrorData(UnknownPropButFoundLikeProp{utk->table, utk->key, candidates}); + } }; void check(NotNull builtinTypes, NotNull unifierState, DcrLogger* logger, const SourceModule& sourceModule, Module* module) { TypeChecker2 typeChecker{builtinTypes, unifierState, logger, &sourceModule, module}; - typeChecker.reduceTypes(); + typeChecker.visit(sourceModule.root); unfreeze(module->interfaceTypes); diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index 1941573..e5a06c0 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -7,6 +7,10 @@ #include "Luau/TxnLog.h" #include "Luau/Substitution.h" #include "Luau/ToString.h" +#include "Luau/TypeUtils.h" +#include "Luau/Unifier.h" +#include "Luau/Instantiation.h" +#include "Luau/Normalize.h" LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000); @@ -30,6 +34,11 @@ struct InstanceCollector : TypeOnceVisitor return true; } + bool visit(TypeId ty, const ClassType&) override + { + return false; + } + bool visit(TypePackId tp, const TypeFamilyInstanceTypePack&) override { // TypeOnceVisitor performs a depth-first traversal in the absence of @@ -52,20 +61,24 @@ struct FamilyReducer Location location; NotNull arena; NotNull builtins; - TxnLog* log = nullptr; - NotNull reducerLog; + TxnLog* parentLog = nullptr; + TxnLog log; bool force = false; + NotNull scope; + NotNull normalizer; FamilyReducer(std::deque queuedTys, std::deque queuedTps, Location location, NotNull arena, - NotNull builtins, TxnLog* log = nullptr, bool force = false) + NotNull builtins, NotNull scope, NotNull normalizer, TxnLog* parentLog = nullptr, bool force = false) : queuedTys(std::move(queuedTys)) , queuedTps(std::move(queuedTps)) , location(location) , arena(arena) , builtins(builtins) - , log(log) - , reducerLog(NotNull{log ? log : TxnLog::empty()}) + , parentLog(parentLog) + , log(parentLog) , force(force) + , scope(scope) + , normalizer(normalizer) { } @@ -78,16 +91,16 @@ struct FamilyReducer SkipTestResult testForSkippability(TypeId ty) { - ty = reducerLog->follow(ty); + ty = log.follow(ty); - if (reducerLog->is(ty)) + if (log.is(ty)) { if (!irreducible.contains(ty)) return SkipTestResult::Defer; else return SkipTestResult::Irreducible; } - else if (reducerLog->is(ty)) + else if (log.is(ty)) { return SkipTestResult::Irreducible; } @@ -97,16 +110,16 @@ struct FamilyReducer SkipTestResult testForSkippability(TypePackId ty) { - ty = reducerLog->follow(ty); + ty = log.follow(ty); - if (reducerLog->is(ty)) + if (log.is(ty)) { if (!irreducible.contains(ty)) return SkipTestResult::Defer; else return SkipTestResult::Irreducible; } - else if (reducerLog->is(ty)) + else if (log.is(ty)) { return SkipTestResult::Irreducible; } @@ -117,8 +130,8 @@ struct FamilyReducer template void replace(T subject, T replacement) { - if (log) - log->replace(subject, Unifiable::Bound{replacement}); + if (parentLog) + parentLog->replace(subject, Unifiable::Bound{replacement}); else asMutable(subject)->ty.template emplace>(replacement); @@ -208,37 +221,38 @@ struct FamilyReducer void stepType() { - TypeId subject = reducerLog->follow(queuedTys.front()); + TypeId subject = log.follow(queuedTys.front()); queuedTys.pop_front(); if (irreducible.contains(subject)) return; - if (const TypeFamilyInstanceType* tfit = reducerLog->get(subject)) + if (const TypeFamilyInstanceType* tfit = log.get(subject)) { if (!testParameters(subject, tfit)) return; - TypeFamilyReductionResult result = tfit->family->reducer(tfit->typeArguments, tfit->packArguments, arena, builtins, reducerLog); + TypeFamilyReductionResult result = + tfit->family->reducer(tfit->typeArguments, tfit->packArguments, arena, builtins, NotNull{&log}, scope, normalizer); handleFamilyReduction(subject, result); } } void stepPack() { - TypePackId subject = reducerLog->follow(queuedTps.front()); + TypePackId subject = log.follow(queuedTps.front()); queuedTps.pop_front(); if (irreducible.contains(subject)) return; - if (const TypeFamilyInstanceTypePack* tfit = reducerLog->get(subject)) + if (const TypeFamilyInstanceTypePack* tfit = log.get(subject)) { if (!testParameters(subject, tfit)) return; TypeFamilyReductionResult result = - tfit->family->reducer(tfit->typeArguments, tfit->packArguments, arena, builtins, reducerLog); + tfit->family->reducer(tfit->typeArguments, tfit->packArguments, arena, builtins, NotNull{&log}, scope, normalizer); handleFamilyReduction(subject, result); } } @@ -253,9 +267,9 @@ struct FamilyReducer }; static FamilyGraphReductionResult reduceFamiliesInternal(std::deque queuedTys, std::deque queuedTps, Location location, - NotNull arena, NotNull builtins, TxnLog* log, bool force) + NotNull arena, NotNull builtins, NotNull scope, NotNull normalizer, TxnLog* log, bool force) { - FamilyReducer reducer{std::move(queuedTys), std::move(queuedTps), location, arena, builtins, log, force}; + FamilyReducer reducer{std::move(queuedTys), std::move(queuedTps), location, arena, builtins, scope, normalizer, log, force}; int iterationCount = 0; while (!reducer.done()) @@ -273,8 +287,8 @@ static FamilyGraphReductionResult reduceFamiliesInternal(std::deque queu return std::move(reducer.result); } -FamilyGraphReductionResult reduceFamilies( - TypeId entrypoint, Location location, NotNull arena, NotNull builtins, TxnLog* log, bool force) +FamilyGraphReductionResult reduceFamilies(TypeId entrypoint, Location location, NotNull arena, NotNull builtins, + NotNull scope, NotNull normalizer, TxnLog* log, bool force) { InstanceCollector collector; @@ -287,11 +301,11 @@ FamilyGraphReductionResult reduceFamilies( return FamilyGraphReductionResult{}; } - return reduceFamiliesInternal(std::move(collector.tys), std::move(collector.tps), location, arena, builtins, log, force); + return reduceFamiliesInternal(std::move(collector.tys), std::move(collector.tps), location, arena, builtins, scope, normalizer, log, force); } -FamilyGraphReductionResult reduceFamilies( - TypePackId entrypoint, Location location, NotNull arena, NotNull builtins, TxnLog* log, bool force) +FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location location, NotNull arena, NotNull builtins, + NotNull scope, NotNull normalizer, TxnLog* log, bool force) { InstanceCollector collector; @@ -304,7 +318,113 @@ FamilyGraphReductionResult reduceFamilies( return FamilyGraphReductionResult{}; } - return reduceFamiliesInternal(std::move(collector.tys), std::move(collector.tps), location, arena, builtins, log, force); + return reduceFamiliesInternal(std::move(collector.tys), std::move(collector.tps), location, arena, builtins, scope, normalizer, log, force); +} + +bool isPending(TypeId ty, NotNull log) +{ + return log->is(ty) || log->is(ty) || log->is(ty) || log->is(ty); +} + +TypeFamilyReductionResult addFamilyFn(std::vector typeParams, std::vector packParams, NotNull arena, + NotNull builtins, NotNull log, NotNull scope, NotNull normalizer) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + // TODO: ICE? + LUAU_ASSERT(false); + return {std::nullopt, true, {}, {}}; + } + + TypeId lhsTy = log->follow(typeParams.at(0)); + TypeId rhsTy = log->follow(typeParams.at(1)); + + if (isNumber(lhsTy) && isNumber(rhsTy)) + { + return {builtins->numberType, false, {}, {}}; + } + else if (log->is(lhsTy) || log->is(rhsTy)) + { + return {builtins->anyType, false, {}, {}}; + } + else if (log->is(lhsTy) || log->is(rhsTy)) + { + return {builtins->errorRecoveryType(), false, {}, {}}; + } + else if (log->is(lhsTy) || log->is(rhsTy)) + { + return {builtins->neverType, false, {}, {}}; + } + else if (isPending(lhsTy, log)) + { + return {std::nullopt, false, {lhsTy}, {}}; + } + else if (isPending(rhsTy, log)) + { + return {std::nullopt, false, {rhsTy}, {}}; + } + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional addMm = findMetatableEntry(builtins, dummy, lhsTy, "__add", Location{}); + bool reversed = false; + if (!addMm) + { + addMm = findMetatableEntry(builtins, dummy, rhsTy, "__add", Location{}); + reversed = true; + } + + if (!addMm) + return {std::nullopt, true, {}, {}}; + + if (isPending(log->follow(*addMm), log)) + return {std::nullopt, false, {log->follow(*addMm)}, {}}; + + const FunctionType* mmFtv = log->get(log->follow(*addMm)); + if (!mmFtv) + return {std::nullopt, true, {}, {}}; + + Instantiation instantiation{log.get(), arena.get(), TypeLevel{}, scope.get()}; + if (std::optional instantiatedAddMm = instantiation.substitute(log->follow(*addMm))) + { + if (const FunctionType* instantiatedMmFtv = get(*instantiatedAddMm)) + { + std::vector inferredArgs; + if (!reversed) + inferredArgs = {lhsTy, rhsTy}; + else + inferredArgs = {rhsTy, lhsTy}; + + TypePackId inferredArgPack = arena->addTypePack(std::move(inferredArgs)); + Unifier u{normalizer, Mode::Strict, scope, Location{}, Variance::Covariant, log.get()}; + u.tryUnify(inferredArgPack, instantiatedMmFtv->argTypes); + + if (std::optional ret = first(instantiatedMmFtv->retTypes); ret && u.errors.empty()) + { + return {u.log.follow(*ret), false, {}, {}}; + } + else + { + return {std::nullopt, true, {}, {}}; + } + } + else + { + return {builtins->errorRecoveryType(), false, {}, {}}; + } + } + else + { + // TODO: Not the nicest logic here. + return {std::nullopt, true, {}, {}}; + } +} + +BuiltinTypeFamilies::BuiltinTypeFamilies() + : addFamily{"Add", addFamilyFn} +{ } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 94c64ee..7e68039 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -18,7 +18,6 @@ #include "Luau/ToString.h" #include "Luau/Type.h" #include "Luau/TypePack.h" -#include "Luau/TypeReduction.h" #include "Luau/TypeUtils.h" #include "Luau/VisitType.h" @@ -269,7 +268,6 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo currentModule.reset(new Module); currentModule->name = module.name; currentModule->humanReadableName = module.humanReadableName; - currentModule->reduction = std::make_unique(NotNull{¤tModule->internalTypes}, builtinTypes, NotNull{iceHandler}); currentModule->type = module.type; currentModule->allocator = module.allocator; currentModule->names = module.names; @@ -4842,7 +4840,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat ty = follow(ty); const FunctionType* ftv = get(ty); - if (ftv && ftv->hasNoGenerics) + if (ftv && ftv->hasNoFreeOrGenericTypes) return ty; Instantiation instantiation{log, ¤tModule->internalTypes, scope->level, /*scope*/ nullptr}; diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp deleted file mode 100644 index b81cca7..0000000 --- a/Analysis/src/TypeReduction.cpp +++ /dev/null @@ -1,1200 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/TypeReduction.h" - -#include "Luau/Common.h" -#include "Luau/Error.h" -#include "Luau/RecursionCounter.h" -#include "Luau/VisitType.h" - -#include -#include - -LUAU_FASTINTVARIABLE(LuauTypeReductionCartesianProductLimit, 100'000) -LUAU_FASTINTVARIABLE(LuauTypeReductionRecursionLimit, 300) -LUAU_FASTFLAGVARIABLE(DebugLuauDontReduceTypes, false) - -namespace Luau -{ - -namespace detail -{ -bool TypeReductionMemoization::isIrreducible(TypeId ty) -{ - ty = follow(ty); - - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (auto edge = types.find(ty); edge && edge->irreducible) - return true; - else if (get(ty) || get(ty) || get(ty)) - return false; - else if (auto tt = get(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed)) - return false; - else - return true; -} - -bool TypeReductionMemoization::isIrreducible(TypePackId tp) -{ - tp = follow(tp); - - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (auto edge = typePacks.find(tp); edge && edge->irreducible) - return true; - else if (get(tp) || get(tp)) - return false; - else if (auto vtp = get(tp)) - return isIrreducible(vtp->ty); - else - return true; -} - -TypeId TypeReductionMemoization::memoize(TypeId ty, TypeId reducedTy) -{ - ty = follow(ty); - reducedTy = follow(reducedTy); - - // The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible. - // We don't need to recurse much further than that, because we already record the irreducibility from - // the bottom up. - bool irreducible = isIrreducible(reducedTy); - if (auto it = get(reducedTy)) - { - for (TypeId part : it) - irreducible &= isIrreducible(part); - } - else if (auto ut = get(reducedTy)) - { - for (TypeId option : ut) - irreducible &= isIrreducible(option); - } - else if (auto tt = get(reducedTy)) - { - for (auto& [k, p] : tt->props) - irreducible &= isIrreducible(p.type()); - - if (tt->indexer) - { - irreducible &= isIrreducible(tt->indexer->indexType); - irreducible &= isIrreducible(tt->indexer->indexResultType); - } - - for (auto ta : tt->instantiatedTypeParams) - irreducible &= isIrreducible(ta); - - for (auto tpa : tt->instantiatedTypePackParams) - irreducible &= isIrreducible(tpa); - } - else if (auto mt = get(reducedTy)) - { - irreducible &= isIrreducible(mt->table); - irreducible &= isIrreducible(mt->metatable); - } - else if (auto ft = get(reducedTy)) - { - irreducible &= isIrreducible(ft->argTypes); - irreducible &= isIrreducible(ft->retTypes); - } - else if (auto nt = get(reducedTy)) - irreducible &= isIrreducible(nt->ty); - - types[ty] = {reducedTy, irreducible}; - types[reducedTy] = {reducedTy, irreducible}; - return reducedTy; -} - -TypePackId TypeReductionMemoization::memoize(TypePackId tp, TypePackId reducedTp) -{ - tp = follow(tp); - reducedTp = follow(reducedTp); - - bool irreducible = isIrreducible(reducedTp); - TypePackIterator it = begin(tp); - while (it != end(tp)) - { - irreducible &= isIrreducible(*it); - ++it; - } - - if (it.tail()) - irreducible &= isIrreducible(*it.tail()); - - typePacks[tp] = {reducedTp, irreducible}; - typePacks[reducedTp] = {reducedTp, irreducible}; - return reducedTp; -} - -std::optional> TypeReductionMemoization::memoizedof(TypeId ty) const -{ - auto fetchContext = [this](TypeId ty) -> std::optional> { - if (auto edge = types.find(ty)) - return *edge; - else - return std::nullopt; - }; - - TypeId currentTy = ty; - std::optional> lastEdge; - while (auto edge = fetchContext(currentTy)) - { - lastEdge = edge; - if (edge->irreducible) - return edge; - else if (edge->type == currentTy) - return edge; - else - currentTy = edge->type; - } - - return lastEdge; -} - -std::optional> TypeReductionMemoization::memoizedof(TypePackId tp) const -{ - auto fetchContext = [this](TypePackId tp) -> std::optional> { - if (auto edge = typePacks.find(tp)) - return *edge; - else - return std::nullopt; - }; - - TypePackId currentTp = tp; - std::optional> lastEdge; - while (auto edge = fetchContext(currentTp)) - { - lastEdge = edge; - if (edge->irreducible) - return edge; - else if (edge->type == currentTp) - return edge; - else - currentTp = edge->type; - } - - return lastEdge; -} -} // namespace detail - -namespace -{ - -template -std::pair get2(const Thing& one, const Thing& two) -{ - const A* a = get(one); - const B* b = get(two); - return a && b ? std::make_pair(a, b) : std::make_pair(nullptr, nullptr); -} - -struct TypeReducer -{ - NotNull arena; - NotNull builtinTypes; - NotNull handle; - NotNull memoization; - DenseHashSet* cyclics; - - int depth = 0; - - TypeId reduce(TypeId ty); - TypePackId reduce(TypePackId tp); - - std::optional intersectionType(TypeId left, TypeId right); - std::optional unionType(TypeId left, TypeId right); - TypeId tableType(TypeId ty); - TypeId functionType(TypeId ty); - TypeId negationType(TypeId ty); - - using BinaryFold = std::optional (TypeReducer::*)(TypeId, TypeId); - using UnaryFold = TypeId (TypeReducer::*)(TypeId); - - template - LUAU_NOINLINE std::pair copy(TypeId ty, const T* t) - { - ty = follow(ty); - - if (auto edge = memoization->memoizedof(ty)) - return {edge->type, getMutable(edge->type)}; - - // We specifically do not want to use [`detail::TypeReductionMemoization::memoize`] because that will - // potentially consider these copiedTy to be reducible, but we need this to resolve cyclic references - // without attempting to recursively reduce it, causing copies of copies of copies of... - TypeId copiedTy = arena->addType(*t); - memoization->types[ty] = {copiedTy, true}; - memoization->types[copiedTy] = {copiedTy, true}; - return {copiedTy, getMutable(copiedTy)}; - } - - template - void foldl_impl(Iter it, Iter endIt, BinaryFold f, std::vector* result, bool* didReduce) - { - RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; - - while (it != endIt) - { - TypeId right = reduce(*it); - *didReduce |= right != follow(*it); - - // We're hitting a case where the `currentTy` returned a type that's the same as `T`. - // e.g. `(string?) & ~(false | nil)` became `(string?) & (~false & ~nil)` but the current iterator we're consuming doesn't know this. - // We will need to recurse and traverse that first. - if (auto t = get(right)) - { - foldl_impl(begin(t), end(t), f, result, didReduce); - ++it; - continue; - } - - bool replaced = false; - auto resultIt = result->begin(); - while (resultIt != result->end()) - { - TypeId left = *resultIt; - if (left == right) - { - replaced = true; - ++resultIt; - continue; - } - - std::optional reduced = (this->*f)(left, right); - if (reduced) - { - *resultIt = *reduced; - ++resultIt; - replaced = true; - } - else - { - ++resultIt; - continue; - } - } - - if (!replaced) - result->push_back(right); - - *didReduce |= replaced; - ++it; - } - } - - template - TypeId flatten(std::vector&& types) - { - if (types.size() == 1) - return types[0]; - else - return arena->addType(T{std::move(types)}); - } - - template - TypeId foldl(Iter it, Iter endIt, std::optional ty, BinaryFold f) - { - std::vector result; - bool didReduce = false; - foldl_impl(it, endIt, f, &result, &didReduce); - - // If we've done any reduction, then we'll need to reduce it again, e.g. - // `"a" | "b" | string` is reduced into `string | string`, which is then reduced into `string`. - if (!didReduce) - return ty ? *ty : flatten(std::move(result)); - else - return reduce(flatten(std::move(result))); - } - - template - TypeId apply(BinaryFold f, TypeId left, TypeId right) - { - std::vector types{left, right}; - return foldl(begin(types), end(types), std::nullopt, f); - } - - template - TypeId distribute(TypeIterator it, TypeIterator endIt, BinaryFold f, TypeId ty) - { - std::vector result; - while (it != endIt) - { - result.push_back(apply(f, *it, ty)); - ++it; - } - return flatten(std::move(result)); - } -}; - -TypeId TypeReducer::reduce(TypeId ty) -{ - ty = follow(ty); - - if (auto edge = memoization->memoizedof(ty)) - { - if (edge->irreducible) - return edge->type; - else - ty = follow(edge->type); - } - else if (cyclics->contains(ty)) - return ty; - - RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; - - TypeId result = nullptr; - if (auto i = get(ty)) - result = foldl(begin(i), end(i), ty, &TypeReducer::intersectionType); - else if (auto u = get(ty)) - result = foldl(begin(u), end(u), ty, &TypeReducer::unionType); - else if (get(ty) || get(ty)) - result = tableType(ty); - else if (get(ty)) - result = functionType(ty); - else if (get(ty)) - result = negationType(ty); - else - result = ty; - - return memoization->memoize(ty, result); -} - -TypePackId TypeReducer::reduce(TypePackId tp) -{ - tp = follow(tp); - - if (auto edge = memoization->memoizedof(tp)) - { - if (edge->irreducible) - return edge->type; - else - tp = edge->type; - } - else if (cyclics->contains(tp)) - return tp; - - RecursionLimiter rl{&depth, FInt::LuauTypeReductionRecursionLimit}; - - bool didReduce = false; - TypePackIterator it = begin(tp); - - std::vector head; - while (it != end(tp)) - { - TypeId reducedTy = reduce(*it); - head.push_back(reducedTy); - didReduce |= follow(*it) != follow(reducedTy); - ++it; - } - - std::optional tail = it.tail(); - if (tail) - { - if (auto vtp = get(follow(*it.tail()))) - { - TypeId reducedTy = reduce(vtp->ty); - if (follow(vtp->ty) != follow(reducedTy)) - { - tail = arena->addTypePack(VariadicTypePack{reducedTy, vtp->hidden}); - didReduce = true; - } - } - } - - if (!didReduce) - return memoization->memoize(tp, tp); - else if (head.empty() && tail) - return memoization->memoize(tp, *tail); - else - return memoization->memoize(tp, arena->addTypePack(TypePack{std::move(head), tail})); -} - -std::optional TypeReducer::intersectionType(TypeId left, TypeId right) -{ - if (get(left)) - return left; // never & T ~ never - else if (get(right)) - return right; // T & never ~ never - else if (get(left)) - return right; // unknown & T ~ T - else if (get(right)) - return left; // T & unknown ~ T - else if (get(left)) - return right; // any & T ~ T - else if (get(right)) - return left; // T & any ~ T - else if (get(left)) - return std::nullopt; // 'a & T ~ 'a & T - else if (get(right)) - return std::nullopt; // T & 'a ~ T & 'a - else if (get(left)) - return std::nullopt; // G & T ~ G & T - else if (get(right)) - return std::nullopt; // T & G ~ T & G - else if (get(left)) - return std::nullopt; // error & T ~ error & T - else if (get(right)) - return std::nullopt; // T & error ~ T & error - else if (get(left)) - return std::nullopt; // *blocked* & T ~ *blocked* & T - else if (get(right)) - return std::nullopt; // T & *blocked* ~ T & *blocked* - else if (get(left)) - return std::nullopt; // *pending* & T ~ *pending* & T - else if (get(right)) - return std::nullopt; // T & *pending* ~ T & *pending* - else if (auto [utl, utr] = get2(left, right); utl && utr) - { - std::vector parts; - for (TypeId optionl : utl) - { - for (TypeId optionr : utr) - parts.push_back(apply(&TypeReducer::intersectionType, optionl, optionr)); - } - - return reduce(flatten(std::move(parts))); // (T | U) & (A | B) ~ (T & A) | (T & B) | (U & A) | (U & B) - } - else if (auto ut = get(left)) - return reduce(distribute(begin(ut), end(ut), &TypeReducer::intersectionType, right)); // (A | B) & T ~ (A & T) | (B & T) - else if (get(right)) - return intersectionType(right, left); // T & (A | B) ~ (A | B) & T - else if (auto [p1, p2] = get2(left, right); p1 && p2) - { - if (p1->type == p2->type) - return left; // P1 & P2 ~ P1 iff P1 == P2 - else - return builtinTypes->neverType; // P1 & P2 ~ never iff P1 != P2 - } - else if (auto [p, s] = get2(left, right); p && s) - { - if (p->type == PrimitiveType::String && get(s)) - return right; // string & "A" ~ "A" - else if (p->type == PrimitiveType::Boolean && get(s)) - return right; // boolean & true ~ true - else - return builtinTypes->neverType; // string & true ~ never - } - else if (auto [s, p] = get2(left, right); s && p) - return intersectionType(right, left); // S & P ~ P & S - else if (auto [p, f] = get2(left, right); p && f) - { - if (p->type == PrimitiveType::Function) - return right; // function & () -> () ~ () -> () - else - return builtinTypes->neverType; // string & () -> () ~ never - } - else if (auto [f, p] = get2(left, right); f && p) - return intersectionType(right, left); // () -> () & P ~ P & () -> () - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return right; // table & {} ~ {} - else - return builtinTypes->neverType; // string & {} ~ never - } - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return right; // table & {} ~ {} - else - return builtinTypes->neverType; // string & {} ~ never - } - else if (auto [t, p] = get2(left, right); t && p) - return intersectionType(right, left); // {} & P ~ P & {} - else if (auto [t, p] = get2(left, right); t && p) - return intersectionType(right, left); // M & P ~ P & M - else if (auto [s1, s2] = get2(left, right); s1 && s2) - { - if (*s1 == *s2) - return left; // "a" & "a" ~ "a" - else - return builtinTypes->neverType; // "a" & "b" ~ never - } - else if (auto [c1, c2] = get2(left, right); c1 && c2) - { - if (isSubclass(c1, c2)) - return left; // Derived & Base ~ Derived - else if (isSubclass(c2, c1)) - return right; // Base & Derived ~ Derived - else - return builtinTypes->neverType; // Base & Unrelated ~ never - } - else if (auto [f1, f2] = get2(left, right); f1 && f2) - return std::nullopt; // TODO - else if (auto [t1, t2] = get2(left, right); t1 && t2) - { - if (t1->state == TableState::Free || t2->state == TableState::Free) - return std::nullopt; // '{ x: T } & { x: U } ~ '{ x: T } & { x: U } - else if (t1->state == TableState::Generic || t2->state == TableState::Generic) - return std::nullopt; // '{ x: T } & { x: U } ~ '{ x: T } & { x: U } - - if (cyclics->contains(left)) - return std::nullopt; // (t1 where t1 = { p: t1 }) & {} ~ t1 & {} - else if (cyclics->contains(right)) - return std::nullopt; // {} & (t1 where t1 = { p: t1 }) ~ {} & t1 - - TypeId resultTy = arena->addType(TableType{}); - TableType* table = getMutable(resultTy); - table->state = t1->state == TableState::Sealed || t2->state == TableState::Sealed ? TableState::Sealed : TableState::Unsealed; - - for (const auto& [name, prop] : t1->props) - { - // TODO: when t1 has properties, we should also intersect that with the indexer in t2 if it exists, - // even if we have the corresponding property in the other one. - if (auto other = t2->props.find(name); other != t2->props.end()) - { - TypeId propTy = apply(&TypeReducer::intersectionType, prop.type(), other->second.type()); - if (get(propTy)) - return builtinTypes->neverType; // { p : string } & { p : number } ~ { p : string & number } ~ { p : never } ~ never - else - table->props[name] = {propTy}; // { p : string } & { p : ~"a" } ~ { p : string & ~"a" } - } - else - table->props[name] = prop; // { p : string } & {} ~ { p : string } - } - - for (const auto& [name, prop] : t2->props) - { - // TODO: And vice versa, t2 properties against t1 indexer if it exists, - // even if we have the corresponding property in the other one. - if (!t1->props.count(name)) - table->props[name] = {reduce(prop.type())}; // {} & { p : string & string } ~ { p : string } - } - - if (t1->indexer && t2->indexer) - { - TypeId keyTy = apply(&TypeReducer::intersectionType, t1->indexer->indexType, t2->indexer->indexType); - if (get(keyTy)) - return std::nullopt; // { [string]: _ } & { [number]: _ } ~ { [string]: _ } & { [number]: _ } - - TypeId valueTy = apply(&TypeReducer::intersectionType, t1->indexer->indexResultType, t2->indexer->indexResultType); - table->indexer = TableIndexer{keyTy, valueTy}; // { [string]: number } & { [string]: string } ~ { [string]: never } - } - else if (t1->indexer) - { - TypeId keyTy = reduce(t1->indexer->indexType); - TypeId valueTy = reduce(t1->indexer->indexResultType); - table->indexer = TableIndexer{keyTy, valueTy}; // { [number]: boolean } & { p : string } ~ { p : string, [number]: boolean } - } - else if (t2->indexer) - { - TypeId keyTy = reduce(t2->indexer->indexType); - TypeId valueTy = reduce(t2->indexer->indexResultType); - table->indexer = TableIndexer{keyTy, valueTy}; // { p : string } & { [number]: boolean } ~ { p : string, [number]: boolean } - } - - return resultTy; - } - else if (auto [mt, tt] = get2(left, right); mt && tt) - return std::nullopt; // TODO - else if (auto [tt, mt] = get2(left, right); tt && mt) - return intersectionType(right, left); // T & M ~ M & T - else if (auto [m1, m2] = get2(left, right); m1 && m2) - return std::nullopt; // TODO - else if (auto [nl, nr] = get2(left, right); nl && nr) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - TypeId nrTy = follow(nr->ty); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - - if (auto [npl, npr] = get2(nlTy, nrTy); npl && npr) - { - if (npl->type == npr->type) - return left; // ~P1 & ~P2 ~ ~P1 iff P1 == P2 - else - return std::nullopt; // ~P1 & ~P2 ~ ~P1 & ~P2 iff P1 != P2 - } - else if (auto [nsl, nsr] = get2(nlTy, nrTy); nsl && nsr) - { - if (*nsl == *nsr) - return left; // ~"A" & ~"A" ~ ~"A" - else - return std::nullopt; // ~"A" & ~"B" ~ ~"A" & ~"B" - } - else if (auto [ns, np] = get2(nlTy, nrTy); ns && np) - { - if (get(ns) && np->type == PrimitiveType::String) - return right; // ~"A" & ~string ~ ~string - else if (get(ns) && np->type == PrimitiveType::Boolean) - return right; // ~false & ~boolean ~ ~boolean - else - return std::nullopt; // ~"A" | ~P ~ ~"A" & ~P - } - else if (auto [np, ns] = get2(nlTy, nrTy); np && ns) - return intersectionType(right, left); // ~P & ~S ~ ~S & ~P - else - return std::nullopt; // ~T & ~U ~ ~T & ~U - } - else if (auto nl = get(left)) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - - if (auto [np, p] = get2(nlTy, right); np && p) - { - if (np->type == p->type) - return builtinTypes->neverType; // ~P1 & P2 ~ never iff P1 == P2 - else - return right; // ~P1 & P2 ~ P2 iff P1 != P2 - } - else if (auto [ns, s] = get2(nlTy, right); ns && s) - { - if (*ns == *s) - return builtinTypes->neverType; // ~"A" & "A" ~ never - else - return right; // ~"A" & "B" ~ "B" - } - else if (auto [ns, p] = get2(nlTy, right); ns && p) - { - if (get(ns) && p->type == PrimitiveType::String) - return std::nullopt; // ~"A" & string ~ ~"A" & string - else if (get(ns) && p->type == PrimitiveType::Boolean) - { - // Because booleans contain a fixed amount of values (2), we can do something cooler with this one. - const BooleanSingleton* b = get(ns); - return arena->addType(SingletonType{BooleanSingleton{!b->value}}); // ~false & boolean ~ true - } - else - return right; // ~"A" & number ~ number - } - else if (auto [np, s] = get2(nlTy, right); np && s) - { - if (np->type == PrimitiveType::String && get(s)) - return builtinTypes->neverType; // ~string & "A" ~ never - else if (np->type == PrimitiveType::Boolean && get(s)) - return builtinTypes->neverType; // ~boolean & true ~ never - else - return right; // ~P & "A" ~ "A" - } - else if (auto [np, f] = get2(nlTy, right); np && f) - { - if (np->type == PrimitiveType::Function) - return builtinTypes->neverType; // ~function & () -> () ~ never - else - return right; // ~string & () -> () ~ () -> () - } - else if (auto [nc, c] = get2(nlTy, right); nc && c) - { - if (isSubclass(c, nc)) - return builtinTypes->neverType; // ~Base & Derived ~ never - else if (isSubclass(nc, c)) - return std::nullopt; // ~Derived & Base ~ ~Derived & Base - else - return right; // ~Base & Unrelated ~ Unrelated - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return builtinTypes->neverType; // ~table & {} ~ never - else - return right; // ~string & {} ~ {} - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return builtinTypes->neverType; // ~table & {} ~ never - else - return right; // ~string & {} ~ {} - } - else - return right; // ~T & U ~ U - } - else if (get(right)) - return intersectionType(right, left); // T & ~U ~ ~U & T - else - return builtinTypes->neverType; // for all T and U except the ones handled above, T & U ~ never -} - -std::optional TypeReducer::unionType(TypeId left, TypeId right) -{ - LUAU_ASSERT(!get(left)); - LUAU_ASSERT(!get(right)); - - if (get(left)) - return right; // never | T ~ T - else if (get(right)) - return left; // T | never ~ T - else if (get(left)) - return left; // unknown | T ~ unknown - else if (get(right)) - return right; // T | unknown ~ unknown - else if (get(left)) - return left; // any | T ~ any - else if (get(right)) - return right; // T | any ~ any - else if (get(left)) - return std::nullopt; // error | T ~ error | T - else if (get(right)) - return std::nullopt; // T | error ~ T | error - else if (auto [p1, p2] = get2(left, right); p1 && p2) - { - if (p1->type == p2->type) - return left; // P1 | P2 ~ P1 iff P1 == P2 - else - return std::nullopt; // P1 | P2 ~ P1 | P2 iff P1 != P2 - } - else if (auto [p, s] = get2(left, right); p && s) - { - if (p->type == PrimitiveType::String && get(s)) - return left; // string | "A" ~ string - else if (p->type == PrimitiveType::Boolean && get(s)) - return left; // boolean | true ~ boolean - else - return std::nullopt; // string | true ~ string | true - } - else if (auto [s, p] = get2(left, right); s && p) - return unionType(right, left); // S | P ~ P | S - else if (auto [p, f] = get2(left, right); p && f) - { - if (p->type == PrimitiveType::Function) - return left; // function | () -> () ~ function - else - return std::nullopt; // P | () -> () ~ P | () -> () - } - else if (auto [f, p] = get2(left, right); f && p) - return unionType(right, left); // () -> () | P ~ P | () -> () - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return left; // table | {} ~ table - else - return std::nullopt; // P | {} ~ P | {} - } - else if (auto [p, t] = get2(left, right); p && t) - { - if (p->type == PrimitiveType::Table) - return left; // table | {} ~ table - else - return std::nullopt; // P | {} ~ P | {} - } - else if (auto [t, p] = get2(left, right); t && p) - return unionType(right, left); // {} | P ~ P | {} - else if (auto [t, p] = get2(left, right); t && p) - return unionType(right, left); // M | P ~ P | M - else if (auto [s1, s2] = get2(left, right); s1 && s2) - { - if (*s1 == *s2) - return left; // "a" | "a" ~ "a" - else - return std::nullopt; // "a" | "b" ~ "a" | "b" - } - else if (auto [c1, c2] = get2(left, right); c1 && c2) - { - if (isSubclass(c1, c2)) - return right; // Derived | Base ~ Base - else if (isSubclass(c2, c1)) - return left; // Base | Derived ~ Base - else - return std::nullopt; // Base | Unrelated ~ Base | Unrelated - } - else if (auto [nt, it] = get2(left, right); nt && it) - return reduce(distribute(begin(it), end(it), &TypeReducer::unionType, left)); // ~T | (A & B) ~ (~T | A) & (~T | B) - else if (auto [it, nt] = get2(left, right); it && nt) - return unionType(right, left); // (A & B) | ~T ~ ~T | (A & B) - else if (auto it = get(left)) - { - bool didReduce = false; - std::vector parts; - for (TypeId part : it) - { - auto nt = get(part); - if (!nt) - { - parts.push_back(part); - continue; - } - - auto redex = unionType(part, right); - if (redex && get(*redex)) - { - didReduce = true; - continue; - } - - parts.push_back(part); - } - - if (didReduce) - return flatten(std::move(parts)); // (T & ~nil) | nil ~ T - else - return std::nullopt; // (T & ~nil) | U - } - else if (get(right)) - return unionType(right, left); // A | (T & U) ~ (T & U) | A - else if (auto [nl, nr] = get2(left, right); nl && nr) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - TypeId nrTy = follow(nr->ty); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - LUAU_ASSERT(!get(nlTy) && !get(nrTy)); - - if (auto [npl, npr] = get2(nlTy, nrTy); npl && npr) - { - if (npl->type == npr->type) - return left; // ~P1 | ~P2 ~ ~P1 iff P1 == P2 - else - return builtinTypes->unknownType; // ~P1 | ~P2 ~ ~P1 iff P1 != P2 - } - else if (auto [nsl, nsr] = get2(nlTy, nrTy); nsl && nsr) - { - if (*nsl == *nsr) - return left; // ~"A" | ~"A" ~ ~"A" - else - return builtinTypes->unknownType; // ~"A" | ~"B" ~ unknown - } - else if (auto [ns, np] = get2(nlTy, nrTy); ns && np) - { - if (get(ns) && np->type == PrimitiveType::String) - return left; // ~"A" | ~string ~ ~"A" - else if (get(ns) && np->type == PrimitiveType::Boolean) - return left; // ~false | ~boolean ~ ~false - else - return builtinTypes->unknownType; // ~"A" | ~P ~ unknown - } - else if (auto [np, ns] = get2(nlTy, nrTy); np && ns) - return unionType(right, left); // ~P | ~S ~ ~S | ~P - else - return std::nullopt; // TODO! - } - else if (auto nl = get(left)) - { - // These should've been reduced already. - TypeId nlTy = follow(nl->ty); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - LUAU_ASSERT(!get(nlTy)); - - if (auto [np, p] = get2(nlTy, right); np && p) - { - if (np->type == p->type) - return builtinTypes->unknownType; // ~P1 | P2 ~ unknown iff P1 == P2 - else - return left; // ~P1 | P2 ~ ~P1 iff P1 != P2 - } - else if (auto [ns, s] = get2(nlTy, right); ns && s) - { - if (*ns == *s) - return builtinTypes->unknownType; // ~"A" | "A" ~ unknown - else - return left; // ~"A" | "B" ~ ~"A" - } - else if (auto [ns, p] = get2(nlTy, right); ns && p) - { - if (get(ns) && p->type == PrimitiveType::String) - return builtinTypes->unknownType; // ~"A" | string ~ unknown - else if (get(ns) && p->type == PrimitiveType::Boolean) - return builtinTypes->unknownType; // ~false | boolean ~ unknown - else - return left; // ~"A" | T ~ ~"A" - } - else if (auto [np, s] = get2(nlTy, right); np && s) - { - if (np->type == PrimitiveType::String && get(s)) - return std::nullopt; // ~string | "A" ~ ~string | "A" - else if (np->type == PrimitiveType::Boolean && get(s)) - { - const BooleanSingleton* b = get(s); - return negationType(arena->addType(SingletonType{BooleanSingleton{!b->value}})); // ~boolean | false ~ ~true - } - else - return left; // ~P | "A" ~ ~P - } - else if (auto [nc, c] = get2(nlTy, right); nc && c) - { - if (isSubclass(c, nc)) - return std::nullopt; // ~Base | Derived ~ ~Base | Derived - else if (isSubclass(nc, c)) - return builtinTypes->unknownType; // ~Derived | Base ~ unknown - else - return left; // ~Base | Unrelated ~ ~Base - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return std::nullopt; // ~table | {} ~ ~table | {} - else - return right; // ~P | {} ~ ~P | {} - } - else if (auto [np, t] = get2(nlTy, right); np && t) - { - if (np->type == PrimitiveType::Table) - return std::nullopt; // ~table | {} ~ ~table | {} - else - return right; // ~P | M ~ ~P | M - } - else - return std::nullopt; // TODO - } - else if (get(right)) - return unionType(right, left); // T | ~U ~ ~U | T - else - return std::nullopt; // for all T and U except the ones handled above, T | U ~ T | U -} - -TypeId TypeReducer::tableType(TypeId ty) -{ - if (auto mt = get(ty)) - { - auto [copiedTy, copied] = copy(ty, mt); - copied->table = reduce(mt->table); - copied->metatable = reduce(mt->metatable); - return copiedTy; - } - else if (auto tt = get(ty)) - { - // Because of `typeof()`, we need to preserve pointer identity of free/unsealed tables so that - // all mutations that occurs on this will be applied without leaking the implementation details. - // As a result, we'll just use the type instead of cloning it if it's free/unsealed. - // - // We could choose to do in-place reductions here, but to be on the safer side, I propose that we do not. - if (tt->state == TableState::Free || tt->state == TableState::Unsealed) - return ty; - - auto [copiedTy, copied] = copy(ty, tt); - - for (auto& [name, prop] : copied->props) - { - TypeId propTy = reduce(prop.type()); - if (get(propTy)) - return builtinTypes->neverType; - else - prop.setType(propTy); - } - - if (copied->indexer) - { - TypeId keyTy = reduce(copied->indexer->indexType); - TypeId valueTy = reduce(copied->indexer->indexResultType); - copied->indexer = TableIndexer{keyTy, valueTy}; - } - - for (TypeId& ty : copied->instantiatedTypeParams) - ty = reduce(ty); - - for (TypePackId& tp : copied->instantiatedTypePackParams) - tp = reduce(tp); - - return copiedTy; - } - else - handle->ice("TypeReducer::tableType expects a TableType or MetatableType"); -} - -TypeId TypeReducer::functionType(TypeId ty) -{ - const FunctionType* f = get(ty); - if (!f) - handle->ice("TypeReducer::functionType expects a FunctionType"); - - // TODO: once we have bounded quantification, we need to be able to reduce the generic bounds. - auto [copiedTy, copied] = copy(ty, f); - copied->argTypes = reduce(f->argTypes); - copied->retTypes = reduce(f->retTypes); - return copiedTy; -} - -TypeId TypeReducer::negationType(TypeId ty) -{ - const NegationType* n = get(ty); - if (!n) - return arena->addType(NegationType{ty}); - - TypeId negatedTy = follow(n->ty); - - if (auto nn = get(negatedTy)) - return nn->ty; // ~~T ~ T - else if (get(negatedTy)) - return builtinTypes->unknownType; // ~never ~ unknown - else if (get(negatedTy)) - return builtinTypes->neverType; // ~unknown ~ never - else if (get(negatedTy)) - return builtinTypes->anyType; // ~any ~ any - else if (auto ni = get(negatedTy)) - { - std::vector options; - for (TypeId part : ni) - options.push_back(negationType(arena->addType(NegationType{part}))); - return reduce(flatten(std::move(options))); // ~(T & U) ~ (~T | ~U) - } - else if (auto nu = get(negatedTy)) - { - std::vector parts; - for (TypeId option : nu) - parts.push_back(negationType(arena->addType(NegationType{option}))); - return reduce(flatten(std::move(parts))); // ~(T | U) ~ (~T & ~U) - } - else - return ty; // for all T except the ones handled above, ~T ~ ~T -} - -struct MarkCycles : TypeVisitor -{ - DenseHashSet cyclics{nullptr}; - - void cycle(TypeId ty) override - { - cyclics.insert(follow(ty)); - } - - void cycle(TypePackId tp) override - { - cyclics.insert(follow(tp)); - } - - bool visit(TypeId ty) override - { - return !cyclics.find(follow(ty)); - } - - bool visit(TypePackId tp) override - { - return !cyclics.find(follow(tp)); - } -}; -} // namespace - -TypeReduction::TypeReduction( - NotNull arena, NotNull builtinTypes, NotNull handle, const TypeReductionOptions& opts) - : arena(arena) - , builtinTypes(builtinTypes) - , handle(handle) - , options(opts) -{ -} - -std::optional TypeReduction::reduce(TypeId ty) -{ - ty = follow(ty); - - if (FFlag::DebugLuauDontReduceTypes) - return ty; - else if (!options.allowTypeReductionsFromOtherArenas && ty->owningArena != arena) - return ty; - else if (auto edge = memoization.memoizedof(ty)) - { - if (edge->irreducible) - return edge->type; - else - ty = edge->type; - } - else if (hasExceededCartesianProductLimit(ty)) - return std::nullopt; - - try - { - MarkCycles finder; - finder.traverse(ty); - - TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics}; - return reducer.reduce(ty); - } - catch (const RecursionLimitException&) - { - return std::nullopt; - } -} - -std::optional TypeReduction::reduce(TypePackId tp) -{ - tp = follow(tp); - - if (FFlag::DebugLuauDontReduceTypes) - return tp; - else if (!options.allowTypeReductionsFromOtherArenas && tp->owningArena != arena) - return tp; - else if (auto edge = memoization.memoizedof(tp)) - { - if (edge->irreducible) - return edge->type; - else - tp = edge->type; - } - else if (hasExceededCartesianProductLimit(tp)) - return std::nullopt; - - try - { - MarkCycles finder; - finder.traverse(tp); - - TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics}; - return reducer.reduce(tp); - } - catch (const RecursionLimitException&) - { - return std::nullopt; - } -} - -std::optional TypeReduction::reduce(const TypeFun& fun) -{ - if (FFlag::DebugLuauDontReduceTypes) - return fun; - - // TODO: once we have bounded quantification, we need to be able to reduce the generic bounds. - if (auto reducedTy = reduce(fun.type)) - return TypeFun{fun.typeParams, fun.typePackParams, *reducedTy}; - - return std::nullopt; -} - -size_t TypeReduction::cartesianProductSize(TypeId ty) const -{ - ty = follow(ty); - - auto it = get(follow(ty)); - if (!it) - return 1; - - return std::accumulate(begin(it), end(it), size_t(1), [](size_t acc, TypeId ty) { - if (auto ut = get(ty)) - return acc * std::distance(begin(ut), end(ut)); - else if (get(ty)) - return acc * 0; - else - return acc * 1; - }); -} - -bool TypeReduction::hasExceededCartesianProductLimit(TypeId ty) const -{ - return cartesianProductSize(ty) >= size_t(FInt::LuauTypeReductionCartesianProductLimit); -} - -bool TypeReduction::hasExceededCartesianProductLimit(TypePackId tp) const -{ - TypePackIterator it = begin(tp); - - while (it != end(tp)) - { - if (hasExceededCartesianProductLimit(*it)) - return true; - - ++it; - } - - if (auto tail = it.tail()) - { - if (auto vtp = get(follow(*tail))) - { - if (hasExceededCartesianProductLimit(vtp->ty)) - return true; - } - } - - return false; -} - -} // namespace Luau diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 56be404..76428cf 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -447,13 +447,13 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // "double-report" errors in some cases, like when trying to unify // identical type family instantiations like Add with // Add. - reduceFamilies(superTy, location, NotNull(types), builtinTypes, &log); + reduceFamilies(superTy, location, NotNull(types), builtinTypes, scope, normalizer, &log); superTy = log.follow(superTy); } if (log.get(subTy)) { - reduceFamilies(subTy, location, NotNull(types), builtinTypes, &log); + reduceFamilies(subTy, location, NotNull(types), builtinTypes, scope, normalizer, &log); subTy = log.follow(subTy); } diff --git a/CLI/Reduce.cpp b/CLI/Reduce.cpp index b7c7801..ffe670b 100644 --- a/CLI/Reduce.cpp +++ b/CLI/Reduce.cpp @@ -56,10 +56,9 @@ struct Reducer ParseResult parseResult; AstStatBlock* root; - std::string tempScriptName; + std::string scriptName; - std::string appName; - std::vector appArgs; + std::string command; std::string_view searchText; Reducer() @@ -99,10 +98,10 @@ struct Reducer } while (true); } - FILE* f = fopen(tempScriptName.c_str(), "w"); + FILE* f = fopen(scriptName.c_str(), "w"); if (!f) { - printf("Unable to open temp script to %s\n", tempScriptName.c_str()); + printf("Unable to open temp script to %s\n", scriptName.c_str()); exit(2); } @@ -113,7 +112,7 @@ struct Reducer if (written != source.size()) { printf("??? %zu %zu\n", written, source.size()); - printf("Unable to write to temp script %s\n", tempScriptName.c_str()); + printf("Unable to write to temp script %s\n", scriptName.c_str()); exit(3); } @@ -142,9 +141,15 @@ struct Reducer { writeTempScript(); - std::string command = appName + " " + escape(tempScriptName); - for (const auto& arg : appArgs) - command += " " + escape(arg); + std::string cmd = command; + while (true) + { + auto pos = cmd.find("{}"); + if (std::string::npos == pos) + break; + + cmd = cmd.substr(0, pos) + escape(scriptName) + cmd.substr(pos + 2); + } #if VERBOSE >= 1 printf("running %s\n", command.c_str()); @@ -424,30 +429,20 @@ struct Reducer } } - void run(const std::string scriptName, const std::string appName, const std::vector& appArgs, std::string_view source, + void run(const std::string scriptName, const std::string command, std::string_view source, std::string_view searchText) { - tempScriptName = scriptName; - if (tempScriptName.substr(tempScriptName.size() - 4) == ".lua") - { - tempScriptName.erase(tempScriptName.size() - 4); - tempScriptName += "-reduced.lua"; - } - else - { - this->tempScriptName = scriptName + "-reduced"; - } + this->scriptName = scriptName; #if 0 // Handy debugging trick: VS Code will update its view of the file in realtime as it is edited. - std::string wheee = "code " + tempScriptName; + std::string wheee = "code " + scriptName; system(wheee.c_str()); #endif - printf("Temp script: %s\n", tempScriptName.c_str()); + printf("Script: %s\n", scriptName.c_str()); - this->appName = appName; - this->appArgs = appArgs; + this->command = command; this->searchText = searchText; parseResult = Parser::parse(source.data(), source.size(), nameTable, allocator, parseOptions); @@ -470,13 +465,14 @@ struct Reducer writeTempScript(/* minify */ true); - printf("Done! Check %s\n", tempScriptName.c_str()); + printf("Done! Check %s\n", scriptName.c_str()); } }; [[noreturn]] void help(const std::vector& args) { - printf("Syntax: %s script application \"search text\" [arguments]\n", args[0].data()); + printf("Syntax: %s script command \"search text\"\n", args[0].data()); + printf(" Within command, use {} as a stand-in for the script being reduced\n"); exit(1); } @@ -484,7 +480,7 @@ int main(int argc, char** argv) { const std::vector args(argv, argv + argc); - if (args.size() < 4) + if (args.size() != 4) help(args); for (size_t i = 1; i < args.size(); ++i) @@ -496,7 +492,6 @@ int main(int argc, char** argv) const std::string scriptName = argv[1]; const std::string appName = argv[2]; const std::string searchText = argv[3]; - const std::vector appArgs(begin(args) + 4, end(args)); std::optional source = readFile(scriptName); @@ -507,5 +502,5 @@ int main(int argc, char** argv) } Reducer reducer; - reducer.run(scriptName, appName, appArgs, *source, searchText); + reducer.run(scriptName, appName, *source, searchText); } diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index e7733cd..09acfb4 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -80,6 +80,12 @@ public: void asr(RegisterA64 dst, RegisterA64 src1, uint8_t src2); void ror(RegisterA64 dst, RegisterA64 src1, uint8_t src2); + // Bitfields + void ubfiz(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + void ubfx(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + void sbfiz(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + void sbfx(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w); + // Load // Note: paired loads are currently omitted for simplicity void ldr(RegisterA64 dst, AddressA64 src); @@ -212,7 +218,7 @@ private: void placeFCMP(const char* name, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t opc); void placeFMOV(const char* name, RegisterA64 dst, double src, uint32_t op); void placeBM(const char* name, RegisterA64 dst, RegisterA64 src1, uint32_t src2, uint8_t op); - void placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, uint8_t src2, uint8_t op, int immr, int imms); + void placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op, int immr, int imms); void place(uint32_t word); diff --git a/CodeGen/include/luacodegen.h b/CodeGen/include/luacodegen.h index 1eb185d..654fc2c 100644 --- a/CodeGen/include/luacodegen.h +++ b/CodeGen/include/luacodegen.h @@ -9,7 +9,7 @@ struct lua_State; // returns 1 if Luau code generator is supported, 0 otherwise -LUACODEGEN_API int luau_codegen_supported(); +LUACODEGEN_API int luau_codegen_supported(void); // create an instance of Luau code generator. you must check that this feature is supported using luau_codegen_supported(). LUACODEGEN_API void luau_codegen_create(lua_State* L); diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index 23b5b9f..000dc85 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -280,6 +280,42 @@ void AssemblyBuilderA64::ror(RegisterA64 dst, RegisterA64 src1, uint8_t src2) placeBFM("ror", dst, src1, src2, 0b00'100111, src1.index, src2); } +void AssemblyBuilderA64::ubfiz(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w) +{ + int size = dst.kind == KindA64::x ? 64 : 32; + LUAU_ASSERT(w > 0 && f + w <= size); + + // f * 100 + w is only used for disassembly printout; in the future we might replace it with two separate fields for readability + placeBFM("ubfiz", dst, src, f * 100 + w, 0b10'100110, (-f) & (size - 1), w - 1); +} + +void AssemblyBuilderA64::ubfx(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w) +{ + int size = dst.kind == KindA64::x ? 64 : 32; + LUAU_ASSERT(w > 0 && f + w <= size); + + // f * 100 + w is only used for disassembly printout; in the future we might replace it with two separate fields for readability + placeBFM("ubfx", dst, src, f * 100 + w, 0b10'100110, f, f + w - 1); +} + +void AssemblyBuilderA64::sbfiz(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w) +{ + int size = dst.kind == KindA64::x ? 64 : 32; + LUAU_ASSERT(w > 0 && f + w <= size); + + // f * 100 + w is only used for disassembly printout; in the future we might replace it with two separate fields for readability + placeBFM("sbfiz", dst, src, f * 100 + w, 0b00'100110, (-f) & (size - 1), w - 1); +} + +void AssemblyBuilderA64::sbfx(RegisterA64 dst, RegisterA64 src, uint8_t f, uint8_t w) +{ + int size = dst.kind == KindA64::x ? 64 : 32; + LUAU_ASSERT(w > 0 && f + w <= size); + + // f * 100 + w is only used for disassembly printout; in the future we might replace it with two separate fields for readability + placeBFM("sbfx", dst, src, f * 100 + w, 0b00'100110, f, f + w - 1); +} + void AssemblyBuilderA64::ldr(RegisterA64 dst, AddressA64 src) { LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w || dst.kind == KindA64::s || dst.kind == KindA64::d || dst.kind == KindA64::q); @@ -1010,7 +1046,7 @@ void AssemblyBuilderA64::placeBM(const char* name, RegisterA64 dst, RegisterA64 commit(); } -void AssemblyBuilderA64::placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, uint8_t src2, uint8_t op, int immr, int imms) +void AssemblyBuilderA64::placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op, int immr, int imms) { if (logText) log(name, dst, src1, src2); diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index 59ee6f1..3c2a3f8 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -5,6 +5,7 @@ #include "Luau/UnwindBuilder.h" #include +#include #if defined(_WIN32) && defined(_M_X64) diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 714ddad..6460383 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -34,6 +34,7 @@ #include #include +#include #if defined(__x86_64__) || defined(_M_X64) #ifdef _MSC_VER @@ -61,33 +62,34 @@ namespace CodeGen static void* gPerfLogContext = nullptr; static PerfLogFn gPerfLogFn = nullptr; -static NativeProto* createNativeProto(Proto* proto, const IrBuilder& ir) +struct NativeProto +{ + Proto* p; + void* execdata; + uintptr_t exectarget; +}; + +static NativeProto createNativeProto(Proto* proto, const IrBuilder& ir) { int sizecode = proto->sizecode; - int sizecodeAlloc = (sizecode + 1) & ~1; // align uint32_t array to 8 bytes so that NativeProto is aligned to 8 bytes - void* memory = ::operator new(sizeof(NativeProto) + sizecodeAlloc * sizeof(uint32_t)); - NativeProto* result = new (static_cast(memory) + sizecodeAlloc * sizeof(uint32_t)) NativeProto; - result->proto = proto; - - uint32_t* instOffsets = result->instOffsets; + uint32_t* instOffsets = new uint32_t[sizecode]; + uint32_t instTarget = ir.function.bcMapping[0].asmLocation; for (int i = 0; i < sizecode; i++) { - // instOffsets uses negative indexing for optimal codegen for RETURN opcode - instOffsets[-i] = ir.function.bcMapping[i].asmLocation; + LUAU_ASSERT(ir.function.bcMapping[i].asmLocation >= instTarget); + + instOffsets[i] = ir.function.bcMapping[i].asmLocation - instTarget; } - return result; + // entry target will be relocated when assembly is finalized + return {proto, instOffsets, instTarget}; } -static void destroyNativeProto(NativeProto* nativeProto) +static void destroyExecData(void* execdata) { - int sizecode = nativeProto->proto->sizecode; - int sizecodeAlloc = (sizecode + 1) & ~1; // align uint32_t array to 8 bytes so that NativeProto is aligned to 8 bytes - void* memory = reinterpret_cast(nativeProto) - sizecodeAlloc * sizeof(uint32_t); - - ::operator delete(memory); + delete[] static_cast(execdata); } static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) @@ -271,7 +273,7 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& } template -static NativeProto* assembleFunction(AssemblyBuilder& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +static std::optional assembleFunction(AssemblyBuilder& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { if (options.includeAssembly || options.includeIr) { @@ -321,7 +323,7 @@ static NativeProto* assembleFunction(AssemblyBuilder& build, NativeState& data, if (build.logText) build.logAppend("; skipping (can't lower)\n\n"); - return nullptr; + return std::nullopt; } if (build.logText) @@ -337,23 +339,19 @@ static void onCloseState(lua_State* L) static void onDestroyFunction(lua_State* L, Proto* proto) { - NativeProto* nativeProto = getProtoExecData(proto); - LUAU_ASSERT(nativeProto->proto == proto); - - setProtoExecData(proto, nullptr); - destroyNativeProto(nativeProto); + destroyExecData(proto->execdata); + proto->execdata = nullptr; + proto->exectarget = 0; } static int onEnter(lua_State* L, Proto* proto) { NativeState* data = getNativeState(L); - NativeProto* nativeProto = getProtoExecData(proto); - LUAU_ASSERT(nativeProto); - LUAU_ASSERT(L->ci->savedpc); + LUAU_ASSERT(proto->execdata); + LUAU_ASSERT(L->ci->savedpc >= proto->code && L->ci->savedpc < proto->code + proto->sizecode); - // instOffsets uses negative indexing for optimal codegen for RETURN opcode - uintptr_t target = nativeProto->instBase + nativeProto->instOffsets[proto->code - L->ci->savedpc]; + uintptr_t target = proto->exectarget + static_cast(proto->execdata)[L->ci->savedpc - proto->code]; // Returns 1 to finish the function in the VM return GateFn(data->context.gateEntry)(L, proto, target, &data->context); @@ -361,7 +359,7 @@ static int onEnter(lua_State* L, Proto* proto) static void onSetBreakpoint(lua_State* L, Proto* proto, int instruction) { - if (!getProtoExecData(proto)) + if (!proto->execdata) return; LUAU_ASSERT(!"native breakpoints are not implemented"); @@ -444,8 +442,7 @@ void create(lua_State* L) data.codeAllocator.createBlockUnwindInfo = createBlockUnwindInfo; data.codeAllocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; - initFallbackTable(data); - initHelperFunctions(data); + initFunctions(data); #if defined(__x86_64__) || defined(_M_X64) if (!X64::initHeaderFunctions(data)) @@ -514,20 +511,20 @@ void compile(lua_State* L, int idx) X64::assembleHelpers(build, helpers); #endif - std::vector results; + std::vector results; results.reserve(protos.size()); // Skip protos that have been compiled during previous invocations of CodeGen::compile for (Proto* p : protos) - if (p && getProtoExecData(p) == nullptr) - if (NativeProto* np = assembleFunction(build, *data, helpers, p, {})) - results.push_back(np); + if (p && p->execdata == nullptr) + if (std::optional np = assembleFunction(build, *data, helpers, p, {})) + results.push_back(*np); // Very large modules might result in overflowing a jump offset; in this case we currently abandon the entire module if (!build.finalize()) { - for (NativeProto* result : results) - destroyNativeProto(result); + for (NativeProto result : results) + destroyExecData(result.execdata); return; } @@ -542,36 +539,32 @@ void compile(lua_State* L, int idx) if (!data->codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), int(build.code.size() * sizeof(build.code[0])), nativeData, sizeNativeData, codeStart)) { - for (NativeProto* result : results) - destroyNativeProto(result); + for (NativeProto result : results) + destroyExecData(result.execdata); return; } if (gPerfLogFn && results.size() > 0) { - gPerfLogFn(gPerfLogContext, uintptr_t(codeStart), results[0]->instOffsets[0], ""); + gPerfLogFn(gPerfLogContext, uintptr_t(codeStart), uint32_t(results[0].exectarget), ""); for (size_t i = 0; i < results.size(); ++i) { - uint32_t begin = results[i]->instOffsets[0]; - uint32_t end = i + 1 < results.size() ? results[i + 1]->instOffsets[0] : uint32_t(build.code.size() * sizeof(build.code[0])); + uint32_t begin = uint32_t(results[i].exectarget); + uint32_t end = i + 1 < results.size() ? uint32_t(results[i + 1].exectarget) : uint32_t(build.code.size() * sizeof(build.code[0])); LUAU_ASSERT(begin < end); - logPerfFunction(results[i]->proto, uintptr_t(codeStart) + begin, end - begin); + logPerfFunction(results[i].p, uintptr_t(codeStart) + begin, end - begin); } } - // Record instruction base address; at runtime, instOffsets[] will be used as offsets from instBase - for (NativeProto* result : results) + for (NativeProto result : results) { - result->instBase = uintptr_t(codeStart); - result->entryTarget = uintptr_t(codeStart) + result->instOffsets[0]; + // the memory is now managed by VM and will be freed via onDestroyFunction + result.p->execdata = result.execdata; + result.p->exectarget = uintptr_t(codeStart) + result.exectarget; } - - // Link native proto objects to Proto; the memory is now managed by VM and will be freed via onDestroyFunction - for (NativeProto* result : results) - setProtoExecData(result->proto, result); } std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) @@ -586,7 +579,7 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) #endif NativeState data; - initFallbackTable(data); + initFunctions(data); std::vector protos; gatherFunctions(protos, clvalue(func)->l.p); @@ -600,8 +593,8 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) for (Proto* p : protos) if (p) - if (NativeProto* np = assembleFunction(build, data, helpers, p, options)) - destroyNativeProto(np); + if (std::optional np = assembleFunction(build, data, helpers, p, options)) + destroyExecData(np->execdata); if (!build.finalize()) return std::string(); diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index fbe44e2..f6e9152 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -4,6 +4,7 @@ #include "Luau/AssemblyBuilderA64.h" #include "Luau/UnwindBuilder.h" +#include "BitUtils.h" #include "CustomExecUtils.h" #include "NativeState.h" #include "EmitCommonA64.h" @@ -91,6 +92,13 @@ static void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) // Need to update state of the current function before we jump away build.ldr(x1, mem(x0, offsetof(Closure, l.p))); // cl->l.p aka proto + build.ldr(x2, mem(rState, offsetof(lua_State, ci))); // L->ci + + // We need to check if the new frame can be executed natively + // TOOD: .flags and .savedpc load below can be fused with ldp + build.ldr(w3, mem(x2, offsetof(CallInfo, flags))); + build.tbz(x3, countrz(LUA_CALLINFO_CUSTOM), helpers.exitContinueVm); + build.mov(rClosure, x0); build.ldr(rConstants, mem(x1, offsetof(Proto, k))); // proto->k build.ldr(rCode, mem(x1, offsetof(Proto, code))); // proto->code @@ -98,22 +106,15 @@ static void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) // Get instruction index from instruction pointer // To get instruction index from instruction pointer, we need to divide byte offset by 4 // But we will actually need to scale instruction index by 4 back to byte offset later so it cancels out - // Note that we're computing negative offset here (code-savedpc) so that we can add it to NativeProto address, as we use reverse indexing - build.ldr(x2, mem(rState, offsetof(lua_State, ci))); // L->ci build.ldr(x2, mem(x2, offsetof(CallInfo, savedpc))); // L->ci->savedpc - build.sub(x2, rCode, x2); - - // We need to check if the new function can be executed natively - // TODO: This can be done earlier in the function flow, to reduce the JIT->VM transition penalty - build.ldr(x1, mem(x1, offsetofProtoExecData)); - build.cbz(x1, helpers.exitContinueVm); + build.sub(x2, x2, rCode); // Get new instruction location and jump to it - LUAU_ASSERT(offsetof(NativeProto, instOffsets) == 0); - build.ldr(w2, mem(x1, x2)); - build.ldr(x1, mem(x1, offsetof(NativeProto, instBase))); - build.add(x1, x1, x2); - build.br(x1); + LUAU_ASSERT(offsetof(Proto, exectarget) == offsetof(Proto, execdata) + 8); + build.ldp(x3, x4, mem(x1, offsetof(Proto, execdata))); + build.ldr(w2, mem(x3, x2)); + build.add(x4, x4, x2); + build.br(x4); } static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilder& unwind) diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index 37dfa11..4ad67d8 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -1,13 +1,64 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "CodeGenUtils.h" +#include "CustomExecUtils.h" + +#include "lvm.h" + +#include "lbuiltins.h" +#include "lbytecode.h" +#include "ldebug.h" #include "ldo.h" +#include "lfunc.h" +#include "lgc.h" +#include "lmem.h" +#include "lnumutils.h" +#include "lstate.h" +#include "lstring.h" #include "ltable.h" -#include "FallbacksProlog.h" - #include +LUAU_FASTFLAG(LuauUniformTopHandling) + +// All external function calls that can cause stack realloc or Lua calls have to be wrapped in VM_PROTECT +// This makes sure that we save the pc (in case the Lua call needs to generate a backtrace) before the call, +// and restores the stack pointer after in case stack gets reallocated +// Should only be used on the slow paths. +#define VM_PROTECT(x) \ + { \ + L->ci->savedpc = pc; \ + { \ + x; \ + }; \ + base = L->base; \ + } + +// Some external functions can cause an error, but never reallocate the stack; for these, VM_PROTECT_PC() is +// a cheaper version of VM_PROTECT that can be called before the external call. +#define VM_PROTECT_PC() L->ci->savedpc = pc + +#define VM_REG(i) (LUAU_ASSERT(unsigned(i) < unsigned(L->top - base)), &base[i]) +#define VM_KV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->l.p->sizek)), &k[i]) +#define VM_UV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->nupvalues)), &cl->l.uprefs[i]) + +#define VM_PATCH_C(pc, slot) *const_cast(pc) = ((uint8_t(slot) << 24) | (0x00ffffffu & *(pc))) +#define VM_PATCH_E(pc, slot) *const_cast(pc) = ((uint32_t(slot) << 8) | (0x000000ffu & *(pc))) + +#define VM_INTERRUPT() \ + { \ + void (*interrupt)(lua_State*, int) = L->global->cb.interrupt; \ + if (LUAU_UNLIKELY(!!interrupt)) \ + { /* the interrupt hook is called right before we advance pc */ \ + VM_PROTECT(L->ci->savedpc++; interrupt(L, -1)); \ + if (L->status != 0) \ + { \ + L->ci->savedpc--; \ + return NULL; \ + } \ + } \ + } + namespace Luau { namespace CodeGen @@ -215,6 +266,10 @@ Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) // keep executing new function ci->savedpc = p->code; + + if (LUAU_LIKELY(p->execdata != NULL)) + ci->flags = LUA_CALLINFO_CUSTOM; + return ccl; } else @@ -281,7 +336,8 @@ Closure* returnFallback(lua_State* L, StkId ra, StkId valend) // we're done! if (LUAU_UNLIKELY(ci->flags & LUA_CALLINFO_RETURN)) { - L->top = res; + if (!FFlag::LuauUniformTopHandling) + L->top = res; return NULL; } @@ -290,5 +346,614 @@ Closure* returnFallback(lua_State* L, StkId ra, StkId valend) return clvalue(cip->func); } +const Instruction* executeGETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path should already have been checked, so we skip checking for it here + Table* h = cl->env; + int slot = LUAU_INSN_C(insn) & h->nodemask8; + + // slow-path, may invoke Lua calls via __index metamethod + TValue g; + sethvalue(L, &g, h); + L->cachedslot = slot; + VM_PROTECT(luaV_gettable(L, &g, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; +} + +const Instruction* executeSETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path should already have been checked, so we skip checking for it here + Table* h = cl->env; + int slot = LUAU_INSN_C(insn) & h->nodemask8; + + // slow-path, may invoke Lua calls via __newindex metamethod + TValue g; + sethvalue(L, &g, h); + L->cachedslot = slot; + VM_PROTECT(luaV_settable(L, &g, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; +} + +const Instruction* executeGETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path: built-in table + if (ttistable(rb)) + { + Table* h = hvalue(rb); + + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + // fast-path: value is in expected slot + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)))) + { + setobj2s(L, ra, gval(n)); + return pc; + } + else if (!h->metatable) + { + // fast-path: value is not in expected slot, but the table lookup doesn't involve metatable + const TValue* res = luaH_getstr(h, tsvalue(kv)); + + if (res != luaO_nilobject) + { + int cachedslot = gval2slot(h, res); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, cachedslot); + } + + setobj2s(L, ra, res); + return pc; + } + else + { + // slow-path, may invoke Lua calls via __index metamethod + L->cachedslot = slot; + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + } + else + { + // fast-path: user data with C __index TM + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = fasttm(L, uvalue(rb)->metatable, TM_INDEX)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + L->top = top + 3; + + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + else if (ttisvector(rb)) + { + // fast-path: quick case-insensitive comparison with "X"/"Y"/"Z" + const char* name = getstr(tsvalue(kv)); + int ic = (name[0] | ' ') - 'x'; + +#if LUA_VECTOR_SIZE == 4 + // 'w' is before 'x' in ascii, so ic is -1 when indexing with 'w' + if (ic == -1) + ic = 3; +#endif + + if (unsigned(ic) < LUA_VECTOR_SIZE && name[1] == '\0') + { + const float* v = rb->value.v; // silences ubsan when indexing v[] + setnvalue(ra, v[ic]); + return pc; + } + + fn = fasttm(L, L->global->mt[LUA_TVECTOR], TM_INDEX); + + if (fn && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + L->top = top + 3; + + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + + // fall through to slow path + } + + // fall through to slow path + } + + // slow-path, may invoke Lua calls via __index metamethod + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + return pc; +} + +const Instruction* executeSETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path: built-in table + if (ttistable(rb)) + { + Table* h = hvalue(rb); + + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + // fast-path: value is in expected slot + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)) && !h->readonly)) + { + setobj2t(L, gval(n), ra); + luaC_barriert(L, h, ra); + return pc; + } + else if (fastnotm(h->metatable, TM_NEWINDEX) && !h->readonly) + { + VM_PROTECT_PC(); // set may fail + + TValue* res = luaH_setstr(L, h, tsvalue(kv)); + int cachedslot = gval2slot(h, res); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, cachedslot); + setobj2t(L, res, ra); + luaC_barriert(L, h, ra); + return pc; + } + else + { + // slow-path, may invoke Lua calls via __newindex metamethod + L->cachedslot = slot; + VM_PROTECT(luaV_settable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + } + else + { + // fast-path: user data with C __newindex TM + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = fasttm(L, uvalue(rb)->metatable, TM_NEWINDEX)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 4 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + setobj2s(L, top + 3, ra); + L->top = top + 4; + + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luaV_callTM(L, 3, -1)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + return pc; + } + else + { + // slow-path, may invoke Lua calls via __newindex metamethod + VM_PROTECT(luaV_settable(L, rb, kv, ra)); + return pc; + } + } +} + +const Instruction* executeNEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + Proto* pv = cl->l.p->p[LUAU_INSN_D(insn)]; + LUAU_ASSERT(unsigned(LUAU_INSN_D(insn)) < unsigned(cl->l.p->sizep)); + + VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM + + // note: we save closure to stack early in case the code below wants to capture it by value + Closure* ncl = luaF_newLclosure(L, pv->nups, cl->env, pv); + setclvalue(L, ra, ncl); + + for (int ui = 0; ui < pv->nups; ++ui) + { + Instruction uinsn = *pc++; + LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE); + + switch (LUAU_INSN_A(uinsn)) + { + case LCT_VAL: + setobj(L, &ncl->l.uprefs[ui], VM_REG(LUAU_INSN_B(uinsn))); + break; + + case LCT_REF: + setupvalue(L, &ncl->l.uprefs[ui], luaF_findupval(L, VM_REG(LUAU_INSN_B(uinsn)))); + break; + + case LCT_UPVAL: + setobj(L, &ncl->l.uprefs[ui], VM_UV(LUAU_INSN_B(uinsn))); + break; + + default: + LUAU_ASSERT(!"Unknown upvalue capture type"); + LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks + } + } + + VM_PROTECT(luaC_checkGC(L)); + return pc; +} + +const Instruction* executeNAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + if (ttistable(rb)) + { + Table* h = hvalue(rb); + // note: we can't use nodemask8 here because we need to query the main position of the table, and 8-bit nodemask8 only works + // for predictive lookups + LuaNode* n = &h->node[tsvalue(kv)->hash & (sizenode(h) - 1)]; + + const TValue* mt = 0; + const LuaNode* mtn = 0; + + // fast-path: key is in the table in expected slot + if (ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n))) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, gval(n)); + } + // fast-path: key is absent from the base, table has an __index table, and it has the result in the expected slot + else if (gnext(n) == 0 && (mt = fasttm(L, hvalue(rb)->metatable, TM_INDEX)) && ttistable(mt) && + (mtn = &hvalue(mt)->node[LUAU_INSN_C(insn) & hvalue(mt)->nodemask8]) && ttisstring(gkey(mtn)) && tsvalue(gkey(mtn)) == tsvalue(kv) && + !ttisnil(gval(mtn))) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, gval(mtn)); + } + else + { + // slow-path: handles full table lookup + setobj2s(L, ra + 1, rb); + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + if (ttisnil(ra)) + luaG_methoderror(L, ra + 1, tsvalue(kv)); + } + } + else + { + Table* mt = ttisuserdata(rb) ? uvalue(rb)->metatable : L->global->mt[ttype(rb)]; + const TValue* tmi = 0; + + // fast-path: metatable with __namecall + if (const TValue* fn = fasttm(L, mt, TM_NAMECALL)) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, fn); + + L->namecall = tsvalue(kv); + } + else if ((tmi = fasttm(L, mt, TM_INDEX)) && ttistable(tmi)) + { + Table* h = hvalue(tmi); + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + // fast-path: metatable with __index that has method in expected slot + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)))) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, gval(n)); + } + else + { + // slow-path: handles slot mismatch + setobj2s(L, ra + 1, rb); + L->cachedslot = slot; + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + if (ttisnil(ra)) + luaG_methoderror(L, ra + 1, tsvalue(kv)); + } + } + else + { + // slow-path: handles non-table __index + setobj2s(L, ra + 1, rb); + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + if (ttisnil(ra)) + luaG_methoderror(L, ra + 1, tsvalue(kv)); + } + } + + // intentional fallthrough to CALL + LUAU_ASSERT(LUAU_INSN_OP(*pc) == LOP_CALL); + return pc; +} + +const Instruction* executeSETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = &base[LUAU_INSN_B(insn)]; // note: this can point to L->top if c == LUA_MULTRET making VM_REG unsafe to use + int c = LUAU_INSN_C(insn) - 1; + uint32_t index = *pc++; + + if (c == LUA_MULTRET) + { + c = int(L->top - rb); + L->top = L->ci->top; + } + + Table* h = hvalue(ra); + + // TODO: we really don't need this anymore + if (!ttistable(ra)) + return NULL; // temporary workaround to weaken a rather powerful exploitation primitive in case of a MITM attack on bytecode + + int last = index + c - 1; + if (last > h->sizearray) + { + VM_PROTECT_PC(); // luaH_resizearray may fail due to OOM + + luaH_resizearray(L, h, last); + } + + TValue* array = h->array; + + for (int i = 0; i < c; ++i) + setobj2t(L, &array[index + i - 1], rb + i); + + luaC_barrierfast(L, h); + return pc; +} + +const Instruction* executeFORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + if (ttisfunction(ra)) + { + // will be called during FORGLOOP + } + else + { + Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); + + if (const TValue* fn = fasttm(L, mt, TM_ITER)) + { + setobj2s(L, ra + 1, ra); + setobj2s(L, ra, fn); + + L->top = ra + 2; // func + self arg + LUAU_ASSERT(L->top <= L->stack_last); + + VM_PROTECT(luaD_call(L, ra, 3)); + L->top = L->ci->top; + + // recompute ra since stack might have been reallocated + ra = VM_REG(LUAU_INSN_A(insn)); + + // protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP + if (ttisnil(ra)) + { + VM_PROTECT_PC(); // next call always errors + luaG_typeerror(L, ra, "call"); + } + } + else if (fasttm(L, mt, TM_CALL)) + { + // table or userdata with __call, will be called during FORGLOOP + // TODO: we might be able to stop supporting this depending on whether it's used in practice + } + else if (ttistable(ra)) + { + // set up registers for builtin iteration + setobj2s(L, ra + 1, ra); + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + setnilvalue(ra); + } + else + { + VM_PROTECT_PC(); // next call always errors + luaG_typeerror(L, ra, "iterate over"); + } + } + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + return pc; +} + +const Instruction* executeGETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + int b = LUAU_INSN_B(insn) - 1; + int n = cast_int(base - L->ci->func) - cl->l.p->numparams - 1; + + if (b == LUA_MULTRET) + { + VM_PROTECT(luaD_checkstack(L, n)); + StkId ra = VM_REG(LUAU_INSN_A(insn)); // previous call may change the stack + + for (int j = 0; j < n; j++) + setobj2s(L, ra + j, base - n + j); + + L->top = ra + n; + return pc; + } + else + { + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + for (int j = 0; j < b && j < n; j++) + setobj2s(L, ra + j, base - n + j); + for (int j = n; j < b; j++) + setnilvalue(ra + j); + return pc; + } +} + +const Instruction* executeDUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(LUAU_INSN_D(insn)); + + Closure* kcl = clvalue(kv); + + VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM + + // clone closure if the environment is not shared + // note: we save closure to stack early in case the code below wants to capture it by value + Closure* ncl = (kcl->env == cl->env) ? kcl : luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p); + setclvalue(L, ra, ncl); + + // this loop does three things: + // - if the closure was created anew, it just fills it with upvalues + // - if the closure from the constant table is used, it fills it with upvalues so that it can be shared in the future + // - if the closure is reused, it checks if the reuse is safe via rawequal, and falls back to duplicating the closure + // normally this would use two separate loops, for reuse check and upvalue setup, but MSVC codegen goes crazy if you do that + for (int ui = 0; ui < kcl->nupvalues; ++ui) + { + Instruction uinsn = pc[ui]; + LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE); + LUAU_ASSERT(LUAU_INSN_A(uinsn) == LCT_VAL || LUAU_INSN_A(uinsn) == LCT_UPVAL); + + TValue* uv = (LUAU_INSN_A(uinsn) == LCT_VAL) ? VM_REG(LUAU_INSN_B(uinsn)) : VM_UV(LUAU_INSN_B(uinsn)); + + // check if the existing closure is safe to reuse + if (ncl == kcl && luaO_rawequalObj(&ncl->l.uprefs[ui], uv)) + continue; + + // lazily clone the closure and update the upvalues + if (ncl == kcl && kcl->preload == 0) + { + ncl = luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p); + setclvalue(L, ra, ncl); + + ui = -1; // restart the loop to fill all upvalues + continue; + } + + // this updates a newly created closure, or an existing closure created during preload, in which case we need a barrier + setobj(L, &ncl->l.uprefs[ui], uv); + luaC_barrier(L, ncl, uv); + } + + // this is a noop if ncl is newly created or shared successfully, but it has to run after the closure is preloaded for the first time + ncl->preload = 0; + + if (kcl != ncl) + VM_PROTECT(luaC_checkGC(L)); + + pc += kcl->nupvalues; + return pc; +} + +const Instruction* executePREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k) +{ + [[maybe_unused]] Closure* cl = clvalue(L->ci->func); + Instruction insn = *pc++; + int numparams = LUAU_INSN_A(insn); + + // all fixed parameters are copied after the top so we need more stack space + VM_PROTECT(luaD_checkstack(L, cl->stacksize + numparams)); + + // the caller must have filled extra fixed arguments with nil + LUAU_ASSERT(cast_int(L->top - base) >= numparams); + + // move fixed parameters to final position + StkId fixed = base; // first fixed argument + base = L->top; // final position of first argument + + for (int i = 0; i < numparams; ++i) + { + setobj2s(L, base + i, fixed + i); + setnilvalue(fixed + i); + } + + // rewire our stack frame to point to the new base + L->ci->base = base; + L->ci->top = base + cl->stacksize; + + L->base = base; + L->top = L->ci->top; + return pc; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenUtils.h b/CodeGen/src/CodeGenUtils.h index 4ce3566..87b6ec4 100644 --- a/CodeGen/src/CodeGenUtils.h +++ b/CodeGen/src/CodeGenUtils.h @@ -20,5 +20,17 @@ void callEpilogC(lua_State* L, int nresults, int n); Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults); Closure* returnFallback(lua_State* L, StkId ra, StkId valend); +const Instruction* executeGETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeSETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeGETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeSETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeNEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeNAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeSETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeFORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeGETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executeDUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); +const Instruction* executePREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CustomExecUtils.h b/CodeGen/src/CustomExecUtils.h index 9526d6d..9c99966 100644 --- a/CodeGen/src/CustomExecUtils.h +++ b/CodeGen/src/CustomExecUtils.h @@ -46,21 +46,6 @@ inline void destroyNativeState(lua_State* L) delete state; } -inline NativeProto* getProtoExecData(Proto* proto) -{ - return (NativeProto*)proto->execdata; -} - -inline void setProtoExecData(Proto* proto, NativeProto* nativeProto) -{ - if (nativeProto) - LUAU_ASSERT(proto->execdata == nullptr); - - proto->execdata = nativeProto; -} - -#define offsetofProtoExecData offsetof(Proto, execdata) - #else inline lua_ExecutionCallbacks* getExecutionCallbacks(lua_State* L) @@ -82,15 +67,6 @@ inline NativeState* createNativeState(lua_State* L) inline void destroyNativeState(lua_State* L) {} -inline NativeProto* getProtoExecData(Proto* proto) -{ - return nullptr; -} - -inline void setProtoExecData(Proto* proto, NativeProto* nativeProto) {} - -#define offsetofProtoExecData 0 - #endif inline int getOpLength(LuauOpcode op) diff --git a/CodeGen/src/EmitCommon.h b/CodeGen/src/EmitCommon.h index 6a74966..6b19912 100644 --- a/CodeGen/src/EmitCommon.h +++ b/CodeGen/src/EmitCommon.h @@ -10,11 +10,12 @@ namespace CodeGen constexpr unsigned kTValueSizeLog2 = 4; constexpr unsigned kLuaNodeSizeLog2 = 5; -constexpr unsigned kLuaNodeTagMask = 0xf; -constexpr unsigned kNextBitOffset = 4; -constexpr unsigned kOffsetOfTKeyTag = 12; // offsetof cannot be used on a bit field -constexpr unsigned kOffsetOfTKeyNext = 12; // offsetof cannot be used on a bit field +// TKey.tt and TKey.next are packed together in a bitfield +constexpr unsigned kOffsetOfTKeyTagNext = 12; // offsetof cannot be used on a bit field +constexpr unsigned kTKeyTagBits = 4; +constexpr unsigned kTKeyTagMask = (1 << kTKeyTagBits) - 1; + constexpr unsigned kOffsetOfInstructionC = 3; // Leaf functions that are placed in every module to perform common instruction sequences diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index b6ef957..ce95e74 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -325,10 +325,8 @@ void emitInterrupt(IrRegAllocX64& regs, AssemblyBuilderX64& build, int pcpos) build.setLabel(skip); } -void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, NativeState& data, int op, int pcpos) +void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, int pcpos) { - LUAU_ASSERT(data.context.fallback[op]); - // fallback(L, instruction, base, k) IrCallWrapperX64 callWrap(regs, build); callWrap.addArgument(SizeX64::qword, rState); @@ -339,7 +337,7 @@ void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, NativeState& d callWrap.addArgument(SizeX64::qword, rBase); callWrap.addArgument(SizeX64::qword, rConstants); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, fallback) + op * sizeof(FallbackFn)]); + callWrap.call(qword[rNativeContext + offset]); emitUpdateBase(build); } diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index d4684fe..ddc4048 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -136,7 +136,7 @@ inline OperandX64 luauNodeKeyValue(RegisterX64 node) // Note: tag has dirty upper bits inline OperandX64 luauNodeKeyTag(RegisterX64 node) { - return dword[node + offsetof(LuaNode, key) + kOffsetOfTKeyTag]; + return dword[node + offsetof(LuaNode, key) + kOffsetOfTKeyTagNext]; } inline OperandX64 luauNodeValue(RegisterX64 node) @@ -189,7 +189,7 @@ inline void jumpIfNodeKeyTagIsNot(AssemblyBuilderX64& build, RegisterX64 tmp, Re tmp.size = SizeX64::dword; build.mov(tmp, luauNodeKeyTag(node)); - build.and_(tmp, kLuaNodeTagMask); + build.and_(tmp, kTKeyTagMask); build.cmp(tmp, tag); build.jcc(ConditionX64::NotEqual, label); } @@ -230,7 +230,7 @@ void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build); void emitExit(AssemblyBuilderX64& build, bool continueInVm); void emitUpdateBase(AssemblyBuilderX64& build); void emitInterrupt(IrRegAllocX64& regs, AssemblyBuilderX64& build, int pcpos); -void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, NativeState& data, int op, int pcpos); +void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, int pcpos); void emitContinueCallInVm(AssemblyBuilderX64& build); diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index 19f0cb8..b2db7d1 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -73,8 +73,6 @@ void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int build.mov(rax, qword[ci + offsetof(CallInfo, top)]); build.mov(qword[rState + offsetof(lua_State, top)], rax); - build.mov(rax, qword[proto + offsetofProtoExecData]); // We'll need this value later - // But if it is vararg, update it to 'argi' Label skipVararg; @@ -84,10 +82,14 @@ void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int build.mov(qword[rState + offsetof(lua_State, top)], argi); build.setLabel(skipVararg); - // Check native function data + // Get native function entry + build.mov(rax, qword[proto + offsetof(Proto, exectarget)]); build.test(rax, rax); build.jcc(ConditionX64::Zero, helpers.continueCallInVm); + // Mark call frame as custom + build.mov(dword[ci + offsetof(CallInfo, flags)], LUA_CALLINFO_CUSTOM); + // Switch current constants build.mov(rConstants, qword[proto + offsetof(Proto, k)]); @@ -95,7 +97,7 @@ void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int build.mov(rdx, qword[proto + offsetof(Proto, code)]); build.mov(sCode, rdx); - build.jmp(qword[rax + offsetof(NativeProto, entryTarget)]); + build.jmp(rax); } build.setLabel(cFuncCall); @@ -294,8 +296,9 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i build.mov(proto, qword[rax + offsetof(Closure, l.p)]); - build.mov(execdata, qword[proto + offsetofProtoExecData]); - build.test(execdata, execdata); + build.mov(execdata, qword[proto + offsetof(Proto, execdata)]); + + build.test(byte[cip + offsetof(CallInfo, flags)], LUA_CALLINFO_CUSTOM); build.jcc(ConditionX64::Zero, helpers.exitContinueVm); // Continue in interpreter if function has no native data // Change constants @@ -309,13 +312,11 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i // To get instruction index from instruction pointer, we need to divide byte offset by 4 // But we will actually need to scale instruction index by 4 back to byte offset later so it cancels out - // Note that we're computing negative offset here (code-savedpc) so that we can add it to NativeProto address, as we use reverse indexing - build.sub(rdx, rax); + build.sub(rax, rdx); // Get new instruction location and jump to it - LUAU_ASSERT(offsetof(NativeProto, instOffsets) == 0); - build.mov(edx, dword[execdata + rdx]); - build.add(rdx, qword[execdata + offsetof(NativeProto, instBase)]); + build.mov(edx, dword[execdata + rax]); + build.add(rdx, qword[proto + offsetof(Proto, exectarget)]); build.jmp(rdx); } diff --git a/CodeGen/src/Fallbacks.cpp b/CodeGen/src/Fallbacks.cpp deleted file mode 100644 index 1c0dce5..0000000 --- a/CodeGen/src/Fallbacks.cpp +++ /dev/null @@ -1,639 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details -// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand -#include "Fallbacks.h" -#include "FallbacksProlog.h" - -const Instruction* execute_LOP_GETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - uint32_t aux = *pc++; - TValue* kv = VM_KV(aux); - LUAU_ASSERT(ttisstring(kv)); - - // fast-path: value is in expected slot - Table* h = cl->env; - int slot = LUAU_INSN_C(insn) & h->nodemask8; - LuaNode* n = &h->node[slot]; - - if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv)) && !ttisnil(gval(n))) - { - setobj2s(L, ra, gval(n)); - return pc; - } - else - { - // slow-path, may invoke Lua calls via __index metamethod - TValue g; - sethvalue(L, &g, h); - L->cachedslot = slot; - VM_PROTECT(luaV_gettable(L, &g, kv, ra)); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - return pc; - } -} - -const Instruction* execute_LOP_SETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - uint32_t aux = *pc++; - TValue* kv = VM_KV(aux); - LUAU_ASSERT(ttisstring(kv)); - - // fast-path: value is in expected slot - Table* h = cl->env; - int slot = LUAU_INSN_C(insn) & h->nodemask8; - LuaNode* n = &h->node[slot]; - - if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)) && !h->readonly)) - { - setobj2t(L, gval(n), ra); - luaC_barriert(L, h, ra); - return pc; - } - else - { - // slow-path, may invoke Lua calls via __newindex metamethod - TValue g; - sethvalue(L, &g, h); - L->cachedslot = slot; - VM_PROTECT(luaV_settable(L, &g, kv, ra)); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - return pc; - } -} - -const Instruction* execute_LOP_GETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - StkId rb = VM_REG(LUAU_INSN_B(insn)); - uint32_t aux = *pc++; - TValue* kv = VM_KV(aux); - LUAU_ASSERT(ttisstring(kv)); - - // fast-path: built-in table - if (ttistable(rb)) - { - Table* h = hvalue(rb); - - int slot = LUAU_INSN_C(insn) & h->nodemask8; - LuaNode* n = &h->node[slot]; - - // fast-path: value is in expected slot - if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)))) - { - setobj2s(L, ra, gval(n)); - return pc; - } - else if (!h->metatable) - { - // fast-path: value is not in expected slot, but the table lookup doesn't involve metatable - const TValue* res = luaH_getstr(h, tsvalue(kv)); - - if (res != luaO_nilobject) - { - int cachedslot = gval2slot(h, res); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, cachedslot); - } - - setobj2s(L, ra, res); - return pc; - } - else - { - // slow-path, may invoke Lua calls via __index metamethod - L->cachedslot = slot; - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - return pc; - } - } - else - { - // fast-path: user data with C __index TM - const TValue* fn = 0; - if (ttisuserdata(rb) && (fn = fasttm(L, uvalue(rb)->metatable, TM_INDEX)) && ttisfunction(fn) && clvalue(fn)->isC) - { - // note: it's safe to push arguments past top for complicated reasons (see top of the file) - LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); - StkId top = L->top; - setobj2s(L, top + 0, fn); - setobj2s(L, top + 1, rb); - setobj2s(L, top + 2, kv); - L->top = top + 3; - - L->cachedslot = LUAU_INSN_C(insn); - VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - return pc; - } - else if (ttisvector(rb)) - { - // fast-path: quick case-insensitive comparison with "X"/"Y"/"Z" - const char* name = getstr(tsvalue(kv)); - int ic = (name[0] | ' ') - 'x'; - -#if LUA_VECTOR_SIZE == 4 - // 'w' is before 'x' in ascii, so ic is -1 when indexing with 'w' - if (ic == -1) - ic = 3; -#endif - - if (unsigned(ic) < LUA_VECTOR_SIZE && name[1] == '\0') - { - const float* v = rb->value.v; // silences ubsan when indexing v[] - setnvalue(ra, v[ic]); - return pc; - } - - fn = fasttm(L, L->global->mt[LUA_TVECTOR], TM_INDEX); - - if (fn && ttisfunction(fn) && clvalue(fn)->isC) - { - // note: it's safe to push arguments past top for complicated reasons (see top of the file) - LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); - StkId top = L->top; - setobj2s(L, top + 0, fn); - setobj2s(L, top + 1, rb); - setobj2s(L, top + 2, kv); - L->top = top + 3; - - L->cachedslot = LUAU_INSN_C(insn); - VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - return pc; - } - - // fall through to slow path - } - - // fall through to slow path - } - - // slow-path, may invoke Lua calls via __index metamethod - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - return pc; -} - -const Instruction* execute_LOP_SETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - StkId rb = VM_REG(LUAU_INSN_B(insn)); - uint32_t aux = *pc++; - TValue* kv = VM_KV(aux); - LUAU_ASSERT(ttisstring(kv)); - - // fast-path: built-in table - if (ttistable(rb)) - { - Table* h = hvalue(rb); - - int slot = LUAU_INSN_C(insn) & h->nodemask8; - LuaNode* n = &h->node[slot]; - - // fast-path: value is in expected slot - if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)) && !h->readonly)) - { - setobj2t(L, gval(n), ra); - luaC_barriert(L, h, ra); - return pc; - } - else if (fastnotm(h->metatable, TM_NEWINDEX) && !h->readonly) - { - VM_PROTECT_PC(); // set may fail - - TValue* res = luaH_setstr(L, h, tsvalue(kv)); - int cachedslot = gval2slot(h, res); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, cachedslot); - setobj2t(L, res, ra); - luaC_barriert(L, h, ra); - return pc; - } - else - { - // slow-path, may invoke Lua calls via __newindex metamethod - L->cachedslot = slot; - VM_PROTECT(luaV_settable(L, rb, kv, ra)); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - return pc; - } - } - else - { - // fast-path: user data with C __newindex TM - const TValue* fn = 0; - if (ttisuserdata(rb) && (fn = fasttm(L, uvalue(rb)->metatable, TM_NEWINDEX)) && ttisfunction(fn) && clvalue(fn)->isC) - { - // note: it's safe to push arguments past top for complicated reasons (see top of the file) - LUAU_ASSERT(L->top + 4 < L->stack + L->stacksize); - StkId top = L->top; - setobj2s(L, top + 0, fn); - setobj2s(L, top + 1, rb); - setobj2s(L, top + 2, kv); - setobj2s(L, top + 3, ra); - L->top = top + 4; - - L->cachedslot = LUAU_INSN_C(insn); - VM_PROTECT(luaV_callTM(L, 3, -1)); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - return pc; - } - else - { - // slow-path, may invoke Lua calls via __newindex metamethod - VM_PROTECT(luaV_settable(L, rb, kv, ra)); - return pc; - } - } -} - -const Instruction* execute_LOP_NEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - - Proto* pv = cl->l.p->p[LUAU_INSN_D(insn)]; - LUAU_ASSERT(unsigned(LUAU_INSN_D(insn)) < unsigned(cl->l.p->sizep)); - - VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM - - // note: we save closure to stack early in case the code below wants to capture it by value - Closure* ncl = luaF_newLclosure(L, pv->nups, cl->env, pv); - setclvalue(L, ra, ncl); - - for (int ui = 0; ui < pv->nups; ++ui) - { - Instruction uinsn = *pc++; - LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE); - - switch (LUAU_INSN_A(uinsn)) - { - case LCT_VAL: - setobj(L, &ncl->l.uprefs[ui], VM_REG(LUAU_INSN_B(uinsn))); - break; - - case LCT_REF: - setupvalue(L, &ncl->l.uprefs[ui], luaF_findupval(L, VM_REG(LUAU_INSN_B(uinsn)))); - break; - - case LCT_UPVAL: - setobj(L, &ncl->l.uprefs[ui], VM_UV(LUAU_INSN_B(uinsn))); - break; - - default: - LUAU_ASSERT(!"Unknown upvalue capture type"); - LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks - } - } - - VM_PROTECT(luaC_checkGC(L)); - return pc; -} - -const Instruction* execute_LOP_NAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - StkId rb = VM_REG(LUAU_INSN_B(insn)); - uint32_t aux = *pc++; - TValue* kv = VM_KV(aux); - LUAU_ASSERT(ttisstring(kv)); - - if (ttistable(rb)) - { - Table* h = hvalue(rb); - // note: we can't use nodemask8 here because we need to query the main position of the table, and 8-bit nodemask8 only works - // for predictive lookups - LuaNode* n = &h->node[tsvalue(kv)->hash & (sizenode(h) - 1)]; - - const TValue* mt = 0; - const LuaNode* mtn = 0; - - // fast-path: key is in the table in expected slot - if (ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n))) - { - // note: order of copies allows rb to alias ra+1 or ra - setobj2s(L, ra + 1, rb); - setobj2s(L, ra, gval(n)); - } - // fast-path: key is absent from the base, table has an __index table, and it has the result in the expected slot - else if (gnext(n) == 0 && (mt = fasttm(L, hvalue(rb)->metatable, TM_INDEX)) && ttistable(mt) && - (mtn = &hvalue(mt)->node[LUAU_INSN_C(insn) & hvalue(mt)->nodemask8]) && ttisstring(gkey(mtn)) && tsvalue(gkey(mtn)) == tsvalue(kv) && - !ttisnil(gval(mtn))) - { - // note: order of copies allows rb to alias ra+1 or ra - setobj2s(L, ra + 1, rb); - setobj2s(L, ra, gval(mtn)); - } - else - { - // slow-path: handles full table lookup - setobj2s(L, ra + 1, rb); - L->cachedslot = LUAU_INSN_C(insn); - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - // recompute ra since stack might have been reallocated - ra = VM_REG(LUAU_INSN_A(insn)); - if (ttisnil(ra)) - luaG_methoderror(L, ra + 1, tsvalue(kv)); - } - } - else - { - Table* mt = ttisuserdata(rb) ? uvalue(rb)->metatable : L->global->mt[ttype(rb)]; - const TValue* tmi = 0; - - // fast-path: metatable with __namecall - if (const TValue* fn = fasttm(L, mt, TM_NAMECALL)) - { - // note: order of copies allows rb to alias ra+1 or ra - setobj2s(L, ra + 1, rb); - setobj2s(L, ra, fn); - - L->namecall = tsvalue(kv); - } - else if ((tmi = fasttm(L, mt, TM_INDEX)) && ttistable(tmi)) - { - Table* h = hvalue(tmi); - int slot = LUAU_INSN_C(insn) & h->nodemask8; - LuaNode* n = &h->node[slot]; - - // fast-path: metatable with __index that has method in expected slot - if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)))) - { - // note: order of copies allows rb to alias ra+1 or ra - setobj2s(L, ra + 1, rb); - setobj2s(L, ra, gval(n)); - } - else - { - // slow-path: handles slot mismatch - setobj2s(L, ra + 1, rb); - L->cachedslot = slot; - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, L->cachedslot); - // recompute ra since stack might have been reallocated - ra = VM_REG(LUAU_INSN_A(insn)); - if (ttisnil(ra)) - luaG_methoderror(L, ra + 1, tsvalue(kv)); - } - } - else - { - // slow-path: handles non-table __index - setobj2s(L, ra + 1, rb); - VM_PROTECT(luaV_gettable(L, rb, kv, ra)); - // recompute ra since stack might have been reallocated - ra = VM_REG(LUAU_INSN_A(insn)); - if (ttisnil(ra)) - luaG_methoderror(L, ra + 1, tsvalue(kv)); - } - } - - // intentional fallthrough to CALL - LUAU_ASSERT(LUAU_INSN_OP(*pc) == LOP_CALL); - return pc; -} - -const Instruction* execute_LOP_SETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - StkId rb = &base[LUAU_INSN_B(insn)]; // note: this can point to L->top if c == LUA_MULTRET making VM_REG unsafe to use - int c = LUAU_INSN_C(insn) - 1; - uint32_t index = *pc++; - - if (c == LUA_MULTRET) - { - c = int(L->top - rb); - L->top = L->ci->top; - } - - Table* h = hvalue(ra); - - // TODO: we really don't need this anymore - if (!ttistable(ra)) - return NULL; // temporary workaround to weaken a rather powerful exploitation primitive in case of a MITM attack on bytecode - - int last = index + c - 1; - if (last > h->sizearray) - { - VM_PROTECT_PC(); // luaH_resizearray may fail due to OOM - - luaH_resizearray(L, h, last); - } - - TValue* array = h->array; - - for (int i = 0; i < c; ++i) - setobj2t(L, &array[index + i - 1], rb + i); - - luaC_barrierfast(L, h); - return pc; -} - -const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - - if (ttisfunction(ra)) - { - // will be called during FORGLOOP - } - else - { - Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); - - if (const TValue* fn = fasttm(L, mt, TM_ITER)) - { - setobj2s(L, ra + 1, ra); - setobj2s(L, ra, fn); - - L->top = ra + 2; // func + self arg - LUAU_ASSERT(L->top <= L->stack_last); - - VM_PROTECT(luaD_call(L, ra, 3)); - L->top = L->ci->top; - - // recompute ra since stack might have been reallocated - ra = VM_REG(LUAU_INSN_A(insn)); - - // protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP - if (ttisnil(ra)) - { - VM_PROTECT_PC(); // next call always errors - luaG_typeerror(L, ra, "call"); - } - } - else if (fasttm(L, mt, TM_CALL)) - { - // table or userdata with __call, will be called during FORGLOOP - // TODO: we might be able to stop supporting this depending on whether it's used in practice - } - else if (ttistable(ra)) - { - // set up registers for builtin iteration - setobj2s(L, ra + 1, ra); - setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); - setnilvalue(ra); - } - else - { - VM_PROTECT_PC(); // next call always errors - luaG_typeerror(L, ra, "iterate over"); - } - } - - pc += LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - return pc; -} - -const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - int b = LUAU_INSN_B(insn) - 1; - int n = cast_int(base - L->ci->func) - cl->l.p->numparams - 1; - - if (b == LUA_MULTRET) - { - VM_PROTECT(luaD_checkstack(L, n)); - StkId ra = VM_REG(LUAU_INSN_A(insn)); // previous call may change the stack - - for (int j = 0; j < n; j++) - setobj2s(L, ra + j, base - n + j); - - L->top = ra + n; - return pc; - } - else - { - StkId ra = VM_REG(LUAU_INSN_A(insn)); - - for (int j = 0; j < b && j < n; j++) - setobj2s(L, ra + j, base - n + j); - for (int j = n; j < b; j++) - setnilvalue(ra + j); - return pc; - } -} - -const Instruction* execute_LOP_DUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - TValue* kv = VM_KV(LUAU_INSN_D(insn)); - - Closure* kcl = clvalue(kv); - - VM_PROTECT_PC(); // luaF_newLclosure may fail due to OOM - - // clone closure if the environment is not shared - // note: we save closure to stack early in case the code below wants to capture it by value - Closure* ncl = (kcl->env == cl->env) ? kcl : luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p); - setclvalue(L, ra, ncl); - - // this loop does three things: - // - if the closure was created anew, it just fills it with upvalues - // - if the closure from the constant table is used, it fills it with upvalues so that it can be shared in the future - // - if the closure is reused, it checks if the reuse is safe via rawequal, and falls back to duplicating the closure - // normally this would use two separate loops, for reuse check and upvalue setup, but MSVC codegen goes crazy if you do that - for (int ui = 0; ui < kcl->nupvalues; ++ui) - { - Instruction uinsn = pc[ui]; - LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE); - LUAU_ASSERT(LUAU_INSN_A(uinsn) == LCT_VAL || LUAU_INSN_A(uinsn) == LCT_UPVAL); - - TValue* uv = (LUAU_INSN_A(uinsn) == LCT_VAL) ? VM_REG(LUAU_INSN_B(uinsn)) : VM_UV(LUAU_INSN_B(uinsn)); - - // check if the existing closure is safe to reuse - if (ncl == kcl && luaO_rawequalObj(&ncl->l.uprefs[ui], uv)) - continue; - - // lazily clone the closure and update the upvalues - if (ncl == kcl && kcl->preload == 0) - { - ncl = luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p); - setclvalue(L, ra, ncl); - - ui = -1; // restart the loop to fill all upvalues - continue; - } - - // this updates a newly created closure, or an existing closure created during preload, in which case we need a barrier - setobj(L, &ncl->l.uprefs[ui], uv); - luaC_barrier(L, ncl, uv); - } - - // this is a noop if ncl is newly created or shared successfully, but it has to run after the closure is preloaded for the first time - ncl->preload = 0; - - if (kcl != ncl) - VM_PROTECT(luaC_checkGC(L)); - - pc += kcl->nupvalues; - return pc; -} - -const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - [[maybe_unused]] Closure* cl = clvalue(L->ci->func); - Instruction insn = *pc++; - int numparams = LUAU_INSN_A(insn); - - // all fixed parameters are copied after the top so we need more stack space - VM_PROTECT(luaD_checkstack(L, cl->stacksize + numparams)); - - // the caller must have filled extra fixed arguments with nil - LUAU_ASSERT(cast_int(L->top - base) >= numparams); - - // move fixed parameters to final position - StkId fixed = base; // first fixed argument - base = L->top; // final position of first argument - - for (int i = 0; i < numparams; ++i) - { - setobj2s(L, base + i, fixed + i); - setnilvalue(fixed + i); - } - - // rewire our stack frame to point to the new base - L->ci->base = base; - L->ci->top = base + cl->stacksize; - - L->base = base; - L->top = L->ci->top; - return pc; -} - -const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, StkId base, TValue* k) -{ - LUAU_ASSERT(!"Unsupported deprecated opcode"); - LUAU_UNREACHABLE(); -} diff --git a/CodeGen/src/Fallbacks.h b/CodeGen/src/Fallbacks.h deleted file mode 100644 index 0d2d218..0000000 --- a/CodeGen/src/Fallbacks.h +++ /dev/null @@ -1,24 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand -#pragma once - -#include - -struct lua_State; -struct Closure; -typedef uint32_t Instruction; -typedef struct lua_TValue TValue; -typedef TValue* StkId; - -const Instruction* execute_LOP_GETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_SETGLOBAL(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_GETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_SETTABLEKS(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_NEWCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_NAMECALL(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_SETLIST(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_FORGPREP(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_GETVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_DUPCLOSURE(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_PREPVARARGS(lua_State* L, const Instruction* pc, StkId base, TValue* k); -const Instruction* execute_LOP_BREAK(lua_State* L, const Instruction* pc, StkId base, TValue* k); diff --git a/CodeGen/src/FallbacksProlog.h b/CodeGen/src/FallbacksProlog.h deleted file mode 100644 index bbb06b8..0000000 --- a/CodeGen/src/FallbacksProlog.h +++ /dev/null @@ -1,56 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#pragma once - -#include "lvm.h" - -#include "lbuiltins.h" -#include "lbytecode.h" -#include "ldebug.h" -#include "ldo.h" -#include "lfunc.h" -#include "lgc.h" -#include "lmem.h" -#include "lnumutils.h" -#include "lstate.h" -#include "lstring.h" -#include "ltable.h" - -#include - -// All external function calls that can cause stack realloc or Lua calls have to be wrapped in VM_PROTECT -// This makes sure that we save the pc (in case the Lua call needs to generate a backtrace) before the call, -// and restores the stack pointer after in case stack gets reallocated -// Should only be used on the slow paths. -#define VM_PROTECT(x) \ - { \ - L->ci->savedpc = pc; \ - { \ - x; \ - }; \ - base = L->base; \ - } - -// Some external functions can cause an error, but never reallocate the stack; for these, VM_PROTECT_PC() is -// a cheaper version of VM_PROTECT that can be called before the external call. -#define VM_PROTECT_PC() L->ci->savedpc = pc - -#define VM_REG(i) (LUAU_ASSERT(unsigned(i) < unsigned(L->top - base)), &base[i]) -#define VM_KV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->l.p->sizek)), &k[i]) -#define VM_UV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->nupvalues)), &cl->l.uprefs[i]) - -#define VM_PATCH_C(pc, slot) *const_cast(pc) = ((uint8_t(slot) << 24) | (0x00ffffffu & *(pc))) -#define VM_PATCH_E(pc, slot) *const_cast(pc) = ((uint32_t(slot) << 8) | (0x000000ffu & *(pc))) - -#define VM_INTERRUPT() \ - { \ - void (*interrupt)(lua_State*, int) = L->global->cb.interrupt; \ - if (LUAU_UNLIKELY(!!interrupt)) \ - { /* the interrupt hook is called right before we advance pc */ \ - VM_PROTECT(L->ci->savedpc++; interrupt(L, -1)); \ - if (L->status != 0) \ - { \ - L->ci->savedpc--; \ - return NULL; \ - } \ - } \ - } diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 3ac37ef..711baba 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -96,14 +96,14 @@ static void emitAddOffset(AssemblyBuilderA64& build, RegisterA64 dst, RegisterA6 } } -static void emitFallback(AssemblyBuilderA64& build, int op, int pcpos) +static void emitFallback(AssemblyBuilderA64& build, int offset, int pcpos) { // fallback(L, instruction, base, k) build.mov(x0, rState); emitAddOffset(build, x1, rCode, pcpos * sizeof(Instruction)); build.mov(x2, rBase); build.mov(x3, rConstants); - build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, fallback) + op * sizeof(FallbackFn))); + build.ldr(x4, mem(rNativeContext, offset)); build.blr(x4); emitUpdateBase(build); @@ -658,30 +658,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.e), next); break; } - case IrCmd::JUMP_SLOT_MATCH: - { - // TODO: share code with CHECK_SLOT_MATCH - RegisterA64 temp1 = regs.allocTemp(KindA64::x); - RegisterA64 temp1w = castReg(KindA64::w, temp1); - RegisterA64 temp2 = regs.allocTemp(KindA64::x); - - build.ldr(temp1w, mem(regOp(inst.a), offsetof(LuaNode, key) + kOffsetOfTKeyTag)); - build.and_(temp1w, temp1w, kLuaNodeTagMask); - build.cmp(temp1w, LUA_TSTRING); - build.b(ConditionA64::NotEqual, labelOp(inst.d)); - - AddressA64 addr = tempAddr(inst.b, offsetof(TValue, value)); - build.ldr(temp1, mem(regOp(inst.a), offsetof(LuaNode, key.value))); - build.ldr(temp2, addr); - build.cmp(temp1, temp2); - build.b(ConditionA64::NotEqual, labelOp(inst.d)); - - build.ldr(temp1w, mem(regOp(inst.a), offsetof(LuaNode, val.tt))); - LUAU_ASSERT(LUA_TNIL == 0); - build.cbz(temp1w, labelOp(inst.d)); - jumpOrFallthrough(blockOp(inst.c), next); - break; - } + // IrCmd::JUMP_SLOT_MATCH implemented below case IrCmd::TABLE_LEN: { RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads @@ -1078,34 +1055,40 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.b(ConditionA64::UnsignedLessEqual, labelOp(inst.c)); break; } + case IrCmd::JUMP_SLOT_MATCH: case IrCmd::CHECK_SLOT_MATCH: { + Label& mismatch = inst.cmd == IrCmd::JUMP_SLOT_MATCH ? labelOp(inst.d) : labelOp(inst.c); + RegisterA64 temp1 = regs.allocTemp(KindA64::x); RegisterA64 temp1w = castReg(KindA64::w, temp1); RegisterA64 temp2 = regs.allocTemp(KindA64::x); - build.ldr(temp1w, mem(regOp(inst.a), offsetof(LuaNode, key) + kOffsetOfTKeyTag)); - build.and_(temp1w, temp1w, kLuaNodeTagMask); - build.cmp(temp1w, LUA_TSTRING); - build.b(ConditionA64::NotEqual, labelOp(inst.c)); + LUAU_ASSERT(offsetof(LuaNode, key.value) == offsetof(LuaNode, key) && kOffsetOfTKeyTagNext >= 8 && kOffsetOfTKeyTagNext < 16); + build.ldp(temp1, temp2, mem(regOp(inst.a), offsetof(LuaNode, key))); // load key.value into temp1 and key.tt (alongside other bits) into temp2 + build.ubfx(temp2, temp2, (kOffsetOfTKeyTagNext - 8) * 8, kTKeyTagBits); // .tt is right before .next, and 8 bytes are skipped by ldp + build.cmp(temp2, LUA_TSTRING); + build.b(ConditionA64::NotEqual, mismatch); AddressA64 addr = tempAddr(inst.b, offsetof(TValue, value)); - build.ldr(temp1, mem(regOp(inst.a), offsetof(LuaNode, key.value))); build.ldr(temp2, addr); build.cmp(temp1, temp2); - build.b(ConditionA64::NotEqual, labelOp(inst.c)); + build.b(ConditionA64::NotEqual, mismatch); build.ldr(temp1w, mem(regOp(inst.a), offsetof(LuaNode, val.tt))); LUAU_ASSERT(LUA_TNIL == 0); - build.cbz(temp1w, labelOp(inst.c)); + build.cbz(temp1w, mismatch); + + if (inst.cmd == IrCmd::JUMP_SLOT_MATCH) + jumpOrFallthrough(blockOp(inst.c), next); break; } case IrCmd::CHECK_NODE_NO_NEXT: { RegisterA64 temp = regs.allocTemp(KindA64::w); - build.ldr(temp, mem(regOp(inst.a), offsetof(LuaNode, key) + kOffsetOfTKeyNext)); - build.lsr(temp, temp, kNextBitOffset); + build.ldr(temp, mem(regOp(inst.a), offsetof(LuaNode, key) + kOffsetOfTKeyTagNext)); + build.lsr(temp, temp, kTKeyTagBits); build.cbnz(temp, labelOp(inst.b)); break; } @@ -1139,6 +1122,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) Label skip; build.ldr(temp1, mem(rState, offsetof(lua_State, global))); + // TODO: totalbytes and GCthreshold loads can be fused with ldp build.ldr(temp2, mem(temp1, offsetof(global_State, totalbytes))); build.ldr(temp1, mem(temp1, offsetof(global_State, GCthreshold))); build.cmp(temp1, temp2); @@ -1265,7 +1249,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::SETLIST: regs.spill(build, index); - emitFallback(build, LOP_SETLIST, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeSETLIST), uintOp(inst.a)); break; case IrCmd::CALL: regs.spill(build, index); @@ -1368,14 +1352,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); regs.spill(build, index); - emitFallback(build, LOP_GETGLOBAL, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeGETGLOBAL), uintOp(inst.a)); break; case IrCmd::FALLBACK_SETGLOBAL: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); regs.spill(build, index); - emitFallback(build, LOP_SETGLOBAL, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeSETGLOBAL), uintOp(inst.a)); break; case IrCmd::FALLBACK_GETTABLEKS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); @@ -1383,7 +1367,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); regs.spill(build, index); - emitFallback(build, LOP_GETTABLEKS, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeGETTABLEKS), uintOp(inst.a)); break; case IrCmd::FALLBACK_SETTABLEKS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); @@ -1391,7 +1375,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); regs.spill(build, index); - emitFallback(build, LOP_SETTABLEKS, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeSETTABLEKS), uintOp(inst.a)); break; case IrCmd::FALLBACK_NAMECALL: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); @@ -1399,38 +1383,38 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); regs.spill(build, index); - emitFallback(build, LOP_NAMECALL, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeNAMECALL), uintOp(inst.a)); break; case IrCmd::FALLBACK_PREPVARARGS: LUAU_ASSERT(inst.b.kind == IrOpKind::Constant); regs.spill(build, index); - emitFallback(build, LOP_PREPVARARGS, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executePREPVARARGS), uintOp(inst.a)); break; case IrCmd::FALLBACK_GETVARARGS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); regs.spill(build, index); - emitFallback(build, LOP_GETVARARGS, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeGETVARARGS), uintOp(inst.a)); break; case IrCmd::FALLBACK_NEWCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); regs.spill(build, index); - emitFallback(build, LOP_NEWCLOSURE, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeNEWCLOSURE), uintOp(inst.a)); break; case IrCmd::FALLBACK_DUPCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); regs.spill(build, index); - emitFallback(build, LOP_DUPCLOSURE, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeDUPCLOSURE), uintOp(inst.a)); break; case IrCmd::FALLBACK_FORGPREP: regs.spill(build, index); - emitFallback(build, LOP_FORGPREP, uintOp(inst.a)); + emitFallback(build, offsetof(NativeContext, executeFORGPREP), uintOp(inst.a)); jumpOrFallthrough(blockOp(inst.c), next); break; diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 8c1f2b0..035cc05 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -938,8 +938,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { ScopedRegX64 tmp{regs, SizeX64::dword}; - build.mov(tmp.reg, dword[regOp(inst.a) + offsetof(LuaNode, key) + kOffsetOfTKeyNext]); - build.shr(tmp.reg, kNextBitOffset); + build.mov(tmp.reg, dword[regOp(inst.a) + offsetof(LuaNode, key) + kOffsetOfTKeyTagNext]); + build.shr(tmp.reg, kTKeyTagBits); build.jcc(ConditionX64::NotZero, labelOp(inst.b)); break; } @@ -1098,60 +1098,60 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); - emitFallback(regs, build, data, LOP_GETGLOBAL, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeGETGLOBAL), uintOp(inst.a)); break; case IrCmd::FALLBACK_SETGLOBAL: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); - emitFallback(regs, build, data, LOP_SETGLOBAL, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeSETGLOBAL), uintOp(inst.a)); break; case IrCmd::FALLBACK_GETTABLEKS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); - emitFallback(regs, build, data, LOP_GETTABLEKS, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeGETTABLEKS), uintOp(inst.a)); break; case IrCmd::FALLBACK_SETTABLEKS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); - emitFallback(regs, build, data, LOP_SETTABLEKS, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeSETTABLEKS), uintOp(inst.a)); break; case IrCmd::FALLBACK_NAMECALL: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); - emitFallback(regs, build, data, LOP_NAMECALL, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeNAMECALL), uintOp(inst.a)); break; case IrCmd::FALLBACK_PREPVARARGS: LUAU_ASSERT(inst.b.kind == IrOpKind::Constant); - emitFallback(regs, build, data, LOP_PREPVARARGS, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executePREPVARARGS), uintOp(inst.a)); break; case IrCmd::FALLBACK_GETVARARGS: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); - emitFallback(regs, build, data, LOP_GETVARARGS, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeGETVARARGS), uintOp(inst.a)); break; case IrCmd::FALLBACK_NEWCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); - emitFallback(regs, build, data, LOP_NEWCLOSURE, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeNEWCLOSURE), uintOp(inst.a)); break; case IrCmd::FALLBACK_DUPCLOSURE: LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); - emitFallback(regs, build, data, LOP_DUPCLOSURE, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeDUPCLOSURE), uintOp(inst.a)); break; case IrCmd::FALLBACK_FORGPREP: - emitFallback(regs, build, data, LOP_FORGPREP, uintOp(inst.a)); + emitFallback(regs, build, offsetof(NativeContext, executeFORGPREP), uintOp(inst.a)); jumpOrFallthrough(blockOp(inst.c), next); break; case IrCmd::BITAND_UINT: diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index cb128de..bda4688 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -5,7 +5,6 @@ #include "CodeGenUtils.h" #include "CustomExecUtils.h" -#include "Fallbacks.h" #include "lbuiltins.h" #include "lgc.h" @@ -16,8 +15,6 @@ #include #include -#define CODEGEN_SET_FALLBACK(op) data.context.fallback[op] = {execute_##op} - namespace Luau { namespace CodeGen @@ -33,27 +30,7 @@ NativeState::NativeState() NativeState::~NativeState() = default; -void initFallbackTable(NativeState& data) -{ - // When fallback is completely removed, remove it from includeInsts list in lvmexecute_split.py - CODEGEN_SET_FALLBACK(LOP_NEWCLOSURE); - CODEGEN_SET_FALLBACK(LOP_NAMECALL); - CODEGEN_SET_FALLBACK(LOP_FORGPREP); - CODEGEN_SET_FALLBACK(LOP_GETVARARGS); - CODEGEN_SET_FALLBACK(LOP_DUPCLOSURE); - CODEGEN_SET_FALLBACK(LOP_PREPVARARGS); - CODEGEN_SET_FALLBACK(LOP_BREAK); - CODEGEN_SET_FALLBACK(LOP_SETLIST); - - // Fallbacks that are called from partial implementation of an instruction - // TODO: these fallbacks should be replaced with special functions that exclude the (redundantly executed) fast path from the fallback - CODEGEN_SET_FALLBACK(LOP_GETGLOBAL); - CODEGEN_SET_FALLBACK(LOP_SETGLOBAL); - CODEGEN_SET_FALLBACK(LOP_GETTABLEKS); - CODEGEN_SET_FALLBACK(LOP_SETTABLEKS); -} - -void initHelperFunctions(NativeState& data) +void initFunctions(NativeState& data) { static_assert(sizeof(data.context.luauF_table) == sizeof(luauF_table), "fastcall tables are not of the same length"); memcpy(data.context.luauF_table, luauF_table, sizeof(luauF_table)); @@ -115,6 +92,19 @@ void initHelperFunctions(NativeState& data) data.context.callFallback = callFallback; data.context.returnFallback = returnFallback; + + data.context.executeGETGLOBAL = executeGETGLOBAL; + data.context.executeSETGLOBAL = executeSETGLOBAL; + data.context.executeGETTABLEKS = executeGETTABLEKS; + data.context.executeSETTABLEKS = executeSETTABLEKS; + + data.context.executeNEWCLOSURE = executeNEWCLOSURE; + data.context.executeNAMECALL = executeNAMECALL; + data.context.executeFORGPREP = executeFORGPREP; + data.context.executeGETVARARGS = executeGETVARARGS; + data.context.executeDUPCLOSURE = executeDUPCLOSURE; + data.context.executePREPVARARGS = executePREPVARARGS; + data.context.executeSETLIST = executeSETLIST; } } // namespace CodeGen diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index eb1d97a..40017e3 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -23,19 +23,6 @@ namespace CodeGen class UnwindBuilder; -using FallbackFn = const Instruction* (*)(lua_State* L, const Instruction* pc, StkId base, TValue* k); - -struct NativeProto -{ - // This array is stored before NativeProto in reverse order, so to get offset of instruction i you need to index instOffsets[-i] - // This awkward layout is helpful for maximally efficient address computation on X64/A64 - uint32_t instOffsets[1]; - - uintptr_t instBase = 0; - uintptr_t entryTarget = 0; // = instOffsets[0] + instBase - Proto* proto = nullptr; -}; - struct NativeContext { // Gateway (C => native transition) entry & exit, compiled at runtime @@ -102,7 +89,17 @@ struct NativeContext Closure* (*returnFallback)(lua_State* L, StkId ra, StkId valend) = nullptr; // Opcode fallbacks, implemented in C - FallbackFn fallback[LOP__COUNT] = {}; + const Instruction* (*executeGETGLOBAL)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeSETGLOBAL)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeGETTABLEKS)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeSETTABLEKS)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeNEWCLOSURE)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeNAMECALL)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeSETLIST)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeFORGPREP)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeGETVARARGS)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executeDUPCLOSURE)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; + const Instruction* (*executePREPVARARGS)(lua_State* L, const Instruction* pc, StkId base, TValue* k) = nullptr; // Fast call methods, implemented in C luau_FastFunction luauF_table[256] = {}; @@ -124,8 +121,7 @@ struct NativeState NativeContext context; }; -void initFallbackTable(NativeState& data); -void initHelperFunctions(NativeState& data); +void initFunctions(NativeState& data); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 926ead3..8bb3cd7 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -714,10 +714,23 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::DUP_TABLE: case IrCmd::TRY_NUM_TO_INDEX: case IrCmd::TRY_CALL_FASTGETTM: + break; case IrCmd::INT_TO_NUM: case IrCmd::UINT_TO_NUM: + state.substituteOrRecord(inst, index); + break; case IrCmd::NUM_TO_INT: + if (IrInst* src = function.asInstOp(inst.a); src && src->cmd == IrCmd::INT_TO_NUM) + substitute(function, inst, src->a); + else + state.substituteOrRecord(inst, index); + break; case IrCmd::NUM_TO_UINT: + if (IrInst* src = function.asInstOp(inst.a); src && src->cmd == IrCmd::UINT_TO_NUM) + substitute(function, inst, src->a); + else + state.substituteOrRecord(inst, index); + break; case IrCmd::CHECK_ARRAY_SIZE: case IrCmd::CHECK_SLOT_MATCH: case IrCmd::CHECK_NODE_NO_NEXT: diff --git a/Makefile b/Makefile index aead3d3..99eb93e 100644 --- a/Makefile +++ b/Makefile @@ -136,6 +136,7 @@ $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/ $(TESTS_TARGET): LDFLAGS+=-lpthread $(REPL_CLI_TARGET): LDFLAGS+=-lpthread +$(ANALYZE_CLI_TARGET): LDFLAGS+=-lpthread fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a $(LPROTOBUF) # pseudo targets diff --git a/Sources.cmake b/Sources.cmake index 2b51721..892b889 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -92,7 +92,6 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/EmitBuiltinsX64.cpp CodeGen/src/EmitCommonX64.cpp CodeGen/src/EmitInstructionX64.cpp - CodeGen/src/Fallbacks.cpp CodeGen/src/IrAnalysis.cpp CodeGen/src/IrBuilder.cpp CodeGen/src/IrCallWrapperX64.cpp @@ -123,8 +122,6 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/EmitCommonA64.h CodeGen/src/EmitCommonX64.h CodeGen/src/EmitInstructionX64.h - CodeGen/src/Fallbacks.h - CodeGen/src/FallbacksProlog.h CodeGen/src/IrLoweringA64.h CodeGen/src/IrLoweringX64.h CodeGen/src/IrRegAllocA64.h @@ -171,6 +168,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/RecursionCounter.h Analysis/include/Luau/RequireTracer.h Analysis/include/Luau/Scope.h + Analysis/include/Luau/Simplify.h Analysis/include/Luau/Substitution.h Analysis/include/Luau/Symbol.h Analysis/include/Luau/ToDot.h @@ -185,7 +183,6 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypeFamily.h Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypePack.h - Analysis/include/Luau/TypeReduction.h Analysis/include/Luau/TypeUtils.h Analysis/include/Luau/Type.h Analysis/include/Luau/Unifiable.h @@ -222,6 +219,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Quantify.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp + Analysis/src/Simplify.cpp Analysis/src/Substitution.cpp Analysis/src/Symbol.cpp Analysis/src/ToDot.cpp @@ -236,7 +234,6 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/TypeFamily.cpp Analysis/src/TypeInfer.cpp Analysis/src/TypePack.cpp - Analysis/src/TypeReduction.cpp Analysis/src/TypeUtils.cpp Analysis/src/Type.cpp Analysis/src/Unifiable.cpp @@ -380,6 +377,7 @@ if(TARGET Luau.UnitTest) tests/Parser.test.cpp tests/RequireTracer.test.cpp tests/RuntimeLimits.test.cpp + tests/Simplify.test.cpp tests/StringUtils.test.cpp tests/Symbol.test.cpp tests/ToDot.test.cpp @@ -414,7 +412,6 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.unionTypes.test.cpp tests/TypeInfer.unknownnever.test.cpp tests/TypePack.test.cpp - tests/TypeReduction.test.cpp tests/TypeVar.test.cpp tests/Variant.test.cpp tests/VisitType.test.cpp diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 0f4df67..7f58d96 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,6 +17,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauUniformTopHandling, false) + /* ** {====================================================== ** Error-recovery functions @@ -229,12 +231,14 @@ void luaD_checkCstack(lua_State* L) ** When returns, all the results are on the stack, starting at the original ** function position. */ -void luaD_call(lua_State* L, StkId func, int nResults) +void luaD_call(lua_State* L, StkId func, int nresults) { if (++L->nCcalls >= LUAI_MAXCCALLS) luaD_checkCstack(L); - if (luau_precall(L, func, nResults) == PCRLUA) + ptrdiff_t old_func = savestack(L, func); + + if (luau_precall(L, func, nresults) == PCRLUA) { // is a Lua function? L->ci->flags |= LUA_CALLINFO_RETURN; // luau_execute will stop after returning from the stack frame @@ -248,6 +252,9 @@ void luaD_call(lua_State* L, StkId func, int nResults) L->isactive = false; } + if (FFlag::LuauUniformTopHandling && nresults != LUA_MULTRET) + L->top = restorestack(L, old_func) + nresults; + L->nCcalls--; luaC_checkGC(L); } diff --git a/VM/src/ldo.h b/VM/src/ldo.h index eac9927..0f7b42a 100644 --- a/VM/src/ldo.h +++ b/VM/src/ldo.h @@ -44,7 +44,7 @@ typedef void (*Pfunc)(lua_State* L, void* ud); LUAI_FUNC CallInfo* luaD_growCI(lua_State* L); -LUAI_FUNC void luaD_call(lua_State* L, StkId func, int nResults); +LUAI_FUNC void luaD_call(lua_State* L, StkId func, int nresults); LUAI_FUNC int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t oldtop, ptrdiff_t ef); LUAI_FUNC void luaD_reallocCI(lua_State* L, int newsize); LUAI_FUNC void luaD_reallocstack(lua_State* L, int newsize); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 2230a74..569c1b4 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -32,9 +32,8 @@ Proto* luaF_newproto(lua_State* L) f->debugname = NULL; f->debuginsn = NULL; -#if LUA_CUSTOM_EXECUTION f->execdata = NULL; -#endif + f->exectarget = 0; return f; } diff --git a/VM/src/lobject.h b/VM/src/lobject.h index f0471c2..21b8de0 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -275,9 +275,8 @@ typedef struct Proto TString* debugname; uint8_t* debuginsn; // a copy of code[] array with just opcodes -#if LUA_CUSTOM_EXECUTION void* execdata; -#endif + uintptr_t exectarget; GCObject* gclist; diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 32a240b..ae1e186 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -69,6 +69,7 @@ typedef struct CallInfo #define LUA_CALLINFO_RETURN (1 << 0) // should the interpreter return after returning from this callinfo? first frame must have this set #define LUA_CALLINFO_HANDLE (1 << 1) // should the error thrown during execution get handled by continuation from this callinfo? func must be C +#define LUA_CALLINFO_CUSTOM (1 << 2) // should this function be executed using custom execution callback #define curr_func(L) (clvalue(L->ci->func)) #define ci_func(ci) (clvalue((ci)->func)) diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 5565bfe..454a4e1 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,6 +16,8 @@ #include +LUAU_FASTFLAG(LuauUniformTopHandling) + // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -208,10 +210,11 @@ static void luau_execute(lua_State* L) LUAU_ASSERT(!isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black #if LUA_CUSTOM_EXECUTION - Proto* p = clvalue(L->ci->func)->l.p; - - if (p->execdata && !SingleStep) + if ((L->ci->flags & LUA_CALLINFO_CUSTOM) && !SingleStep) { + Proto* p = clvalue(L->ci->func)->l.p; + LUAU_ASSERT(p->execdata); + if (L->global->ecb.enter(L, p) == 0) return; } @@ -448,7 +451,7 @@ reentry: LUAU_ASSERT(ttisstring(kv)); // fast-path: built-in table - if (ttistable(rb)) + if (LUAU_LIKELY(ttistable(rb))) { Table* h = hvalue(rb); @@ -565,7 +568,7 @@ reentry: LUAU_ASSERT(ttisstring(kv)); // fast-path: built-in table - if (ttistable(rb)) + if (LUAU_LIKELY(ttistable(rb))) { Table* h = hvalue(rb); @@ -801,7 +804,7 @@ reentry: TValue* kv = VM_KV(aux); LUAU_ASSERT(ttisstring(kv)); - if (ttistable(rb)) + if (LUAU_LIKELY(ttistable(rb))) { Table* h = hvalue(rb); // note: we can't use nodemask8 here because we need to query the main position of the table, and 8-bit nodemask8 only works @@ -954,6 +957,7 @@ reentry: #if LUA_CUSTOM_EXECUTION if (LUAU_UNLIKELY(p->execdata && !SingleStep)) { + ci->flags = LUA_CALLINFO_CUSTOM; ci->savedpc = p->code; if (L->global->ecb.enter(L, p) == 1) @@ -1040,7 +1044,8 @@ reentry: // we're done! if (LUAU_UNLIKELY(ci->flags & LUA_CALLINFO_RETURN)) { - L->top = res; + if (!FFlag::LuauUniformTopHandling) + L->top = res; goto exit; } @@ -1050,7 +1055,7 @@ reentry: Proto* nextproto = nextcl->l.p; #if LUA_CUSTOM_EXECUTION - if (LUAU_UNLIKELY(nextproto->execdata && !SingleStep)) + if (LUAU_UNLIKELY((cip->flags & LUA_CALLINFO_CUSTOM) && !SingleStep)) { if (L->global->ecb.enter(L, nextproto) == 1) goto reentry; @@ -1333,7 +1338,7 @@ reentry: // fast-path: number // Note that all jumps below jump by 1 in the "false" case to skip over aux - if (ttisnumber(ra) && ttisnumber(rb)) + if (LUAU_LIKELY(ttisnumber(ra) && ttisnumber(rb))) { pc += nvalue(ra) <= nvalue(rb) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -1366,7 +1371,7 @@ reentry: // fast-path: number // Note that all jumps below jump by 1 in the "true" case to skip over aux - if (ttisnumber(ra) && ttisnumber(rb)) + if (LUAU_LIKELY(ttisnumber(ra) && ttisnumber(rb))) { pc += !(nvalue(ra) <= nvalue(rb)) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -1399,7 +1404,7 @@ reentry: // fast-path: number // Note that all jumps below jump by 1 in the "false" case to skip over aux - if (ttisnumber(ra) && ttisnumber(rb)) + if (LUAU_LIKELY(ttisnumber(ra) && ttisnumber(rb))) { pc += nvalue(ra) < nvalue(rb) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -1432,7 +1437,7 @@ reentry: // fast-path: number // Note that all jumps below jump by 1 in the "true" case to skip over aux - if (ttisnumber(ra) && ttisnumber(rb)) + if (LUAU_LIKELY(ttisnumber(ra) && ttisnumber(rb))) { pc += !(nvalue(ra) < nvalue(rb)) ? LUAU_INSN_D(insn) : 1; LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -1464,7 +1469,7 @@ reentry: StkId rc = VM_REG(LUAU_INSN_C(insn)); // fast-path - if (ttisnumber(rb) && ttisnumber(rc)) + if (LUAU_LIKELY(ttisnumber(rb) && ttisnumber(rc))) { setnvalue(ra, nvalue(rb) + nvalue(rc)); VM_NEXT(); @@ -1510,7 +1515,7 @@ reentry: StkId rc = VM_REG(LUAU_INSN_C(insn)); // fast-path - if (ttisnumber(rb) && ttisnumber(rc)) + if (LUAU_LIKELY(ttisnumber(rb) && ttisnumber(rc))) { setnvalue(ra, nvalue(rb) - nvalue(rc)); VM_NEXT(); @@ -1556,7 +1561,7 @@ reentry: StkId rc = VM_REG(LUAU_INSN_C(insn)); // fast-path - if (ttisnumber(rb) && ttisnumber(rc)) + if (LUAU_LIKELY(ttisnumber(rb) && ttisnumber(rc))) { setnvalue(ra, nvalue(rb) * nvalue(rc)); VM_NEXT(); @@ -1617,7 +1622,7 @@ reentry: StkId rc = VM_REG(LUAU_INSN_C(insn)); // fast-path - if (ttisnumber(rb) && ttisnumber(rc)) + if (LUAU_LIKELY(ttisnumber(rb) && ttisnumber(rc))) { setnvalue(ra, nvalue(rb) / nvalue(rc)); VM_NEXT(); @@ -1764,7 +1769,7 @@ reentry: TValue* kv = VM_KV(LUAU_INSN_C(insn)); // fast-path - if (ttisnumber(rb)) + if (LUAU_LIKELY(ttisnumber(rb))) { setnvalue(ra, nvalue(rb) * nvalue(kv)); VM_NEXT(); @@ -1810,7 +1815,7 @@ reentry: TValue* kv = VM_KV(LUAU_INSN_C(insn)); // fast-path - if (ttisnumber(rb)) + if (LUAU_LIKELY(ttisnumber(rb))) { setnvalue(ra, nvalue(rb) / nvalue(kv)); VM_NEXT(); @@ -1976,7 +1981,7 @@ reentry: StkId rb = VM_REG(LUAU_INSN_B(insn)); // fast-path - if (ttisnumber(rb)) + if (LUAU_LIKELY(ttisnumber(rb))) { setnvalue(ra, -nvalue(rb)); VM_NEXT(); @@ -2019,7 +2024,7 @@ reentry: StkId rb = VM_REG(LUAU_INSN_B(insn)); // fast-path #1: tables - if (ttistable(rb)) + if (LUAU_LIKELY(ttistable(rb))) { Table* h = hvalue(rb); @@ -2878,14 +2883,21 @@ int luau_precall(lua_State* L, StkId func, int nresults) if (!ccl->isC) { + Proto* p = ccl->l.p; + // fill unused parameters with nil StkId argi = L->top; - StkId argend = L->base + ccl->l.p->numparams; + StkId argend = L->base + p->numparams; while (argi < argend) setnilvalue(argi++); // complete missing arguments - L->top = ccl->l.p->is_vararg ? argi : ci->top; + L->top = p->is_vararg ? argi : ci->top; - L->ci->savedpc = ccl->l.p->code; + ci->savedpc = p->code; + +#if LUA_CUSTOM_EXECUTION + if (p->execdata) + ci->flags = LUA_CALLINFO_CUSTOM; +#endif return PCRLUA; } diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index 3827681..cdadfd7 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -135,6 +135,19 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "BinaryImm") SINGLE_COMPARE(ror(x1, x2, 1), 0x93C20441); } +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Bitfield") +{ + SINGLE_COMPARE(ubfiz(x1, x2, 37, 5), 0xD35B1041); + SINGLE_COMPARE(ubfx(x1, x2, 37, 5), 0xD365A441); + SINGLE_COMPARE(sbfiz(x1, x2, 37, 5), 0x935B1041); + SINGLE_COMPARE(sbfx(x1, x2, 37, 5), 0x9365A441); + + SINGLE_COMPARE(ubfiz(w1, w2, 17, 5), 0x530F1041); + SINGLE_COMPARE(ubfx(w1, w2, 17, 5), 0x53115441); + SINGLE_COMPARE(sbfiz(w1, w2, 17, 5), 0x130F1041); + SINGLE_COMPARE(sbfx(w1, w2, 17, 5), 0x13115441); +} + TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Loads") { // address forms @@ -481,6 +494,8 @@ TEST_CASE("LogTest") build.fcvt(s1, d2); + build.ubfx(x1, x2, 37, 5); + build.setLabel(l); build.ret(); @@ -513,6 +528,7 @@ TEST_CASE("LogTest") fmov d0,#0.25 tbz x0,#5,.L1 fcvt s1,d2 + ubfx x1,x2,#3705 .L1: ret )"; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index cf92843..d66eb18 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -3388,38 +3388,6 @@ TEST_CASE_FIXTURE(ACFixture, "globals_are_order_independent") CHECK(ac.entryMap.count("abc1")); } -TEST_CASE_FIXTURE(ACFixture, "type_reduction_is_hooked_up_to_autocomplete") -{ - ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; - - check(R"( - type T = { x: (number & string)? } - - function f(thingamabob: T) - thingamabob.@1 - end - - function g(thingamabob: T) - thingama@2 - end - )"); - - ToStringOptions opts; - opts.exhaustive = true; - - auto ac1 = autocomplete('1'); - REQUIRE(ac1.entryMap.count("x")); - std::optional ty1 = ac1.entryMap.at("x").type; - REQUIRE(ty1); - CHECK("nil" == toString(*ty1, opts)); - - auto ac2 = autocomplete('2'); - REQUIRE(ac2.entryMap.count("thingamabob")); - std::optional ty2 = ac2.entryMap.at("thingamabob").type; - REQUIRE(ty2); - CHECK("{| x: nil |}" == toString(*ty2, opts)); -} - TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") { loadDefinition(R"( @@ -3490,8 +3458,6 @@ local c = b.@1 TEST_CASE_FIXTURE(ACFixture, "suggest_exported_types") { - ScopedFastFlag luauCopyExportedTypes{"LuauCopyExportedTypes", true}; - check(R"( export type Type = {a: number} local a: T@1 diff --git a/tests/ClassFixture.cpp b/tests/ClassFixture.cpp index 9174051..5e28e8d 100644 --- a/tests/ClassFixture.cpp +++ b/tests/ClassFixture.cpp @@ -14,6 +14,7 @@ ClassFixture::ClassFixture() GlobalTypes& globals = frontend.globals; TypeArena& arena = globals.globalTypes; TypeId numberType = builtinTypes->numberType; + TypeId stringType = builtinTypes->stringType; unfreeze(arena); @@ -35,7 +36,7 @@ ClassFixture::ClassFixture() TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(childClassInstanceType)->props = { - {"Method", {makeFunction(arena, childClassInstanceType, {}, {builtinTypes->stringType})}}, + {"Method", {makeFunction(arena, childClassInstanceType, {}, {stringType})}}, }; TypeId childClassType = arena.addType(ClassType{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test"}); @@ -48,7 +49,7 @@ ClassFixture::ClassFixture() TypeId grandChildInstanceType = arena.addType(ClassType{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(grandChildInstanceType)->props = { - {"Method", {makeFunction(arena, grandChildInstanceType, {}, {builtinTypes->stringType})}}, + {"Method", {makeFunction(arena, grandChildInstanceType, {}, {stringType})}}, }; TypeId grandChildType = arena.addType(ClassType{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test"}); @@ -61,7 +62,7 @@ ClassFixture::ClassFixture() TypeId anotherChildInstanceType = arena.addType(ClassType{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(anotherChildInstanceType)->props = { - {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {builtinTypes->stringType})}}, + {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {stringType})}}, }; TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test"}); @@ -101,7 +102,7 @@ ClassFixture::ClassFixture() TypeId callableClassMetaType = arena.addType(TableType{}); TypeId callableClassType = arena.addType(ClassType{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test"}); getMutable(callableClassMetaType)->props = { - {"__call", {makeFunction(arena, nullopt, {callableClassType, builtinTypes->stringType}, {builtinTypes->numberType})}}, + {"__call", {makeFunction(arena, nullopt, {callableClassType, stringType}, {numberType})}}, }; globals.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType}; @@ -114,7 +115,7 @@ ClassFixture::ClassFixture() }; // IndexableClass has a table indexer with a key type of 'number | string' and a return type of 'number' - addIndexableClass("IndexableClass", arena.addType(Luau::UnionType{{builtinTypes->stringType, numberType}}), numberType); + addIndexableClass("IndexableClass", arena.addType(Luau::UnionType{{stringType, numberType}}), numberType); // IndexableNumericKeyClass has a table indexer with a key type of 'number' and a return type of 'number' addIndexableClass("IndexableNumericKeyClass", numberType, numberType); diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index 7b93398..6bfb159 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "ConstraintGraphBuilderFixture.h" -#include "Luau/TypeReduction.h" - namespace Luau { @@ -13,7 +11,6 @@ ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() { mainModule->name = "MainModule"; mainModule->humanReadableName = "MainModule"; - mainModule->reduction = std::make_unique(NotNull{&mainModule->internalTypes}, builtinTypes, NotNull{&ice}); BlockedType::DEPRECATED_nextIndex = 0; BlockedTypePack::nextIndex = 0; diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index e1213b9..2f5fbf1 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -1521,6 +1521,36 @@ bb_3: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "IntNumIntPeepholes") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp i1 = build.inst(IrCmd::LOAD_INT, build.vmReg(0)); + IrOp u1 = build.inst(IrCmd::LOAD_INT, build.vmReg(1)); + IrOp ni1 = build.inst(IrCmd::INT_TO_NUM, i1); + IrOp nu1 = build.inst(IrCmd::UINT_TO_NUM, u1); + IrOp i2 = build.inst(IrCmd::NUM_TO_INT, ni1); + IrOp u2 = build.inst(IrCmd::NUM_TO_UINT, nu1); + build.inst(IrCmd::STORE_INT, build.vmReg(0), i2); + build.inst(IrCmd::STORE_INT, build.vmReg(1), u2); + build.inst(IrCmd::RETURN, build.constUint(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_INT R0 + %1 = LOAD_INT R1 + STORE_INT R0, %0 + STORE_INT R1, %1 + RETURN 2u + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("LinearExecutionFlowExtraction"); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 22530a2..abdfea7 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -350,6 +350,35 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); } +// Unions should never be cyclic, but we should clone them correctly even if +// they are. +TEST_CASE_FIXTURE(Fixture, "clone_cyclic_union") +{ + ScopedFastFlag sff{"LuauCloneCyclicUnions", true}; + + TypeArena src; + + TypeId u = src.addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}}); + UnionType* uu = getMutable(u); + REQUIRE(uu); + + uu->options.push_back(u); + + TypeArena dest; + CloneState cloneState; + + TypeId cloned = clone(u, dest, cloneState); + REQUIRE(cloned); + + const UnionType* clonedUnion = get(cloned); + REQUIRE(clonedUnion); + REQUIRE(3 == clonedUnion->options.size()); + + CHECK(builtinTypes->numberType == clonedUnion->options[0]); + CHECK(builtinTypes->stringType == clonedUnion->options[1]); + CHECK(cloned == clonedUnion->options[2]); +} + TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") { ScopedFastFlag flags[] = { diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 26b3b00..93ea751 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -494,7 +494,7 @@ struct NormalizeFixture : Fixture REQUIRE(node); AstStatTypeAlias* alias = node->as(); REQUIRE(alias); - TypeId* originalTy = getMainModule()->astOriginalResolvedTypes.find(alias->type); + TypeId* originalTy = getMainModule()->astResolvedTypes.find(alias->type); REQUIRE(originalTy); return normalizer.normalize(*originalTy); } @@ -732,15 +732,11 @@ TEST_CASE_FIXTURE(NormalizeFixture, "narrow_union_of_classes_with_intersection") TEST_CASE_FIXTURE(NormalizeFixture, "intersection_of_metatables_where_the_metatable_is_top_or_bottom") { - ScopedFastFlag sff{"LuauNormalizeMetatableFixes", true}; - CHECK("{ @metatable *error-type*, {| |} }" == toString(normal("Mt<{}, any> & Mt<{}, err>"))); } TEST_CASE_FIXTURE(NormalizeFixture, "crazy_metatable") { - ScopedFastFlag sff{"LuauNormalizeMetatableFixes", true}; - CHECK("never" == toString(normal("Mt<{}, number> & Mt<{}, string>"))); } diff --git a/tests/Simplify.test.cpp b/tests/Simplify.test.cpp new file mode 100644 index 0000000..2052019 --- /dev/null +++ b/tests/Simplify.test.cpp @@ -0,0 +1,508 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" + +#include "Luau/Simplify.h" + +using namespace Luau; + +namespace +{ + +struct SimplifyFixture : Fixture +{ + TypeArena _arena; + const NotNull arena{&_arena}; + + ToStringOptions opts; + + Scope scope{builtinTypes->anyTypePack}; + + const TypeId anyTy = builtinTypes->anyType; + const TypeId unknownTy = builtinTypes->unknownType; + const TypeId neverTy = builtinTypes->neverType; + const TypeId errorTy = builtinTypes->errorType; + + const TypeId functionTy = builtinTypes->functionType; + const TypeId tableTy = builtinTypes->tableType; + + const TypeId numberTy = builtinTypes->numberType; + const TypeId stringTy = builtinTypes->stringType; + const TypeId booleanTy = builtinTypes->booleanType; + const TypeId nilTy = builtinTypes->nilType; + const TypeId threadTy = builtinTypes->threadType; + + const TypeId classTy = builtinTypes->classType; + + const TypeId trueTy = builtinTypes->trueType; + const TypeId falseTy = builtinTypes->falseType; + + const TypeId truthyTy = builtinTypes->truthyType; + const TypeId falsyTy = builtinTypes->falsyType; + + const TypeId freeTy = arena->addType(FreeType{&scope}); + const TypeId genericTy = arena->addType(GenericType{}); + const TypeId blockedTy = arena->addType(BlockedType{}); + const TypeId pendingTy = arena->addType(PendingExpansionType{{}, {}, {}, {}}); + + const TypeId helloTy = arena->addType(SingletonType{StringSingleton{"hello"}}); + const TypeId worldTy = arena->addType(SingletonType{StringSingleton{"world"}}); + + const TypePackId emptyTypePack = arena->addTypePack({}); + + const TypeId fn1Ty = arena->addType(FunctionType{emptyTypePack, emptyTypePack}); + const TypeId fn2Ty = arena->addType(FunctionType{builtinTypes->anyTypePack, emptyTypePack}); + + TypeId parentClassTy = nullptr; + TypeId childClassTy = nullptr; + TypeId anotherChildClassTy = nullptr; + TypeId unrelatedClassTy = nullptr; + + SimplifyFixture() + { + createSomeClasses(&frontend); + + parentClassTy = frontend.globals.globalScope->linearSearchForBinding("Parent")->typeId; + childClassTy = frontend.globals.globalScope->linearSearchForBinding("Child")->typeId; + anotherChildClassTy = frontend.globals.globalScope->linearSearchForBinding("AnotherChild")->typeId; + unrelatedClassTy = frontend.globals.globalScope->linearSearchForBinding("Unrelated")->typeId; + } + + TypeId intersect(TypeId a, TypeId b) + { + return simplifyIntersection(builtinTypes, arena, a, b).result; + } + + std::string intersectStr(TypeId a, TypeId b) + { + return toString(intersect(a, b), opts); + } + + bool isIntersection(TypeId a) + { + return bool(get(follow(a))); + } + + TypeId mkTable(std::map propTypes) + { + TableType::Props props; + for (const auto& [name, ty] : propTypes) + props[name] = Property{ty}; + + return arena->addType(TableType{props, {}, TypeLevel{}, TableState::Sealed}); + } + + TypeId mkNegation(TypeId ty) + { + return arena->addType(NegationType{ty}); + } + + TypeId mkFunction(TypeId arg, TypeId ret) + { + return arena->addType(FunctionType{arena->addTypePack({arg}), arena->addTypePack({ret})}); + } + + TypeId union_(TypeId a, TypeId b) + { + return simplifyUnion(builtinTypes, arena, a, b).result; + } +}; + +} // namespace + +TEST_SUITE_BEGIN("Simplify"); + +TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_other_tops_and_bottom_types") +{ + CHECK(unknownTy == intersect(unknownTy, unknownTy)); + + CHECK(unknownTy == intersect(unknownTy, anyTy)); + CHECK(unknownTy == intersect(anyTy, unknownTy)); + + CHECK(neverTy == intersect(unknownTy, neverTy)); + CHECK(neverTy == intersect(neverTy, unknownTy)); + + CHECK(neverTy == intersect(unknownTy, errorTy)); + CHECK(neverTy == intersect(errorTy, unknownTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "nil") +{ + CHECK(nilTy == intersect(nilTy, nilTy)); + CHECK(neverTy == intersect(nilTy, numberTy)); + CHECK(neverTy == intersect(nilTy, trueTy)); + CHECK(neverTy == intersect(nilTy, tableTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "boolean_singletons") +{ + CHECK(trueTy == intersect(trueTy, booleanTy)); + CHECK(trueTy == intersect(booleanTy, trueTy)); + + CHECK(falseTy == intersect(falseTy, booleanTy)); + CHECK(falseTy == intersect(booleanTy, falseTy)); + + CHECK(neverTy == intersect(falseTy, trueTy)); + CHECK(neverTy == intersect(trueTy, falseTy)); + + CHECK(booleanTy == union_(trueTy, booleanTy)); + CHECK(booleanTy == union_(booleanTy, trueTy)); + CHECK(booleanTy == union_(falseTy, booleanTy)); + CHECK(booleanTy == union_(booleanTy, falseTy)); + CHECK(booleanTy == union_(falseTy, trueTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "boolean_and_truthy_and_falsy") +{ + TypeId optionalBooleanTy = arena->addType(UnionType{{booleanTy, nilTy}}); + + CHECK(trueTy == intersect(booleanTy, truthyTy)); + + CHECK(trueTy == intersect(optionalBooleanTy, truthyTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "any_and_indeterminate_types") +{ + CHECK("a" == intersectStr(anyTy, freeTy)); + CHECK("a" == intersectStr(freeTy, anyTy)); + + CHECK("b" == intersectStr(anyTy, genericTy)); + CHECK("b" == intersectStr(genericTy, anyTy)); + + CHECK(blockedTy == intersect(anyTy, blockedTy)); + CHECK(blockedTy == intersect(blockedTy, anyTy)); + + CHECK(pendingTy == intersect(anyTy, pendingTy)); + CHECK(pendingTy == intersect(pendingTy, anyTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_indeterminate_types") +{ + CHECK(isIntersection(intersect(unknownTy, freeTy))); + CHECK(isIntersection(intersect(freeTy, unknownTy))); + + CHECK(isIntersection(intersect(unknownTy, genericTy))); + CHECK(isIntersection(intersect(genericTy, unknownTy))); + + CHECK(isIntersection(intersect(unknownTy, blockedTy))); + CHECK(isIntersection(intersect(blockedTy, unknownTy))); + + CHECK(isIntersection(intersect(unknownTy, pendingTy))); + CHECK(isIntersection(intersect(pendingTy, unknownTy))); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_concrete") +{ + CHECK(numberTy == intersect(numberTy, unknownTy)); + CHECK(numberTy == intersect(unknownTy, numberTy)); + CHECK(trueTy == intersect(trueTy, unknownTy)); + CHECK(trueTy == intersect(unknownTy, trueTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "error_and_other_tops_and_bottom_types") +{ + CHECK(errorTy == intersect(errorTy, errorTy)); + + CHECK(errorTy == intersect(errorTy, anyTy)); + CHECK(errorTy == intersect(anyTy, errorTy)); + + CHECK(neverTy == intersect(errorTy, neverTy)); + CHECK(neverTy == intersect(neverTy, errorTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "error_and_indeterminate_types") +{ + CHECK("*error-type* & a" == intersectStr(errorTy, freeTy)); + CHECK("*error-type* & a" == intersectStr(freeTy, errorTy)); + + CHECK("*error-type* & b" == intersectStr(errorTy, genericTy)); + CHECK("*error-type* & b" == intersectStr(genericTy, errorTy)); + + CHECK(isIntersection(intersect(errorTy, blockedTy))); + CHECK(isIntersection(intersect(blockedTy, errorTy))); + + CHECK(isIntersection(intersect(errorTy, pendingTy))); + CHECK(isIntersection(intersect(pendingTy, errorTy))); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "unknown_and_concrete") +{ + CHECK(neverTy == intersect(numberTy, errorTy)); + CHECK(neverTy == intersect(errorTy, numberTy)); + CHECK(neverTy == intersect(trueTy, errorTy)); + CHECK(neverTy == intersect(errorTy, trueTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "primitives") +{ + // This shouldn't be possible, but we'll make it work even if it is. + TypeId numberTyDuplicate = arena->addType(PrimitiveType{PrimitiveType::Number}); + + CHECK(numberTy == intersect(numberTy, numberTyDuplicate)); + CHECK(neverTy == intersect(numberTy, stringTy)); + + CHECK(neverTy == intersect(neverTy, numberTy)); + CHECK(neverTy == intersect(numberTy, neverTy)); + + CHECK(neverTy == intersect(neverTy, functionTy)); + CHECK(neverTy == intersect(functionTy, neverTy)); + + CHECK(neverTy == intersect(neverTy, tableTy)); + CHECK(neverTy == intersect(tableTy, neverTy)); + + CHECK(numberTy == intersect(anyTy, numberTy)); + CHECK(numberTy == intersect(numberTy, anyTy)); + + CHECK(neverTy == intersect(stringTy, nilTy)); + CHECK(neverTy == intersect(nilTy, stringTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "primitives_and_falsy") +{ + CHECK(neverTy == intersect(numberTy, falsyTy)); + CHECK(neverTy == intersect(falsyTy, numberTy)); + + CHECK(nilTy == intersect(nilTy, falsyTy)); + CHECK(nilTy == intersect(falsyTy, nilTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "primitives_and_singletons") +{ + CHECK(helloTy == intersect(helloTy, stringTy)); + CHECK(helloTy == intersect(stringTy, helloTy)); + + CHECK(neverTy == intersect(worldTy, helloTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "functions") +{ + CHECK(fn1Ty == intersect(fn1Ty, functionTy)); + CHECK(fn1Ty == intersect(functionTy, fn1Ty)); + + // Intersections of functions are super weird if you think about it. + CHECK("(() -> ()) & ((...any) -> ())" == intersectStr(fn1Ty, fn2Ty)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "negated_top_function_type") +{ + TypeId negatedFunctionTy = mkNegation(functionTy); + + CHECK(numberTy == intersect(numberTy, negatedFunctionTy)); + CHECK(numberTy == intersect(negatedFunctionTy, numberTy)); + + CHECK(falsyTy == intersect(falsyTy, negatedFunctionTy)); + CHECK(falsyTy == intersect(negatedFunctionTy, falsyTy)); + + TypeId f = mkFunction(stringTy, numberTy); + + CHECK(neverTy == intersect(f, negatedFunctionTy)); + CHECK(neverTy == intersect(negatedFunctionTy, f)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "optional_overloaded_function_and_top_function") +{ + // (((number) -> string) & ((string) -> number))? & ~function + + TypeId f1 = mkFunction(numberTy, stringTy); + TypeId f2 = mkFunction(stringTy, numberTy); + + TypeId f12 = arena->addType(IntersectionType{{f1, f2}}); + + TypeId t = arena->addType(UnionType{{f12, nilTy}}); + + TypeId notFunctionTy = mkNegation(functionTy); + + CHECK(nilTy == intersect(t, notFunctionTy)); + CHECK(nilTy == intersect(notFunctionTy, t)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "negated_function_does_not_intersect_cleanly_with_truthy") +{ + // ~function & ~(false?) + // ~function & ~(false | nil) + // ~function & ~false & ~nil + + TypeId negatedFunctionTy = mkNegation(functionTy); + CHECK(isIntersection(intersect(negatedFunctionTy, truthyTy))); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "tables") +{ + TypeId t1 = mkTable({{"tag", stringTy}}); + + CHECK(t1 == intersect(t1, tableTy)); + CHECK(neverTy == intersect(t1, functionTy)); + + TypeId t2 = mkTable({{"tag", helloTy}}); + + CHECK(t2 == intersect(t1, t2)); + CHECK(t2 == intersect(t2, t1)); + + TypeId t3 = mkTable({}); + + CHECK(t1 == intersect(t1, t3)); + CHECK(t1 == intersect(t3, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "tables_and_top_table") +{ + TypeId notTableType = mkNegation(tableTy); + TypeId t1 = mkTable({{"prop", stringTy}, {"another", numberTy}}); + + CHECK(t1 == intersect(t1, tableTy)); + CHECK(t1 == intersect(tableTy, t1)); + + CHECK(neverTy == intersect(t1, notTableType)); + CHECK(neverTy == intersect(notTableType, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "tables_and_truthy") +{ + TypeId t1 = mkTable({{"prop", stringTy}, {"another", numberTy}}); + + CHECK(t1 == intersect(t1, truthyTy)); + CHECK(t1 == intersect(truthyTy, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "table_with_a_tag") +{ + // {tag: string, prop: number} & {tag: "hello"} + // I think we can decline to simplify this: + TypeId t1 = mkTable({{"tag", stringTy}, {"prop", numberTy}}); + TypeId t2 = mkTable({{"tag", helloTy}}); + + CHECK("{| prop: number, tag: string |} & {| tag: \"hello\" |}" == intersectStr(t1, t2)); + CHECK("{| prop: number, tag: string |} & {| tag: \"hello\" |}" == intersectStr(t2, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "nested_table_tag_test") +{ + TypeId t1 = mkTable({ + {"subtable", mkTable({ + {"tag", helloTy}, + {"subprop", numberTy}, + })}, + {"prop", stringTy}, + }); + TypeId t2 = mkTable({ + {"subtable", mkTable({ + {"tag", helloTy}, + })}, + }); + + CHECK(t1 == intersect(t1, t2)); + CHECK(t1 == intersect(t2, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "union") +{ + TypeId t1 = arena->addType(UnionType{{numberTy, stringTy, nilTy, tableTy}}); + + CHECK(nilTy == intersect(t1, nilTy)); + // CHECK(nilTy == intersect(nilTy, t1)); // TODO? + + CHECK(builtinTypes->stringType == intersect(builtinTypes->optionalStringType, truthyTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "two_unions") +{ + TypeId t1 = arena->addType(UnionType{{numberTy, booleanTy, stringTy, nilTy, tableTy}}); + + CHECK("false?" == intersectStr(t1, falsyTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "curious_union") +{ + // (a & false) | (a & nil) + TypeId curious = + arena->addType(UnionType{{arena->addType(IntersectionType{{freeTy, falseTy}}), arena->addType(IntersectionType{{freeTy, nilTy}})}}); + + CHECK("(a & false) | (a & nil) | number" == toString(union_(curious, numberTy))); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "negations") +{ + TypeId notNumberTy = mkNegation(numberTy); + TypeId notStringTy = mkNegation(stringTy); + + CHECK(neverTy == intersect(numberTy, notNumberTy)); + + CHECK(numberTy == intersect(numberTy, notStringTy)); + CHECK(numberTy == intersect(notStringTy, numberTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "top_class_type") +{ + CHECK(neverTy == intersect(classTy, stringTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "classes") +{ + CHECK(childClassTy == intersect(childClassTy, parentClassTy)); + CHECK(childClassTy == intersect(parentClassTy, childClassTy)); + + CHECK(parentClassTy == union_(childClassTy, parentClassTy)); + CHECK(parentClassTy == union_(parentClassTy, childClassTy)); + + CHECK(neverTy == intersect(childClassTy, unrelatedClassTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "negations_of_classes") +{ + TypeId notChildClassTy = mkNegation(childClassTy); + TypeId notParentClassTy = mkNegation(parentClassTy); + + CHECK(neverTy == intersect(childClassTy, notParentClassTy)); + CHECK(neverTy == intersect(notParentClassTy, childClassTy)); + + CHECK("Parent & ~Child" == intersectStr(notChildClassTy, parentClassTy)); + CHECK("Parent & ~Child" == intersectStr(parentClassTy, notChildClassTy)); + + CHECK(notParentClassTy == intersect(notChildClassTy, notParentClassTy)); + CHECK(notParentClassTy == intersect(notParentClassTy, notChildClassTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "intersection_of_intersection_of_a_free_type_can_result_in_removal_of_that_free_type") +{ + // a & string and number + // (a & number) & (string & number) + + TypeId t1 = arena->addType(IntersectionType{{freeTy, stringTy}}); + + CHECK(neverTy == intersect(t1, numberTy)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "some_tables_are_really_never") +{ + TypeId notAnyTy = mkNegation(anyTy); + + TypeId t1 = mkTable({{"someKey", notAnyTy}}); + + CHECK(neverTy == intersect(t1, numberTy)); + CHECK(neverTy == intersect(numberTy, t1)); + CHECK(neverTy == intersect(t1, t1)); +} + +TEST_CASE_FIXTURE(SimplifyFixture, "simplify_stops_at_cycles") +{ + TypeId t = mkTable({}); + TableType* tt = getMutable(t); + REQUIRE(tt); + + TypeId t2 = mkTable({}); + TableType* t2t = getMutable(t2); + REQUIRE(t2t); + + tt->props["cyclic"] = Property{t2}; + t2t->props["cyclic"] = Property{t}; + + CHECK(t == intersect(t, anyTy)); + CHECK(t == intersect(anyTy, t)); + + CHECK(t2 == intersect(t2, anyTy)); + CHECK(t2 == intersect(anyTy, t2)); +} + +TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 160757e..39759c7 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -291,9 +291,9 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") { o.maxTypeLength = 30; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); } else { @@ -321,9 +321,9 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") { o.maxTypeLength = 30; CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~false & ~nil)... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> ((b) -> ((a) -> (() -> ()) | (a & ~(false?))... *TRUNCATED*"); } else { @@ -507,25 +507,25 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") CHECK_EQ("{ @metatable { __index: { @metatable {| __index: base |}, child } }, inst }", r.name); CHECK(0 == opts.nameMap.types.size()); - const MetatableType* tMeta = get(tType); + const MetatableType* tMeta = get(follow(tType)); REQUIRE(tMeta); - TableType* tMeta2 = getMutable(tMeta->metatable); + TableType* tMeta2 = getMutable(follow(tMeta->metatable)); REQUIRE(tMeta2); REQUIRE(tMeta2->props.count("__index")); - const MetatableType* tMeta3 = get(tMeta2->props["__index"].type()); + const MetatableType* tMeta3 = get(follow(tMeta2->props["__index"].type())); REQUIRE(tMeta3); - TableType* tMeta4 = getMutable(tMeta3->metatable); + TableType* tMeta4 = getMutable(follow(tMeta3->metatable)); REQUIRE(tMeta4); REQUIRE(tMeta4->props.count("__index")); - TableType* tMeta5 = getMutable(tMeta4->props["__index"].type()); + TableType* tMeta5 = getMutable(follow(tMeta4->props["__index"].type())); REQUIRE(tMeta5); REQUIRE(tMeta5->props.count("one") > 0); - TableType* tMeta6 = getMutable(tMeta3->table); + TableType* tMeta6 = getMutable(follow(tMeta3->table)); REQUIRE(tMeta6); REQUIRE(tMeta6->props.count("two") > 0); diff --git a/tests/TxnLog.test.cpp b/tests/TxnLog.test.cpp index 78ab064..bfd2976 100644 --- a/tests/TxnLog.test.cpp +++ b/tests/TxnLog.test.cpp @@ -25,6 +25,8 @@ struct TxnLogFixture TypeId a = arena.freshType(globalScope.get()); TypeId b = arena.freshType(globalScope.get()); TypeId c = arena.freshType(childScope.get()); + + TypeId g = arena.addType(GenericType{"G"}); }; TEST_SUITE_BEGIN("TxnLog"); @@ -110,4 +112,13 @@ TEST_CASE_FIXTURE(TxnLogFixture, "colliding_coincident_logs_do_not_create_degene CHECK("a" == toString(b)); } +TEST_CASE_FIXTURE(TxnLogFixture, "replacing_persistent_types_is_allowed_but_makes_the_log_radioactive") +{ + persist(g); + + log.replace(g, BoundType{a}); + + CHECK(log.radioactive); +} + TEST_SUITE_END(); diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index 9a101d9..b11b05d 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -20,7 +20,7 @@ struct FamilyFixture : Fixture swapFamily = TypeFamily{/* name */ "Swap", /* reducer */ [](std::vector tys, std::vector tps, NotNull arena, NotNull builtins, - NotNull log) -> TypeFamilyReductionResult { + NotNull log, NotNull scope, NotNull normalizer) -> TypeFamilyReductionResult { LUAU_ASSERT(tys.size() == 1); TypeId param = log->follow(tys.at(0)); @@ -78,18 +78,6 @@ TEST_CASE_FIXTURE(FamilyFixture, "basic_type_family") CHECK("Type family instance Swap is uninhabited" == toString(result.errors[0])); }; -TEST_CASE_FIXTURE(FamilyFixture, "type_reduction_reduces_families") -{ - if (!FFlag::DebugLuauDeferredConstraintResolution) - return; - - CheckResult result = check(R"( - local x: Swap & nil - )"); - - CHECK("never" == toString(requireType("x"))); -} - TEST_CASE_FIXTURE(FamilyFixture, "family_as_fn_ret") { if (!FFlag::DebugLuauDeferredConstraintResolution) @@ -202,4 +190,27 @@ TEST_CASE_FIXTURE(FamilyFixture, "function_internal_families") CHECK(toString(result.errors[0]) == "Type family instance Swap is uninhabited"); } +TEST_CASE_FIXTURE(Fixture, "add_family_at_work") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function add(a, b) + return a + b + end + + local a = add(1, 2) + local b = add(1, "foo") + local c = add("foo", 1) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(requireType("a")) == "number"); + CHECK(toString(requireType("b")) == "Add"); + CHECK(toString(requireType("c")) == "Add"); + CHECK(toString(result.errors[0]) == "Type family instance Add is uninhabited"); + CHECK(toString(result.errors[1]) == "Type family instance Add is uninhabited"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 84b057d..4feb3a6 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -736,6 +736,18 @@ TEST_CASE_FIXTURE(Fixture, "luau_print_is_not_special_without_the_flag") LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "luau_print_incomplete") +{ + ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; + + CheckResult result = check(R"( + local a: _luau_print + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("_luau_print requires one generic parameter", toString(result.errors[0])); +} + TEST_CASE_FIXTURE(Fixture, "instantiate_type_fun_should_not_trip_rbxassert") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.cfa.test.cpp b/tests/TypeInfer.cfa.test.cpp index 7374295..04aeb54 100644 --- a/tests/TypeInfer.cfa.test.cpp +++ b/tests/TypeInfer.cfa.test.cpp @@ -352,10 +352,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions") CHECK_EQ("\"err\"", toString(requireTypeAtPosition({13, 31}))); CHECK_EQ("E", toString(requireTypeAtPosition({14, 31}))); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("{| error: E, tag: \"err\" |}", toString(requireTypeAtPosition({16, 19}))); - else - CHECK_EQ("Err", toString(requireTypeAtPosition({16, 19}))); + CHECK_EQ("Err", toString(requireTypeAtPosition({16, 19}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "do_assert_x") diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 607fc40..d9e4bba 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -552,6 +552,8 @@ TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") local x : IndexableClass local y = x[true] )"); + + CHECK_EQ( toString(result.errors[0]), "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible"); } @@ -560,6 +562,7 @@ TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") local x : IndexableClass x[true] = 42 )"); + CHECK_EQ( toString(result.errors[0]), "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible"); } @@ -593,7 +596,10 @@ TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") local x : IndexableNumericKeyClass x["key"] = 1 )"); - CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ(toString(result.errors[0]), "Key 'key' not found in class 'IndexableNumericKeyClass'"); + else + CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); } { CheckResult result = check(R"( @@ -615,7 +621,10 @@ TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") local x : IndexableNumericKeyClass local y = x["key"] )"); - CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ(toString(result.errors[0]), "Key 'key' not found in class 'IndexableNumericKeyClass'"); + else + CHECK_EQ(toString(result.errors[0]), "Type 'string' could not be converted into 'number'"); } { CheckResult result = check(R"( diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 9712f09..78f7558 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -358,6 +358,22 @@ TEST_CASE_FIXTURE(Fixture, "another_recursive_local_function") LUAU_REQUIRE_NO_ERRORS(result); } +// We had a bug where we'd look up the type of a recursive call using the DFG, +// not the bindings tables. As a result, we would erroneously use the +// generalized type of foo() in this recursive fragment. This creates a +// constraint cycle that doesn't always work itself out. +// +// The fix is for the DFG node within the scope of foo() to retain the +// ungeneralized type of foo. +TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_calls_must_refer_to_the_ungeneralized_type") +{ + CheckResult result = check(R"( + function foo() + string.format('%s: %s', "51", foo()) + end + )"); +} + TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") { CheckResult result = check(R"( @@ -1029,7 +1045,7 @@ TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") LUAU_REQUIRE_NO_ERRORS(result); TypeId type = requireTypeAtPosition(Position(6, 14)); CHECK_EQ("(tbl, number, number) -> number", toString(type)); - auto ftv = get(type); + auto ftv = get(follow(type)); REQUIRE(ftv); CHECK(ftv->hasSelf); } @@ -1967,7 +1983,7 @@ TEST_CASE_FIXTURE(Fixture, "inner_frees_become_generic_in_dcr") LUAU_REQUIRE_NO_ERRORS(result); std::optional ty = findTypeAtPosition(Position{3, 19}); REQUIRE(ty); - CHECK(get(*ty)); + CHECK(get(follow(*ty))); } TEST_CASE_FIXTURE(Fixture, "function_exprs_are_generalized_at_signature_scope_not_enclosing") diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 99abf71..738d3cd 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -132,40 +132,23 @@ TEST_CASE_FIXTURE(Fixture, "should_still_pick_an_overload_whose_arguments_are_un TEST_CASE_FIXTURE(Fixture, "propagates_name") { - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CheckResult result = check(R"( - type A={a:number} - type B={b:string} + const std::string code = R"( + type A={a:number} + type B={b:string} - local c:A&B - local b = c - )"); + local c:A&B + local b = c + )"; - LUAU_REQUIRE_NO_ERRORS(result); + const std::string expected = R"( + type A={a:number} + type B={b:string} - CHECK("{| a: number, b: string |}" == toString(requireType("b"))); - } - else - { - const std::string code = R"( - type A={a:number} - type B={b:string} + local c:A&B + local b:A&B=c + )"; - local c:A&B - local b = c - )"; - - const std::string expected = R"( - type A={a:number} - type B={b:string} - - local c:A&B - local b:A&B=c - )"; - - CHECK_EQ(expected, decorateWithTypes(code)); - } + CHECK_EQ(expected, decorateWithTypes(code)); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_property_guaranteed_to_exist") @@ -328,11 +311,7 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") LUAU_REQUIRE_ERROR_COUNT(1, result); auto e = toString(result.errors[0]); - // In DCR, because of type normalization, we print a different error message - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("Cannot add property 'z' to table '{| x: number, y: number |}'", e); - else - CHECK_EQ("Cannot add property 'z' to table 'X & Y'", e); + CHECK_EQ("Cannot add property 'z' to table 'X & Y'", e); } TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") @@ -406,10 +385,7 @@ local a: XYZ = 3 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into '{| x: number, y: number, z: number |}')"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into 'X & Y & Z' + CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into 'X & Y & Z' caused by: Not all intersection parts are compatible. Type 'number' could not be converted into 'X')"); } @@ -426,11 +402,7 @@ local b: number = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ(toString(result.errors[0]), R"(Type '{| x: number, y: number, z: number |}' could not be converted into 'number')"); - else - CHECK_EQ( - toString(result.errors[0]), R"(Type 'X & Y & Z' could not be converted into 'number'; none of the intersection parts are compatible)"); + CHECK_EQ(toString(result.errors[0]), R"(Type 'X & Y & Z' could not be converted into 'number'; none of the intersection parts are compatible)"); } TEST_CASE_FIXTURE(Fixture, "overload_is_not_a_function") @@ -470,11 +442,7 @@ TEST_CASE_FIXTURE(Fixture, "intersect_bool_and_false") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ(toString(result.errors[0]), "Type 'false' could not be converted into 'true'"); - else - CHECK_EQ( - toString(result.errors[0]), "Type 'boolean & false' could not be converted into 'true'; none of the intersection parts are compatible"); + CHECK_EQ(toString(result.errors[0]), "Type 'boolean & false' could not be converted into 'true'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "intersect_false_and_bool_and_false") @@ -486,14 +454,9 @@ TEST_CASE_FIXTURE(Fixture, "intersect_false_and_bool_and_false") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ(toString(result.errors[0]), "Type 'false' could not be converted into 'true'"); - else - { - // TODO: odd stringification of `false & (boolean & false)`.) - CHECK_EQ(toString(result.errors[0]), - "Type 'boolean & false & false' could not be converted into 'true'; none of the intersection parts are compatible"); - } + // TODO: odd stringification of `false & (boolean & false)`.) + CHECK_EQ(toString(result.errors[0]), + "Type 'boolean & false & false' could not be converted into 'true'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "intersect_saturate_overloaded_functions") @@ -531,21 +494,8 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(toString(result.errors[0]), - "Type '{| p: number?, q: nil, r: number? |}' could not be converted into '{| p: nil |}'\n" - "caused by:\n" - " Property 'p' is not compatible. Type 'number?' could not be converted into 'nil'\n" - "caused by:\n" - " Not all union options are compatible. Type 'number' could not be converted into 'nil' in an invariant context"); - } - else - { - CHECK_EQ(toString(result.errors[0]), - "Type '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' could not be converted into " - "'{| p: nil |}'; none of the intersection parts are compatible"); - } + CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' could not be converted into " + "'{| p: nil |}'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") @@ -558,27 +508,9 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") local z : { p : string?, q : number? } = x -- Not OK )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK_EQ(toString(result.errors[0]), - "Type '{| p: number?, q: string? |}' could not be converted into '{| p: string?, q: number? |}'\n" - "caused by:\n" - " Property 'p' is not compatible. Type 'number' could not be converted into 'string' in an invariant context"); - - CHECK_EQ(toString(result.errors[1]), - "Type '{| p: number?, q: string? |}' could not be converted into '{| p: string?, q: number? |}'\n" - "caused by:\n" - " Property 'q' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into " - "'{| p: string?, q: number? |}'; none of the intersection parts are compatible"); - } + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into " + "'{| p: string?, q: number? |}'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_never_properties") @@ -605,18 +537,9 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_returning_intersections") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(toString(result.errors[0]), - "Type '((number?) -> {| p: number, q: number |}) & ((string?) -> {| p: number, r: number |})' could not be converted into " - "'(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); - } - else - { - CHECK_EQ(toString(result.errors[0]), - "Type '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' could not be converted into " - "'(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); - } + CHECK_EQ(toString(result.errors[0]), + "Type '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' could not be converted into " + "'(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); } TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic") @@ -917,7 +840,8 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(never) -> never", toString(requireType("f"))); + // TODO? We do not simplify types from explicit annotations. + CHECK_EQ("({| x: number |} & {| x: string |}) -> {| x: number |} & {| x: string |}", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types_2") @@ -933,7 +857,7 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_intersection_types_2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(never) -> never", toString(requireType("f"))); + CHECK_EQ("({| x: number |} & {| x: string |}) -> never", toString(requireType("f"))); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 90436ce..dd26cc8 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -676,9 +676,19 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") src += "end"; CheckResult result = check(src); - LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); - CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // TODO: This will eventually entirely go away, but for now the Add + // family will ensure there's one less error. + LUAU_REQUIRE_ERROR_COUNT(ops.size() - 1, result); + CHECK_EQ("Unknown type used in - operation; consider adding a type annotation to 'a'", toString(result.errors[0])); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); + CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "and_binexps_dont_unify") @@ -889,8 +899,16 @@ TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Unknown type used in + operation; consider adding a type annotation to 'x'"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireType("f")) == "(a, b) -> Add"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type used in + operation; consider adding a type annotation to 'x'"); + } result = check(Mode::Nonstrict, R"( local function f(x, y) @@ -985,31 +1003,6 @@ TEST_CASE_FIXTURE(Fixture, "unrelated_primitives_cannot_be_compared") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(BuiltinsFixture, "mm_ops_must_return_a_value") -{ - if (!FFlag::DebugLuauDeferredConstraintResolution) - return; - - CheckResult result = check(R"( - local mm = { - __add = function(self, other) - return - end, - } - - local x = setmetatable({}, mm) - local y = x + 123 - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK(requireType("y") == builtinTypes->errorRecoveryType()); - - const GenericError* ge = get(result.errors[1]); - REQUIRE(ge); - CHECK(ge->message == "Metamethod '__add' must return a value"); -} - TEST_CASE_FIXTURE(BuiltinsFixture, "mm_comparisons_must_return_a_boolean") { if (!FFlag::DebugLuauDeferredConstraintResolution) @@ -1179,6 +1172,38 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.String.slice") +{ + + CheckResult result = check(R"( +--!strict +local function slice(str: string, startIndexStr: string | number, lastIndexStr: (string | number)?): string + local strLen, invalidBytePosition = utf8.len(str) + assert(strLen ~= nil, ("string `%s` has an invalid byte at position %s"):format(str, tostring(invalidBytePosition))) + local startIndex = tonumber(startIndexStr) + + + -- if no last index length set, go to str length + 1 + local lastIndex = strLen + 1 + + assert(typeof(lastIndex) == "number", "lastIndexStr should convert to number") + + if lastIndex > strLen then + lastIndex = strLen + 1 + end + + local startIndexByte = utf8.offset(str, startIndex) + + return string.sub(str, startIndexByte, startIndexByte) +end + +return slice + + + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Array.startswith") { // This test also exercises whether the binary operator == passes the correct expected type @@ -1204,5 +1229,24 @@ return startsWith LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "add_type_family_works") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + local function add(x, y) + return x + y + end + + local a = add(1, 2) + local b = add("foo", "bar") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(requireType("a")) == "number"); + CHECK(toString(requireType("b")) == "Add"); + CHECK(toString(result.errors[0]) == "Type family instance Add is uninhabited"); +} TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 606a4f4..885a978 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -532,7 +532,7 @@ return wrapStrictTable(Constants, "Constants") std::optional result = first(m->returnType); REQUIRE(result); if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("(any?) & ~table", toString(*result)); + CHECK_EQ("(any & ~table)?", toString(*result)); else CHECK_MESSAGE(get(*result), *result); } @@ -819,4 +819,61 @@ TEST_CASE_FIXTURE(Fixture, "lookup_prop_of_intersection_containing_unions_of_tab // CHECK("variable" == unknownProp->key); } +TEST_CASE_FIXTURE(Fixture, "expected_type_should_be_a_helpful_deduction_guide_for_function_calls") +{ + ScopedFastFlag sffs[]{ + {"LuauUnifyTwoOptions", true}, + {"LuauTypeMismatchInvarianceInError", true}, + }; + + CheckResult result = check(R"( + type Ref = { val: T } + + local function useRef(x: T): Ref + return { val = x } + end + + local x: Ref = useRef(nil) + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + // This is actually wrong! Sort of. It's doing the wrong thing, it's actually asking whether + // `{| val: number? |} <: {| val: nil |}` + // instead of the correct way, which is + // `{| val: nil |} <: {| val: number? |}` + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(toString(result.errors[0]), R"(Type 'Ref' could not be converted into 'Ref' +caused by: + Property 'val' is not compatible. Type 'nil' could not be converted into 'number' in an invariant context)"); + } +} + +TEST_CASE_FIXTURE(Fixture, "floating_generics_should_not_be_allowed") +{ + CheckResult result = check(R"( + local assign : (target: T, source0: U?, source1: V?, source2: W?, ...any) -> T & U & V & W = (nil :: any) + + -- We have a big problem here: The generics U, V, and W are not bound to anything! + -- Things get strange because of this. + local benchmark = assign({}) + local options = benchmark.options + do + local resolve2: any = nil + options.fn({ + resolve = function(...) + resolve2(...) + end, + }) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index c55497a..3b0654a 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1020,16 +1020,8 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_tag") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(R"({| catfood: string, name: string, tag: "Cat" |})", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ(R"({| dogfood: string, name: string, tag: "Dog" |})", toString(requireTypeAtPosition({9, 33}))); - } - else - { - CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); - } + CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); } TEST_CASE_FIXTURE(Fixture, "discriminate_tag_with_implicit_else") @@ -1050,16 +1042,8 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_tag_with_implicit_else") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - CHECK_EQ(R"({| catfood: string, name: string, tag: "Cat" |})", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ(R"({| dogfood: string, name: string, tag: "Dog" |})", toString(requireTypeAtPosition({9, 33}))); - } - else - { - CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); - CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); - } + CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); } TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") @@ -1403,7 +1387,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns") if (FFlag::DebugLuauDeferredConstraintResolution) { CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("~string", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("unknown & ~string", toString(requireTypeAtPosition({5, 28}))); } else { @@ -1508,14 +1492,7 @@ local _ = _ ~= _ or _ or _ end )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - // Without a realistic motivating case, it's hard to tell if it's important for this to work without errors. - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(get(result.errors[0])); - } - else - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_take_the_length") @@ -1615,7 +1592,7 @@ TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ("~false & ~nil", toString(requireTypeAtPosition({4, 30}))); + CHECK_EQ("~(false?)", toString(requireTypeAtPosition({4, 30}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 4b24fb2..82a20bc 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1059,11 +1059,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "unification_of_unions_in_a_self_referential_ const MetatableType* amtv = get(requireType("a")); REQUIRE(amtv); - CHECK_EQ(amtv->metatable, requireType("amt")); + CHECK_EQ(follow(amtv->metatable), follow(requireType("amt"))); const MetatableType* bmtv = get(requireType("b")); REQUIRE(bmtv); - CHECK_EQ(bmtv->metatable, requireType("bmt")); + CHECK_EQ(follow(bmtv->metatable), follow(requireType("bmt"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "oop_polymorphic") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index cbb04cb..829f993 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -267,10 +267,7 @@ TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowi end )"); - if (FFlag::DebugLuauDeferredConstraintResolution) - LUAU_REQUIRE_ERROR_COUNT(1, result); - else - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 4411916..afe0552 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -1060,4 +1060,14 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "type_param_overflow") +{ + CheckResult result = check(R"( + type Two = { a: T, b: U } + local x: Two = { a = 1, b = 'c' } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 960d6f1..100abfb 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -354,10 +354,7 @@ a.x = 2 LUAU_REQUIRE_ERROR_COUNT(1, result); auto s = toString(result.errors[0]); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("Value of type '{| x: number, y: number |}?' could be nil", s); - else - CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", s); + CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", s); } TEST_CASE_FIXTURE(Fixture, "optional_length_error") @@ -870,4 +867,50 @@ TEST_CASE_FIXTURE(Fixture, "optional_class_instances_are_invariant") CHECK(expectedError == toString(result.errors[0])); } +TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Map.entries") +{ + + fileResolver.source["Module/Map"] = R"( +--!strict + +type Object = { [any]: any } +type Array = { [number]: T } +type Table = { [T]: V } +type Tuple = Array + +local Map = {} + +export type Map = { + size: number, + -- method definitions + set: (self: Map, K, V) -> Map, + get: (self: Map, K) -> V | nil, + clear: (self: Map) -> (), + delete: (self: Map, K) -> boolean, + has: (self: Map, K) -> boolean, + keys: (self: Map) -> Array, + values: (self: Map) -> Array, + entries: (self: Map) -> Array>, + ipairs: (self: Map) -> any, + [K]: V, + _map: { [K]: V }, + _array: { [number]: K }, +} + +function Map:entries() + return {} +end + +local function coerceToTable(mapLike: Map | Table): Array> + local e = mapLike:entries(); + return e +end + + )"; + + CheckResult result = frontend.check("Module/Map"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeReduction.test.cpp b/tests/TypeReduction.test.cpp deleted file mode 100644 index 5f11a71..0000000 --- a/tests/TypeReduction.test.cpp +++ /dev/null @@ -1,1509 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/TypeReduction.h" - -#include "Fixture.h" -#include "doctest.h" - -using namespace Luau; - -namespace -{ -struct ReductionFixture : Fixture -{ - TypeReductionOptions typeReductionOpts{/* allowTypeReductionsFromOtherArenas */ true}; - ToStringOptions toStringOpts{true}; - - TypeArena arena; - InternalErrorReporter iceHandler; - UnifierSharedState unifierState{&iceHandler}; - TypeReduction reduction{NotNull{&arena}, builtinTypes, NotNull{&iceHandler}, typeReductionOpts}; - - ReductionFixture() - { - registerHiddenTypes(&frontend); - createSomeClasses(&frontend); - } - - TypeId reductionof(TypeId ty) - { - std::optional reducedTy = reduction.reduce(ty); - REQUIRE(reducedTy); - return *reducedTy; - } - - std::optional tryReduce(const std::string& annotation) - { - check("type _Res = " + annotation); - return reduction.reduce(requireTypeAlias("_Res")); - } - - TypeId reductionof(const std::string& annotation) - { - check("type _Res = " + annotation); - return reductionof(requireTypeAlias("_Res")); - } - - std::string toStringFull(TypeId ty) - { - return toString(ty, toStringOpts); - } -}; -} // namespace - -TEST_SUITE_BEGIN("TypeReductionTests"); - -TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_exceeded") -{ - ScopedFastInt sfi{"LuauTypeReductionCartesianProductLimit", 5}; - - CheckResult result = check(R"( - type T - = string - & (number | string | boolean) - & (number | string | boolean) - )"); - - CHECK(!reduction.reduce(requireTypeAlias("T"))); - // LUAU_REQUIRE_ERROR_COUNT(1, result); - // CHECK("Code is too complex to typecheck! Consider simplifying the code around this area" == toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_exceeded_with_normal_limit") -{ - CheckResult result = check(R"( - type T - = string -- 1 = 1 - & (number | string | boolean) -- 1 * 3 = 3 - & (number | string | boolean) -- 3 * 3 = 9 - & (number | string | boolean) -- 9 * 3 = 27 - & (number | string | boolean) -- 27 * 3 = 81 - & (number | string | boolean) -- 81 * 3 = 243 - & (number | string | boolean) -- 243 * 3 = 729 - & (number | string | boolean) -- 729 * 3 = 2187 - & (number | string | boolean) -- 2187 * 3 = 6561 - & (number | string | boolean) -- 6561 * 3 = 19683 - & (number | string | boolean) -- 19683 * 3 = 59049 - & (number | string) -- 59049 * 2 = 118098 - )"); - - CHECK(!reduction.reduce(requireTypeAlias("T"))); - // LUAU_REQUIRE_ERROR_COUNT(1, result); - // CHECK("Code is too complex to typecheck! Consider simplifying the code around this area" == toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(ReductionFixture, "cartesian_product_is_zero") -{ - ScopedFastInt sfi{"LuauTypeReductionCartesianProductLimit", 5}; - - CheckResult result = check(R"( - type T - = string - & (number | string | boolean) - & (number | string | boolean) - & never - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(ReductionFixture, "stress_test_recursion_limits") -{ - TypeId ty = arena.addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}); - for (size_t i = 0; i < 20'000; ++i) - { - TableType table; - table.state = TableState::Sealed; - table.props["x"] = {ty}; - ty = arena.addType(IntersectionType{{arena.addType(table), arena.addType(table)}}); - } - - CHECK(!reduction.reduce(ty)); -} - -TEST_CASE_FIXTURE(ReductionFixture, "caching") -{ - SUBCASE("free_tables") - { - TypeId ty1 = arena.addType(TableType{}); - getMutable(ty1)->state = TableState::Free; - getMutable(ty1)->props["x"] = {builtinTypes->stringType}; - - TypeId ty2 = arena.addType(TableType{}); - getMutable(ty2)->state = TableState::Sealed; - - TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - - CHECK("{- x: string -} & {| |}" == toStringFull(reductionof(intersectionTy))); - - getMutable(ty1)->state = TableState::Sealed; - CHECK("{| x: string |}" == toStringFull(reductionof(intersectionTy))); - } - - SUBCASE("unsealed_tables") - { - TypeId ty1 = arena.addType(TableType{}); - getMutable(ty1)->state = TableState::Unsealed; - getMutable(ty1)->props["x"] = {builtinTypes->stringType}; - - TypeId ty2 = arena.addType(TableType{}); - getMutable(ty2)->state = TableState::Sealed; - - TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - - CHECK("{| x: string |}" == toStringFull(reductionof(intersectionTy))); - - getMutable(ty1)->state = TableState::Sealed; - CHECK("{| x: string |}" == toStringFull(reductionof(intersectionTy))); - } - - SUBCASE("free_types") - { - TypeId ty1 = arena.freshType(nullptr); - TypeId ty2 = arena.addType(TableType{}); - getMutable(ty2)->state = TableState::Sealed; - - TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - - CHECK("a & {| |}" == toStringFull(reductionof(intersectionTy))); - - *asMutable(ty1) = BoundType{ty2}; - CHECK("{| |}" == toStringFull(reductionof(intersectionTy))); - } - - SUBCASE("we_can_see_that_the_cache_works_if_we_mutate_a_normally_not_mutated_type") - { - TypeId ty1 = arena.addType(BoundType{builtinTypes->stringType}); - TypeId ty2 = builtinTypes->numberType; - - TypeId intersectionTy = arena.addType(IntersectionType{{ty1, ty2}}); - - CHECK("never" == toStringFull(reductionof(intersectionTy))); // Bound & number ~ never - - *asMutable(ty1) = BoundType{ty2}; - CHECK("never" == toStringFull(reductionof(intersectionTy))); // Bound & number ~ number, but the cache is `never`. - } - - SUBCASE("ptr_eq_irreducible_unions") - { - TypeId unionTy = arena.addType(UnionType{{builtinTypes->stringType, builtinTypes->numberType}}); - TypeId reducedTy = reductionof(unionTy); - REQUIRE(unionTy == reducedTy); - } - - SUBCASE("ptr_eq_irreducible_intersections") - { - TypeId intersectionTy = arena.addType(IntersectionType{{builtinTypes->stringType, arena.addType(GenericType{"G"})}}); - TypeId reducedTy = reductionof(intersectionTy); - REQUIRE(intersectionTy == reducedTy); - } - - SUBCASE("ptr_eq_free_table") - { - TypeId tableTy = arena.addType(TableType{}); - getMutable(tableTy)->state = TableState::Free; - - TypeId reducedTy = reductionof(tableTy); - REQUIRE(tableTy == reducedTy); - } - - SUBCASE("ptr_eq_unsealed_table") - { - TypeId tableTy = arena.addType(TableType{}); - getMutable(tableTy)->state = TableState::Unsealed; - - TypeId reducedTy = reductionof(tableTy); - REQUIRE(tableTy == reducedTy); - } -} // caching - -TEST_CASE_FIXTURE(ReductionFixture, "intersections_without_negations") -{ - SUBCASE("string_and_string") - { - TypeId ty = reductionof("string & string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("never_and_string") - { - TypeId ty = reductionof("never & string"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_never") - { - TypeId ty = reductionof("string & never"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("unknown_and_string") - { - TypeId ty = reductionof("unknown & string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_and_unknown") - { - TypeId ty = reductionof("string & unknown"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("any_and_string") - { - TypeId ty = reductionof("any & string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_and_any") - { - TypeId ty = reductionof("string & any"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_number_and_string") - { - TypeId ty = reductionof("(string | number) & string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_and_string_or_number") - { - TypeId ty = reductionof("string & (string | number)"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_and_a") - { - TypeId ty = reductionof(R"(string & "a")"); - CHECK(R"("a")" == toStringFull(ty)); - } - - SUBCASE("boolean_and_true") - { - TypeId ty = reductionof("boolean & true"); - CHECK("true" == toStringFull(ty)); - } - - SUBCASE("boolean_and_a") - { - TypeId ty = reductionof(R"(boolean & "a")"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("a_and_a") - { - TypeId ty = reductionof(R"("a" & "a")"); - CHECK(R"("a")" == toStringFull(ty)); - } - - SUBCASE("a_and_b") - { - TypeId ty = reductionof(R"("a" & "b")"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("a_and_true") - { - TypeId ty = reductionof(R"("a" & true)"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("a_and_true") - { - TypeId ty = reductionof(R"(true & false)"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_type_and_function") - { - TypeId ty = reductionof("() -> () & fun"); - CHECK("() -> ()" == toStringFull(ty)); - } - - SUBCASE("function_type_and_string") - { - TypeId ty = reductionof("() -> () & string"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("parent_and_child") - { - TypeId ty = reductionof("Parent & Child"); - CHECK("Child" == toStringFull(ty)); - } - - SUBCASE("child_and_parent") - { - TypeId ty = reductionof("Child & Parent"); - CHECK("Child" == toStringFull(ty)); - } - - SUBCASE("child_and_unrelated") - { - TypeId ty = reductionof("Child & Unrelated"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_table") - { - TypeId ty = reductionof("string & {}"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_child") - { - TypeId ty = reductionof("string & Child"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_function") - { - TypeId ty = reductionof("string & () -> ()"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_and_table") - { - TypeId ty = reductionof("() -> () & {}"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_and_class") - { - TypeId ty = reductionof("() -> () & Child"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_and_function") - { - TypeId ty = reductionof("() -> () & () -> ()"); - CHECK("(() -> ()) & (() -> ())" == toStringFull(ty)); - } - - SUBCASE("table_and_table") - { - TypeId ty = reductionof("{} & {}"); - CHECK("{| |}" == toStringFull(ty)); - } - - SUBCASE("table_and_metatable") - { - // No setmetatable in ReductionFixture, so we mix and match. - BuiltinsFixture fixture; - fixture.check(R"( - type Ty = {} & typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("{ @metatable { }, { } } & {| |}" == toStringFull(ty)); - } - - SUBCASE("a_and_string") - { - TypeId ty = reductionof(R"("a" & string)"); - CHECK(R"("a")" == toStringFull(ty)); - } - - SUBCASE("reducible_function_and_function") - { - TypeId ty = reductionof("((string | string) -> (number | number)) & fun"); - CHECK("(string) -> number" == toStringFull(ty)); - } - - SUBCASE("string_and_error") - { - TypeId ty = reductionof("string & err"); - CHECK("*error-type* & string" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_number") - { - TypeId ty = reductionof("{ p: string } & { p: number }"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_string") - { - TypeId ty = reductionof("{ p: string } & { p: string }"); - CHECK("{| p: string |}" == toStringFull(ty)); - } - - SUBCASE("table_x_table_p_string_and_table_x_table_p_number") - { - TypeId ty = reductionof("{ x: { p: string } } & { x: { p: number } }"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("table_p_and_table_q") - { - TypeId ty = reductionof("{ p: string } & { q: number }"); - CHECK("{| p: string, q: number |}" == toStringFull(ty)); - } - - SUBCASE("table_tag_a_or_table_tag_b_and_table_b") - { - TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { b: string }"); - CHECK("{| a: number, b: string, tag: string |} | {| b: string, tag: number |}" == toStringFull(ty)); - } - - SUBCASE("table_string_number_indexer_and_table_string_number_indexer") - { - TypeId ty = reductionof("{ [string]: number } & { [string]: number }"); - CHECK("{| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("table_string_number_indexer_and_empty_table") - { - TypeId ty = reductionof("{ [string]: number } & {}"); - CHECK("{| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("empty_table_table_string_number_indexer") - { - TypeId ty = reductionof("{} & { [string]: number }"); - CHECK("{| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("string_number_indexer_and_number_number_indexer") - { - TypeId ty = reductionof("{ [string]: number } & { [number]: number }"); - CHECK("{number} & {| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_indexer_number_number") - { - TypeId ty = reductionof("{ p: string } & { [number]: number }"); - CHECK("{| [number]: number, p: string |}" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_indexer_string_number") - { - TypeId ty = reductionof("{ p: string } & { [string]: number }"); - CHECK("{| [string]: number, p: string |}" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_string_plus_indexer_string_number") - { - TypeId ty = reductionof("{ p: string } & { p: string, [string]: number }"); - CHECK("{| [string]: number, p: string |}" == toStringFull(ty)); - } - - SUBCASE("array_number_and_array_string") - { - TypeId ty = reductionof("{number} & {string}"); - CHECK("{never}" == toStringFull(ty)); - } - - SUBCASE("array_string_and_array_string") - { - TypeId ty = reductionof("{string} & {string}"); - CHECK("{string}" == toStringFull(ty)); - } - - SUBCASE("array_string_or_number_and_array_string") - { - TypeId ty = reductionof("{string | number} & {string}"); - CHECK("{string}" == toStringFull(ty)); - } - - SUBCASE("fresh_type_and_string") - { - TypeId freshTy = arena.freshType(nullptr); - TypeId ty = reductionof(arena.addType(IntersectionType{{freshTy, builtinTypes->stringType}})); - CHECK("a & string" == toStringFull(ty)); - } - - SUBCASE("string_and_fresh_type") - { - TypeId freshTy = arena.freshType(nullptr); - TypeId ty = reductionof(arena.addType(IntersectionType{{builtinTypes->stringType, freshTy}})); - CHECK("a & string" == toStringFull(ty)); - } - - SUBCASE("generic_and_string") - { - TypeId genericTy = arena.addType(GenericType{"G"}); - TypeId ty = reductionof(arena.addType(IntersectionType{{genericTy, builtinTypes->stringType}})); - CHECK("G & string" == toStringFull(ty)); - } - - SUBCASE("string_and_generic") - { - TypeId genericTy = arena.addType(GenericType{"G"}); - TypeId ty = reductionof(arena.addType(IntersectionType{{builtinTypes->stringType, genericTy}})); - CHECK("G & string" == toStringFull(ty)); - } - - SUBCASE("parent_and_child_or_parent_and_anotherchild_or_parent_and_unrelated") - { - TypeId ty = reductionof("Parent & (Child | AnotherChild | Unrelated)"); - CHECK("AnotherChild | Child" == toString(ty)); - } - - SUBCASE("parent_and_child_or_parent_and_anotherchild_or_parent_and_unrelated_2") - { - TypeId ty = reductionof("(Parent & Child) | (Parent & AnotherChild) | (Parent & Unrelated)"); - CHECK("AnotherChild | Child" == toString(ty)); - } - - SUBCASE("top_table_and_table") - { - TypeId ty = reductionof("tbl & {}"); - CHECK("{| |}" == toString(ty)); - } - - SUBCASE("top_table_and_non_table") - { - TypeId ty = reductionof("tbl & \"foo\""); - CHECK("never" == toString(ty)); - } - - SUBCASE("top_table_and_metatable") - { - BuiltinsFixture fixture; - registerHiddenTypes(&fixture.frontend); - fixture.check(R"( - type Ty = tbl & typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("{ @metatable { }, { } }" == toString(ty)); - } -} // intersections_without_negations - -TEST_CASE_FIXTURE(ReductionFixture, "intersections_with_negations") -{ - SUBCASE("nil_and_not_nil") - { - TypeId ty = reductionof("nil & Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("nil_and_not_false") - { - TypeId ty = reductionof("nil & Not"); - CHECK("nil" == toStringFull(ty)); - } - - SUBCASE("string_or_nil_and_not_nil") - { - TypeId ty = reductionof("(string?) & Not"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_nil_and_not_false_or_nil") - { - TypeId ty = reductionof("(string?) & Not"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_nil_and_not_false_and_not_nil") - { - TypeId ty = reductionof("(string?) & Not & Not"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("not_false_and_bool") - { - TypeId ty = reductionof("Not & boolean"); - CHECK("true" == toStringFull(ty)); - } - - SUBCASE("function_type_and_not_function") - { - TypeId ty = reductionof("() -> () & Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("function_type_and_not_string") - { - TypeId ty = reductionof("() -> () & Not"); - CHECK("() -> ()" == toStringFull(ty)); - } - - SUBCASE("not_a_and_string_or_nil") - { - TypeId ty = reductionof(R"(Not<"a"> & (string | nil))"); - CHECK(R"((string & ~"a")?)" == toStringFull(ty)); - } - - SUBCASE("not_a_and_a") - { - TypeId ty = reductionof(R"(Not<"a"> & "a")"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_a_and_b") - { - TypeId ty = reductionof(R"(Not<"a"> & "b")"); - CHECK(R"("b")" == toStringFull(ty)); - } - - SUBCASE("not_string_and_a") - { - TypeId ty = reductionof(R"(Not & "a")"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_bool_and_true") - { - TypeId ty = reductionof("Not & true"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_string_and_true") - { - TypeId ty = reductionof("Not & true"); - CHECK("true" == toStringFull(ty)); - } - - SUBCASE("parent_and_not_child") - { - TypeId ty = reductionof("Parent & Not"); - CHECK("Parent & ~Child" == toStringFull(ty)); - } - - SUBCASE("not_child_and_parent") - { - TypeId ty = reductionof("Not & Parent"); - CHECK("Parent & ~Child" == toStringFull(ty)); - } - - SUBCASE("child_and_not_parent") - { - TypeId ty = reductionof("Child & Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_parent_and_child") - { - TypeId ty = reductionof("Not & Child"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_parent_and_unrelated") - { - TypeId ty = reductionof("Not & Unrelated"); - CHECK("Unrelated" == toStringFull(ty)); - } - - SUBCASE("unrelated_and_not_parent") - { - TypeId ty = reductionof("Unrelated & Not"); - CHECK("Unrelated" == toStringFull(ty)); - } - - SUBCASE("not_unrelated_and_parent") - { - TypeId ty = reductionof("Not & Parent"); - CHECK("Parent" == toStringFull(ty)); - } - - SUBCASE("parent_and_not_unrelated") - { - TypeId ty = reductionof("Parent & Not"); - CHECK("Parent" == toStringFull(ty)); - } - - SUBCASE("reducible_function_and_not_function") - { - TypeId ty = reductionof("((string | string) -> (number | number)) & Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("string_and_not_error") - { - TypeId ty = reductionof("string & Not"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_not_number") - { - TypeId ty = reductionof("{ p: string } & { p: Not }"); - CHECK("{| p: string |}" == toStringFull(ty)); - } - - SUBCASE("table_p_string_and_table_p_not_string") - { - TypeId ty = reductionof("{ p: string } & { p: Not }"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("table_x_table_p_string_and_table_x_table_p_not_number") - { - TypeId ty = reductionof("{ x: { p: string } } & { x: { p: Not } }"); - CHECK("{| x: {| p: string |} |}" == toStringFull(ty)); - } - - SUBCASE("table_or_nil_and_truthy") - { - TypeId ty = reductionof("({ x: number | string }?) & Not"); - CHECK("{| x: number | string |}" == toString(ty)); - } - - SUBCASE("not_top_table_and_table") - { - TypeId ty = reductionof("Not & {}"); - CHECK("never" == toString(ty)); - } - - SUBCASE("not_top_table_and_metatable") - { - BuiltinsFixture fixture; - registerHiddenTypes(&fixture.frontend); - fixture.check(R"( - type Ty = Not & typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("never" == toString(ty)); - } -} // intersections_with_negations - -TEST_CASE_FIXTURE(ReductionFixture, "unions_without_negations") -{ - SUBCASE("never_or_string") - { - TypeId ty = reductionof("never | string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_never") - { - TypeId ty = reductionof("string | never"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("unknown_or_string") - { - TypeId ty = reductionof("unknown | string"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("string_or_unknown") - { - TypeId ty = reductionof("string | unknown"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("any_or_string") - { - TypeId ty = reductionof("any | string"); - CHECK("any" == toStringFull(ty)); - } - - SUBCASE("string_or_any") - { - TypeId ty = reductionof("string | any"); - CHECK("any" == toStringFull(ty)); - } - - SUBCASE("string_or_string_and_number") - { - TypeId ty = reductionof("string | (string & number)"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_string") - { - TypeId ty = reductionof("string | string"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("string_or_number") - { - TypeId ty = reductionof("string | number"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("number_or_string") - { - TypeId ty = reductionof("number | string"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_number_or_string") - { - TypeId ty = reductionof("(string | number) | string"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_number_or_string_2") - { - TypeId ty = reductionof("string | (number | string)"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_string_or_number") - { - TypeId ty = reductionof("string | (string | number)"); - CHECK("number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_string_or_number_or_boolean") - { - TypeId ty = reductionof("string | (string | number | boolean)"); - CHECK("boolean | number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_string_or_boolean_or_number") - { - TypeId ty = reductionof("string | (string | boolean | number)"); - CHECK("boolean | number | string" == toStringFull(ty)); - } - - SUBCASE("string_or_boolean_or_string_or_number") - { - TypeId ty = reductionof("string | (boolean | string | number)"); - CHECK("boolean | number | string" == toStringFull(ty)); - } - - SUBCASE("boolean_or_string_or_number_or_string") - { - TypeId ty = reductionof("(boolean | string | number) | string"); - CHECK("boolean | number | string" == toStringFull(ty)); - } - - SUBCASE("boolean_or_true") - { - TypeId ty = reductionof("boolean | true"); - CHECK("boolean" == toStringFull(ty)); - } - - SUBCASE("boolean_or_false") - { - TypeId ty = reductionof("boolean | false"); - CHECK("boolean" == toStringFull(ty)); - } - - SUBCASE("boolean_or_true_or_false") - { - TypeId ty = reductionof("boolean | true | false"); - CHECK("boolean" == toStringFull(ty)); - } - - SUBCASE("string_or_a") - { - TypeId ty = reductionof(R"(string | "a")"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("a_or_a") - { - TypeId ty = reductionof(R"("a" | "a")"); - CHECK(R"("a")" == toStringFull(ty)); - } - - SUBCASE("a_or_b") - { - TypeId ty = reductionof(R"("a" | "b")"); - CHECK(R"("a" | "b")" == toStringFull(ty)); - } - - SUBCASE("a_or_b_or_string") - { - TypeId ty = reductionof(R"("a" | "b" | string)"); - CHECK("string" == toStringFull(ty)); - } - - SUBCASE("unknown_or_any") - { - TypeId ty = reductionof("unknown | any"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("any_or_unknown") - { - TypeId ty = reductionof("any | unknown"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("function_type_or_function") - { - TypeId ty = reductionof("() -> () | fun"); - CHECK("function" == toStringFull(ty)); - } - - SUBCASE("function_or_string") - { - TypeId ty = reductionof("fun | string"); - CHECK("function | string" == toStringFull(ty)); - } - - SUBCASE("parent_or_child") - { - TypeId ty = reductionof("Parent | Child"); - CHECK("Parent" == toStringFull(ty)); - } - - SUBCASE("child_or_parent") - { - TypeId ty = reductionof("Child | Parent"); - CHECK("Parent" == toStringFull(ty)); - } - - SUBCASE("parent_or_unrelated") - { - TypeId ty = reductionof("Parent | Unrelated"); - CHECK("Parent | Unrelated" == toStringFull(ty)); - } - - SUBCASE("parent_or_child_or_unrelated") - { - TypeId ty = reductionof("Parent | Child | Unrelated"); - CHECK("Parent | Unrelated" == toStringFull(ty)); - } - - SUBCASE("parent_or_unrelated_or_child") - { - TypeId ty = reductionof("Parent | Unrelated | Child"); - CHECK("Parent | Unrelated" == toStringFull(ty)); - } - - SUBCASE("parent_or_child_or_unrelated_or_child") - { - TypeId ty = reductionof("Parent | Child | Unrelated | Child"); - CHECK("Parent | Unrelated" == toStringFull(ty)); - } - - SUBCASE("string_or_true") - { - TypeId ty = reductionof("string | true"); - CHECK("string | true" == toStringFull(ty)); - } - - SUBCASE("string_or_function") - { - TypeId ty = reductionof("string | () -> ()"); - CHECK("(() -> ()) | string" == toStringFull(ty)); - } - - SUBCASE("string_or_err") - { - TypeId ty = reductionof("string | err"); - CHECK("*error-type* | string" == toStringFull(ty)); - } - - SUBCASE("top_table_or_table") - { - TypeId ty = reductionof("tbl | {}"); - CHECK("table" == toString(ty)); - } - - SUBCASE("top_table_or_metatable") - { - BuiltinsFixture fixture; - registerHiddenTypes(&fixture.frontend); - fixture.check(R"( - type Ty = tbl | typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("table" == toString(ty)); - } - - SUBCASE("top_table_or_non_table") - { - TypeId ty = reductionof("tbl | number"); - CHECK("number | table" == toString(ty)); - } -} // unions_without_negations - -TEST_CASE_FIXTURE(ReductionFixture, "unions_with_negations") -{ - SUBCASE("string_or_not_string") - { - TypeId ty = reductionof("string | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_string_or_string") - { - TypeId ty = reductionof("Not | string"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_number_or_string") - { - TypeId ty = reductionof("Not | string"); - CHECK("~number" == toStringFull(ty)); - } - - SUBCASE("string_or_not_number") - { - TypeId ty = reductionof("string | Not"); - CHECK("~number" == toStringFull(ty)); - } - - SUBCASE("not_hi_or_string_and_not_hi") - { - TypeId ty = reductionof(R"(Not<"hi"> | (string & Not<"hi">))"); - CHECK(R"(~"hi")" == toStringFull(ty)); - } - - SUBCASE("string_and_not_hi_or_not_hi") - { - TypeId ty = reductionof(R"((string & Not<"hi">) | Not<"hi">)"); - CHECK(R"(~"hi")" == toStringFull(ty)); - } - - SUBCASE("string_or_not_never") - { - TypeId ty = reductionof("string | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_a_or_not_a") - { - TypeId ty = reductionof(R"(Not<"a"> | Not<"a">)"); - CHECK(R"(~"a")" == toStringFull(ty)); - } - - SUBCASE("not_a_or_a") - { - TypeId ty = reductionof(R"(Not<"a"> | "a")"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("a_or_not_a") - { - TypeId ty = reductionof(R"("a" | Not<"a">)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_a_or_string") - { - TypeId ty = reductionof(R"(Not<"a"> | string)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("string_or_not_a") - { - TypeId ty = reductionof(R"(string | Not<"a">)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_string_or_a") - { - TypeId ty = reductionof(R"(Not | "a")"); - CHECK(R"("a" | ~string)" == toStringFull(ty)); - } - - SUBCASE("a_or_not_string") - { - TypeId ty = reductionof(R"("a" | Not)"); - CHECK(R"("a" | ~string)" == toStringFull(ty)); - } - - SUBCASE("not_number_or_a") - { - TypeId ty = reductionof(R"(Not | "a")"); - CHECK("~number" == toStringFull(ty)); - } - - SUBCASE("a_or_not_number") - { - TypeId ty = reductionof(R"("a" | Not)"); - CHECK("~number" == toStringFull(ty)); - } - - SUBCASE("not_a_or_not_b") - { - TypeId ty = reductionof(R"(Not<"a"> | Not<"b">)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("boolean_or_not_false") - { - TypeId ty = reductionof("boolean | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("boolean_or_not_true") - { - TypeId ty = reductionof("boolean | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("false_or_not_false") - { - TypeId ty = reductionof("false | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("true_or_not_false") - { - TypeId ty = reductionof("true | Not"); - CHECK("~false" == toStringFull(ty)); - } - - SUBCASE("not_boolean_or_true") - { - TypeId ty = reductionof("Not | true"); - CHECK("~false" == toStringFull(ty)); - } - - SUBCASE("not_false_or_not_boolean") - { - TypeId ty = reductionof("Not | Not"); - CHECK("~false" == toStringFull(ty)); - } - - SUBCASE("function_type_or_not_function") - { - TypeId ty = reductionof("() -> () | Not"); - CHECK("(() -> ()) | ~function" == toStringFull(ty)); - } - - SUBCASE("not_parent_or_child") - { - TypeId ty = reductionof("Not | Child"); - CHECK("Child | ~Parent" == toStringFull(ty)); - } - - SUBCASE("child_or_not_parent") - { - TypeId ty = reductionof("Child | Not"); - CHECK("Child | ~Parent" == toStringFull(ty)); - } - - SUBCASE("parent_or_not_child") - { - TypeId ty = reductionof("Parent | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_child_or_parent") - { - TypeId ty = reductionof("Not | Parent"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("parent_or_not_unrelated") - { - TypeId ty = reductionof("Parent | Not"); - CHECK("~Unrelated" == toStringFull(ty)); - } - - SUBCASE("not_string_or_string_and_not_a") - { - TypeId ty = reductionof(R"(Not | (string & Not<"a">))"); - CHECK(R"(~"a")" == toStringFull(ty)); - } - - SUBCASE("not_string_or_not_string") - { - TypeId ty = reductionof("Not | Not"); - CHECK("~string" == toStringFull(ty)); - } - - SUBCASE("not_string_or_not_number") - { - TypeId ty = reductionof("Not | Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_a_or_not_boolean") - { - TypeId ty = reductionof(R"(Not<"a"> | Not)"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_a_or_boolean") - { - TypeId ty = reductionof(R"(Not<"a"> | boolean)"); - CHECK(R"(~"a")" == toStringFull(ty)); - } - - SUBCASE("string_or_err") - { - TypeId ty = reductionof("string | Not"); - CHECK("string | ~*error-type*" == toStringFull(ty)); - } - - SUBCASE("not_top_table_or_table") - { - TypeId ty = reductionof("Not | {}"); - CHECK("{| |} | ~table" == toString(ty)); - } - - SUBCASE("not_top_table_or_metatable") - { - BuiltinsFixture fixture; - registerHiddenTypes(&fixture.frontend); - fixture.check(R"( - type Ty = Not | typeof(setmetatable({}, {})) - )"); - - TypeId ty = reductionof(fixture.requireTypeAlias("Ty")); - CHECK("{ @metatable { }, { } } | ~table" == toString(ty)); - } -} // unions_with_negations - -TEST_CASE_FIXTURE(ReductionFixture, "tables") -{ - SUBCASE("reduce_props") - { - TypeId ty = reductionof("{ x: string | string, y: number | number }"); - CHECK("{| x: string, y: number |}" == toStringFull(ty)); - } - - SUBCASE("reduce_indexers") - { - TypeId ty = reductionof("{ [string | string]: number | number }"); - CHECK("{| [string]: number |}" == toStringFull(ty)); - } - - SUBCASE("reduce_instantiated_type_parameters") - { - check(R"( - type Foo = { x: T } - local foo: Foo = { x = "hello" } - )"); - - TypeId ty = reductionof(requireType("foo")); - CHECK("Foo" == toString(ty)); - } - - SUBCASE("reduce_instantiated_type_pack_parameters") - { - check(R"( - type Foo = { x: () -> T... } - local foo: Foo = { x = function() return "hi", 5 end } - )"); - - TypeId ty = reductionof(requireType("foo")); - CHECK("Foo" == toString(ty)); - } - - SUBCASE("reduce_tables_within_tables") - { - TypeId ty = reductionof("{ x: { y: string & number } }"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("array_of_never") - { - TypeId ty = reductionof("{never}"); - CHECK("{never}" == toStringFull(ty)); - } -} - -TEST_CASE_FIXTURE(ReductionFixture, "metatables") -{ - SUBCASE("reduce_table_part") - { - TableType table; - table.state = TableState::Sealed; - table.props["x"] = {arena.addType(UnionType{{builtinTypes->stringType, builtinTypes->stringType}})}; - TypeId tableTy = arena.addType(std::move(table)); - - TypeId ty = reductionof(arena.addType(MetatableType{tableTy, arena.addType(TableType{})})); - CHECK("{ @metatable { }, {| x: string |} }" == toStringFull(ty)); - } - - SUBCASE("reduce_metatable_part") - { - TableType table; - table.state = TableState::Sealed; - table.props["x"] = {arena.addType(UnionType{{builtinTypes->stringType, builtinTypes->stringType}})}; - TypeId tableTy = arena.addType(std::move(table)); - - TypeId ty = reductionof(arena.addType(MetatableType{arena.addType(TableType{}), tableTy})); - CHECK("{ @metatable {| x: string |}, { } }" == toStringFull(ty)); - } -} - -TEST_CASE_FIXTURE(ReductionFixture, "functions") -{ - SUBCASE("reduce_parameters") - { - TypeId ty = reductionof("(string | string) -> ()"); - CHECK("(string) -> ()" == toStringFull(ty)); - } - - SUBCASE("reduce_returns") - { - TypeId ty = reductionof("() -> (string | string)"); - CHECK("() -> string" == toStringFull(ty)); - } - - SUBCASE("reduce_parameters_and_returns") - { - TypeId ty = reductionof("(string | string) -> (number | number)"); - CHECK("(string) -> number" == toStringFull(ty)); - } - - SUBCASE("reduce_tail") - { - TypeId ty = reductionof("() -> ...(string | string)"); - CHECK("() -> (...string)" == toStringFull(ty)); - } - - SUBCASE("reduce_head_and_tail") - { - TypeId ty = reductionof("() -> (string | string, number | number, ...(boolean | boolean))"); - CHECK("() -> (string, number, ...boolean)" == toStringFull(ty)); - } - - SUBCASE("reduce_overloaded_functions") - { - TypeId ty = reductionof("((number | number) -> ()) & ((string | string) -> ())"); - CHECK("((number) -> ()) & ((string) -> ())" == toStringFull(ty)); - } -} // functions - -TEST_CASE_FIXTURE(ReductionFixture, "negations") -{ - SUBCASE("not_unknown") - { - TypeId ty = reductionof("Not"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_never") - { - TypeId ty = reductionof("Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_any") - { - TypeId ty = reductionof("Not"); - CHECK("any" == toStringFull(ty)); - } - - SUBCASE("not_not_reduction") - { - TypeId ty = reductionof("Not>"); - CHECK("never" == toStringFull(ty)); - } - - SUBCASE("not_string") - { - TypeId ty = reductionof("Not"); - CHECK("~string" == toStringFull(ty)); - } - - SUBCASE("not_string_or_number") - { - TypeId ty = reductionof("Not"); - CHECK("~number & ~string" == toStringFull(ty)); - } - - SUBCASE("not_string_and_number") - { - TypeId ty = reductionof("Not"); - CHECK("unknown" == toStringFull(ty)); - } - - SUBCASE("not_error") - { - TypeId ty = reductionof("Not"); - CHECK("~*error-type*" == toStringFull(ty)); - } -} // negations - -TEST_CASE_FIXTURE(ReductionFixture, "discriminable_unions") -{ - SUBCASE("cat_or_dog_and_dog") - { - TypeId ty = reductionof(R"(({ tag: "cat", catfood: string } | { tag: "dog", dogfood: string }) & { tag: "dog" })"); - CHECK(R"({| dogfood: string, tag: "dog" |})" == toStringFull(ty)); - } - - SUBCASE("cat_or_dog_and_not_dog") - { - TypeId ty = reductionof(R"(({ tag: "cat", catfood: string } | { tag: "dog", dogfood: string }) & { tag: Not<"dog"> })"); - CHECK(R"({| catfood: string, tag: "cat" |})" == toStringFull(ty)); - } - - SUBCASE("string_or_number_and_number") - { - TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { tag: string }"); - CHECK("{| a: number, tag: string |}" == toStringFull(ty)); - } - - SUBCASE("string_or_number_and_number") - { - TypeId ty = reductionof("({ tag: string, a: number } | { tag: number, b: string }) & { tag: number }"); - CHECK("{| b: string, tag: number |}" == toStringFull(ty)); - } - - SUBCASE("child_or_unrelated_and_parent") - { - TypeId ty = reductionof("({ tag: Child, x: number } | { tag: Unrelated, y: string }) & { tag: Parent }"); - CHECK("{| tag: Child, x: number |}" == toStringFull(ty)); - } - - SUBCASE("child_or_unrelated_and_not_parent") - { - TypeId ty = reductionof("({ tag: Child, x: number } | { tag: Unrelated, y: string }) & { tag: Not }"); - CHECK("{| tag: Unrelated, y: string |}" == toStringFull(ty)); - } -} - -TEST_CASE_FIXTURE(ReductionFixture, "cycles") -{ - SUBCASE("recursively_defined_function") - { - check("type F = (f: F) -> ()"); - - TypeId ty = reductionof(requireTypeAlias("F")); - CHECK("t1 where t1 = (t1) -> ()" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_function_and_function") - { - check("type F = (f: F & fun) -> ()"); - - TypeId ty = reductionof(requireTypeAlias("F")); - CHECK("t1 where t1 = (function & t1) -> ()" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_table") - { - check("type T = { x: T }"); - - TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("t1 where t1 = {| x: t1 |}" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_table_and_table") - { - check("type T = { x: T & {} }"); - - TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("t1 where t1 = {| x: t1 & {| |} |}" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_table_and_table_2") - { - check("type T = { x: T } & { x: number }"); - - TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("t1 where t1 = {| x: number |} & {| x: t1 |}" == toStringFull(ty)); - } - - SUBCASE("recursively_defined_table_and_table_3") - { - check("type T = { x: T } & { x: T }"); - - TypeId ty = reductionof(requireTypeAlias("T")); - CHECK("t1 where t1 = {| x: t1 |} & {| x: t1 |}" == toStringFull(ty)); - } -} - -TEST_CASE_FIXTURE(ReductionFixture, "string_singletons") -{ - TypeId ty = reductionof("(string & Not<\"A\">)?"); - CHECK("(string & ~\"A\")?" == toStringFull(ty)); -} - -TEST_CASE_FIXTURE(ReductionFixture, "string_singletons_2") -{ - TypeId ty = reductionof("Not<\"A\"> & Not<\"B\"> & (string?)"); - CHECK("(string & ~\"A\" & ~\"B\")?" == toStringFull(ty)); -} - -TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index b751126..8cdd36e 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -2,7 +2,6 @@ #include "Luau/Scope.h" #include "Luau/Type.h" #include "Luau/TypeInfer.h" -#include "Luau/TypeReduction.h" #include "Luau/VisitType.h" #include "Fixture.h" diff --git a/tools/faillist.txt b/tools/faillist.txt index a26e5c9..fe3353a 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -1,13 +1,10 @@ -AnnotationTests.too_many_type_params AstQuery.last_argument_function_call_type -AutocompleteTest.autocomplete_response_perf1 BuiltinTests.aliased_string_format BuiltinTests.assert_removes_falsy_types BuiltinTests.assert_removes_falsy_types2 BuiltinTests.assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy BuiltinTests.bad_select_should_not_crash -BuiltinTests.dont_add_definitions_to_persistent_types BuiltinTests.gmatch_definition BuiltinTests.math_max_checks_for_numbers BuiltinTests.select_slightly_out_of_range @@ -22,7 +19,6 @@ BuiltinTests.string_format_tostring_specifier_type_constraint BuiltinTests.string_format_use_correct_argument2 DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_string_props -GenericsTests.apply_type_function_nested_generics2 GenericsTests.better_mismatch_error_messages GenericsTests.bound_tables_do_not_clone_original_fields GenericsTests.check_mutual_generic_functions @@ -35,6 +31,7 @@ GenericsTests.generic_functions_should_be_memory_safe GenericsTests.generic_type_pack_parentheses GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments GenericsTests.infer_generic_function_function_argument_2 +GenericsTests.infer_generic_function_function_argument_3 GenericsTests.infer_generic_function_function_argument_overloaded GenericsTests.infer_generic_lib_function_function_argument GenericsTests.instantiated_function_argument_names @@ -42,23 +39,24 @@ GenericsTests.no_stack_overflow_from_quantifying GenericsTests.self_recursive_instantiated_param IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect -isSubtype.any_is_unknown_union_error ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal -ProvisionalTests.bail_early_if_unification_is_too_complicated ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean -ProvisionalTests.generic_type_leak_to_module_interface_variadic +ProvisionalTests.expected_type_should_be_a_helpful_deduction_guide_for_function_calls ProvisionalTests.greedy_inference_with_shared_self_triggers_function_with_no_returns ProvisionalTests.luau-polyfill.Array.filter ProvisionalTests.setmetatable_constrains_free_type_into_free_table ProvisionalTests.specialization_binds_with_prototypes_too_early ProvisionalTests.table_insert_with_a_singleton_argument ProvisionalTests.typeguard_inference_incomplete -RefinementTest.type_guard_can_filter_for_intersection_of_tables +RefinementTest.discriminate_from_truthiness_of_x +RefinementTest.not_t_or_some_prop_of_t +RefinementTest.truthy_constraint_on_properties RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector RefinementTest.typeguard_in_assert_position RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table +RuntimeLimits.typescript_port_of_Result_type TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.checked_prop_too_early TableTests.disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode @@ -71,9 +69,6 @@ TableTests.expected_indexer_value_type_extra TableTests.expected_indexer_value_type_extra_2 TableTests.explicitly_typed_table TableTests.explicitly_typed_table_with_indexer -TableTests.found_like_key_in_table_function_call -TableTests.found_like_key_in_table_property_access -TableTests.found_multiple_like_keys TableTests.fuzz_table_unify_instantiated_table TableTests.generic_table_instantiation_potential_regression TableTests.give_up_after_one_metatable_index_look_up @@ -92,7 +87,6 @@ TableTests.oop_polymorphic TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table -TableTests.result_is_always_any_if_lhs_is_any TableTests.result_is_bool_for_equality_operators_if_lhs_is_any TableTests.right_table_missing_key2 TableTests.shared_selfs @@ -101,7 +95,6 @@ TableTests.shared_selfs_through_metatables TableTests.table_call_metamethod_basic TableTests.table_simple_call TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors -TableTests.table_unification_4 TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon ToString.toStringDetailed2 @@ -122,7 +115,6 @@ TypeAliases.type_alias_local_mutation TypeAliases.type_alias_local_rename TypeAliases.type_alias_locations TypeAliases.type_alias_of_an_imported_recursive_generic_type -TypeFamilyTests.function_internal_families TypeInfer.check_type_infer_recursion_count TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError @@ -131,18 +123,14 @@ TypeInfer.follow_on_new_types_in_substitution TypeInfer.fuzz_free_table_type_change_during_index_check TypeInfer.infer_assignment_value_types_mutable_lval TypeInfer.no_stack_overflow_from_isoptional -TypeInfer.no_stack_overflow_from_isoptional2 -TypeInfer.recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter_2 TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer TypeInferAnyError.for_in_loop_iterator_is_any2 TypeInferClasses.class_type_mismatch_with_name_conflict -TypeInferClasses.classes_without_overloaded_operators_cannot_be_added TypeInferClasses.index_instance_property -TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties -TypeInferClasses.warn_when_prop_almost_matches TypeInferFunctions.cannot_hoist_interior_defns_into_signature +TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site TypeInferFunctions.function_cast_error_uses_correct_language TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_unsealed_overwrite @@ -177,6 +165,8 @@ TypeInferOperators.CallOrOfFunctions TypeInferOperators.cli_38355_recursive_union TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops +TypeInferOperators.luau-polyfill.String.slice +TypeInferOperators.luau_polyfill_is_array TypeInferOperators.operator_eq_completely_incompatible TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs @@ -191,7 +181,6 @@ TypePackTests.detect_cyclic_typepacks2 TypePackTests.pack_tail_unification_check TypePackTests.type_alias_backwards_compatible TypePackTests.type_alias_default_type_errors -TypePackTests.type_alias_type_packs_errors TypePackTests.unify_variadic_tails_in_arguments TypePackTests.variadic_packs TypeSingletons.function_call_with_singletons @@ -202,6 +191,7 @@ TypeSingletons.return_type_of_f_is_not_widened TypeSingletons.table_properties_type_error_escapes TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeSingletons.widening_happens_almost_everywhere +UnionTypes.dont_allow_cyclic_unions_to_be_inferred UnionTypes.generic_function_with_optional_arg UnionTypes.index_on_a_union_type_with_missing_property UnionTypes.optional_union_follow diff --git a/tools/lvmexecute_split.py b/tools/lvmexecute_split.py deleted file mode 100644 index 6e64bcd..0000000 --- a/tools/lvmexecute_split.py +++ /dev/null @@ -1,112 +0,0 @@ -#!/usr/bin/python3 -# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - -# This code can be used to split lvmexecute.cpp VM switch into separate functions for use as native code generation fallbacks -import sys -import re - -input = sys.stdin.readlines() - -inst = "" - -header = """// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand -#pragma once - -#include - -struct lua_State; -struct Closure; -typedef uint32_t Instruction; -typedef struct lua_TValue TValue; -typedef TValue* StkId; - -""" - -source = """// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details -// This file was generated by 'tools/lvmexecute_split.py' script, do not modify it by hand -#include "Fallbacks.h" -#include "FallbacksProlog.h" - -""" - -function = "" -signature = "" - -includeInsts = ["LOP_NEWCLOSURE", "LOP_NAMECALL", "LOP_FORGPREP", "LOP_GETVARARGS", "LOP_DUPCLOSURE", "LOP_PREPVARARGS", "LOP_BREAK", "LOP_GETGLOBAL", "LOP_SETGLOBAL", "LOP_GETTABLEKS", "LOP_SETTABLEKS", "LOP_SETLIST"] - -state = 0 - -# parse with the state machine -for line in input: - # find the start of an instruction - if state == 0: - match = re.match("\s+VM_CASE\((LOP_[A-Z_0-9]+)\)", line) - - if match: - inst = match[1] - signature = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)" - function = signature + "\n" - function += "{\n" - function += " [[maybe_unused]] Closure* cl = clvalue(L->ci->func);\n" - state = 1 - - # first line of the instruction which is "{" - elif state == 1: - assert(line == " {\n") - state = 2 - - # find the end of an instruction - elif state == 2: - # remove jumps back into the native code - if line == "#if LUA_CUSTOM_EXECUTION\n": - state = 3 - continue - - if line[0] == ' ': - finalline = line[12:-1] + "\n" - else: - finalline = line - - finalline = finalline.replace("VM_NEXT();", "return pc;"); - finalline = finalline.replace("goto exit;", "return NULL;"); - finalline = finalline.replace("return;", "return NULL;"); - - function += finalline - match = re.match(" }", line) - - if match: - # break is not supported - if inst == "LOP_BREAK": - function = "const Instruction* execute_" + inst + "(lua_State* L, const Instruction* pc, StkId base, TValue* k)\n" - function += "{\n LUAU_ASSERT(!\"Unsupported deprecated opcode\");\n LUAU_UNREACHABLE();\n}\n" - # handle fallthrough - elif inst == "LOP_NAMECALL": - function = function[:-len(finalline)] - function += " return pc;\n}\n" - - if inst in includeInsts: - header += signature + ";\n" - source += function + "\n" - - state = 0 - - # skip LUA_CUSTOM_EXECUTION code blocks - elif state == 3: - if line == "#endif\n": - state = 4 - continue - - # skip extra line - elif state == 4: - state = 2 - -# make sure we found the ending -assert(state == 0) - -with open("Fallbacks.h", "w") as fp: - fp.writelines(header) - -with open("Fallbacks.cpp", "w") as fp: - fp.writelines(source)