From d2ab5df62b9e6efc20aaece3254bedd80bb437fa Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 24 Feb 2023 13:49:38 -0800 Subject: [PATCH] Sync to upstream/release/565 (#845) We've made a few small changes to reduce the amount of stack we use when typechecking nested method calls (eg `foo:bar():baz():quux()`). We've also fixed a small bytecode compiler issue that caused us to emit redundant jump instructions in code that conditionally uses `break` or `continue`. On the new solver, we've switched to a new, better way to handle augmentations to unsealed tables. We've also made some substantial improvements to type inference and error reporting on function calls. These things should both be on par with the old solver now. The main improvements to the native code generator have been elimination of some redundant type tag checks. Also, we are starting to inline particular fastcalls directly to IR. --------- Co-authored-by: Arseny Kapoulkine Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/Constraint.h | 26 +- .../include/Luau/ConstraintGraphBuilder.h | 27 +- Analysis/include/Luau/ConstraintSolver.h | 17 +- Analysis/include/Luau/DcrLogger.h | 41 +- Analysis/include/Luau/Scope.h | 2 +- Analysis/include/Luau/Symbol.h | 5 + Analysis/include/Luau/Type.h | 23 + Analysis/include/Luau/TypeInfer.h | 4 +- Analysis/include/Luau/TypeReduction.h | 41 +- Analysis/include/Luau/Unifier.h | 6 + Analysis/src/ConstraintGraphBuilder.cpp | 163 ++-- Analysis/src/ConstraintSolver.cpp | 358 +++++++- Analysis/src/DcrLogger.cpp | 227 +++-- Analysis/src/Frontend.cpp | 4 +- Analysis/src/Normalize.cpp | 2 + Analysis/src/Quantify.cpp | 2 +- Analysis/src/Scope.cpp | 6 +- Analysis/src/ToString.cpp | 6 + Analysis/src/TypeChecker2.cpp | 248 ++++- Analysis/src/TypeInfer.cpp | 205 +++-- Analysis/src/TypeReduction.cpp | 362 ++++---- Analysis/src/Unifier.cpp | 7 +- CLI/Reduce.cpp | 2 +- CodeGen/include/Luau/IrBuilder.h | 3 + CodeGen/include/Luau/IrData.h | 95 +- CodeGen/include/Luau/IrUtils.h | 9 +- CodeGen/include/Luau/OptimizeConstProp.h | 16 + CodeGen/src/CodeGen.cpp | 496 +--------- CodeGen/src/EmitBuiltinsX64.cpp | 24 +- CodeGen/src/EmitBuiltinsX64.h | 13 +- CodeGen/src/EmitCommonX64.h | 8 - CodeGen/src/EmitInstructionX64.cpp | 867 +----------------- CodeGen/src/EmitInstructionX64.h | 61 +- CodeGen/src/IrAnalysis.cpp | 2 + CodeGen/src/IrBuilder.cpp | 63 +- CodeGen/src/IrDump.cpp | 55 +- CodeGen/src/IrLoweringX64.cpp | 373 ++++---- CodeGen/src/IrLoweringX64.h | 32 +- CodeGen/src/IrRegAllocX64.cpp | 181 ++++ CodeGen/src/IrRegAllocX64.h | 51 ++ CodeGen/src/IrTranslateBuiltins.cpp | 40 + CodeGen/src/IrTranslateBuiltins.h | 27 + CodeGen/src/IrTranslation.cpp | 81 +- CodeGen/src/IrTranslation.h | 3 + CodeGen/src/IrUtils.cpp | 70 +- CodeGen/src/OptimizeConstProp.cpp | 565 ++++++++++++ CodeGen/src/OptimizeFinalX64.cpp | 5 +- Common/include/Luau/Bytecode.h | 2 +- Compiler/src/Compiler.cpp | 22 +- Sources.cmake | 6 + tests/Compiler.test.cpp | 61 +- tests/ConstraintGraphBuilderFixture.cpp | 3 +- tests/IrBuilder.test.cpp | 689 +++++++++++++- tests/Module.test.cpp | 47 +- tests/NonstrictMode.test.cpp | 2 +- tests/ToString.test.cpp | 14 +- tests/TypeInfer.aliases.test.cpp | 12 +- tests/TypeInfer.functions.test.cpp | 15 +- tests/TypeInfer.refinements.test.cpp | 34 + tests/TypeInfer.tables.test.cpp | 25 +- tests/TypeInfer.tryUnify.test.cpp | 9 +- tests/TypeInfer.unknownnever.test.cpp | 2 +- tools/faillist.txt | 64 +- 63 files changed, 3501 insertions(+), 2430 deletions(-) create mode 100644 CodeGen/include/Luau/OptimizeConstProp.h create mode 100644 CodeGen/src/IrRegAllocX64.cpp create mode 100644 CodeGen/src/IrRegAllocX64.h create mode 100644 CodeGen/src/IrTranslateBuiltins.cpp create mode 100644 CodeGen/src/IrTranslateBuiltins.h create mode 100644 CodeGen/src/OptimizeConstProp.cpp diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 65599e4..1c41bbb 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -159,6 +159,20 @@ struct SetPropConstraint TypeId propType; }; +// result ~ setIndexer subjectType indexType propType +// +// If the subject is a table or table-like thing that already has an indexer, +// unify its indexType and propType with those from this constraint. +// +// If the table is a free or unsealed table, we augment it with a new indexer. +struct SetIndexerConstraint +{ + TypeId resultType; + TypeId subjectType; + TypeId indexType; + TypeId propType; +}; + // if negation: // result ~ if isSingleton D then ~D else unknown where D = discriminantType // if not negation: @@ -170,9 +184,19 @@ struct SingletonOrTopTypeConstraint bool negated; }; +// resultType ~ unpack sourceTypePack +// +// Similar to PackSubtypeConstraint, but with one important difference: If the +// sourcePack is blocked, this constraint blocks. +struct UnpackConstraint +{ + TypePackId resultPack; + TypePackId sourcePack; +}; + using ConstraintV = Variant; + HasPropConstraint, SetPropConstraint, SetIndexerConstraint, SingletonOrTopTypeConstraint, UnpackConstraint>; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 085b673..7b2711f 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -191,7 +191,7 @@ struct ConstraintGraphBuilder Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); - TypePackId checkLValues(const ScopePtr& scope, AstArray exprs); + std::vector checkLValues(const ScopePtr& scope, AstArray exprs); TypeId checkLValue(const ScopePtr& scope, AstExpr* expr); @@ -244,10 +244,31 @@ struct ConstraintGraphBuilder **/ TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& list, bool inTypeArguments); + /** + * Creates generic types given a list of AST definitions, resolving default + * types as required. + * @param scope the scope that the generics should belong to. + * @param generics the AST generics to create types for. + * @param useCache whether to use the generic type cache for the given + * scope. + * @param addTypes whether to add the types to the scope's + * privateTypeBindings map. + **/ std::vector> createGenerics( - const ScopePtr& scope, AstArray generics, bool useCache = false); + const ScopePtr& scope, AstArray generics, bool useCache = false, bool addTypes = true); + + /** + * Creates generic type packs given a list of AST definitions, resolving + * default type packs as required. + * @param scope the scope that the generic packs should belong to. + * @param generics the AST generics to create type packs for. + * @param useCache whether to use the generic type pack cache for the given + * scope. + * @param addTypes whether to add the types to the scope's + * privateTypePackBindings map. + **/ std::vector> createGenericPacks( - const ScopePtr& scope, AstArray packs, bool useCache = false); + const ScopePtr& scope, AstArray packs, bool useCache = false, bool addTypes = true); Inference flattenPack(const ScopePtr& scope, Location location, InferencePack pack); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index de7b3a0..62687ae 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -8,6 +8,7 @@ #include "Luau/Normalize.h" #include "Luau/ToString.h" #include "Luau/Type.h" +#include "Luau/TypeReduction.h" #include "Luau/Variant.h" #include @@ -19,7 +20,12 @@ struct DcrLogger; // TypeId, TypePackId, or Constraint*. It is impossible to know which, but we // never dereference this pointer. -using BlockedConstraintId = const void*; +using BlockedConstraintId = Variant; + +struct HashBlockedConstraintId +{ + size_t operator()(const BlockedConstraintId& bci) const; +}; struct ModuleResolver; @@ -47,6 +53,7 @@ struct ConstraintSolver NotNull builtinTypes; InternalErrorReporter iceReporter; NotNull normalizer; + NotNull reducer; // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; NotNull rootScope; @@ -65,7 +72,7 @@ struct ConstraintSolver // anything. std::unordered_map, size_t> blockedConstraints; // A mapping of type/pack pointers to the constraints they block. - std::unordered_map>> blocked; + std::unordered_map>, HashBlockedConstraintId> blocked; // Memoized instantiations of type aliases. DenseHashMap instantiatedAliases{{}}; @@ -78,7 +85,8 @@ struct ConstraintSolver DcrLogger* logger; explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); + ModuleName moduleName, NotNull reducer, NotNull moduleResolver, std::vector requireCycles, + DcrLogger* logger); // Randomize the order in which to dispatch constraints void randomize(unsigned seed); @@ -112,7 +120,9 @@ struct ConstraintSolver bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); bool tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force); bool tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint); + bool tryDispatch(const UnpackConstraint& c, NotNull constraint); // for a, ... in some_table do // also handles __iter metamethod @@ -123,6 +133,7 @@ struct ConstraintSolver TypeId nextTy, TypeId tableTy, TypeId firstIndexTy, const IterableConstraint& c, NotNull constraint, bool force); std::optional lookupTableProp(TypeId subjectType, const std::string& propName); + std::optional lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen); void block(NotNull target, NotNull constraint); /** diff --git a/Analysis/include/Luau/DcrLogger.h b/Analysis/include/Luau/DcrLogger.h index 45c84c6..1e170d5 100644 --- a/Analysis/include/Luau/DcrLogger.h +++ b/Analysis/include/Luau/DcrLogger.h @@ -4,6 +4,7 @@ #include "Luau/Constraint.h" #include "Luau/NotNull.h" #include "Luau/Scope.h" +#include "Luau/Module.h" #include "Luau/ToString.h" #include "Luau/Error.h" #include "Luau/Variant.h" @@ -34,11 +35,26 @@ struct TypeBindingSnapshot std::string typeString; }; +struct ExprTypesAtLocation +{ + Location location; + TypeId ty; + std::optional expectedTy; +}; + +struct AnnotationTypesAtLocation +{ + Location location; + TypeId resolvedTy; +}; + struct ConstraintGenerationLog { std::string source; - std::unordered_map constraintLocations; std::vector errors; + + std::vector exprTypeLocations; + std::vector annotationTypeLocations; }; struct ScopeSnapshot @@ -49,16 +65,11 @@ struct ScopeSnapshot std::vector children; }; -enum class ConstraintBlockKind -{ - TypeId, - TypePackId, - ConstraintId, -}; +using ConstraintBlockTarget = Variant>; struct ConstraintBlock { - ConstraintBlockKind kind; + ConstraintBlockTarget target; std::string stringification; }; @@ -71,16 +82,18 @@ struct ConstraintSnapshot struct BoundarySnapshot { - std::unordered_map constraints; + DenseHashMap unsolvedConstraints{nullptr}; ScopeSnapshot rootScope; + DenseHashMap typeStrings{nullptr}; }; struct StepSnapshot { - std::string currentConstraint; + const Constraint* currentConstraint; bool forced; - std::unordered_map unsolvedConstraints; + DenseHashMap unsolvedConstraints{nullptr}; ScopeSnapshot rootScope; + DenseHashMap typeStrings{nullptr}; }; struct TypeSolveLog @@ -95,8 +108,6 @@ struct TypeCheckLog std::vector errors; }; -using ConstraintBlockTarget = Variant>; - struct DcrLogger { std::string compileOutput(); @@ -104,6 +115,7 @@ struct DcrLogger void captureSource(std::string source); void captureGenerationError(const TypeError& error); void captureConstraintLocation(NotNull constraint, Location location); + void captureGenerationModule(const ModulePtr& module); void pushBlock(NotNull constraint, TypeId block); void pushBlock(NotNull constraint, TypePackId block); @@ -126,9 +138,10 @@ private: TypeSolveLog solveLog; TypeCheckLog checkLog; - ToStringOptions opts; + ToStringOptions opts{true}; std::vector snapshotBlocks(NotNull constraint); + void captureBoundaryState(BoundarySnapshot& target, const Scope* rootScope, const std::vector>& unsolvedConstraints); }; } // namespace Luau diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index a8f83e2..85a36fc 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -52,7 +52,7 @@ struct Scope std::optional lookup(Symbol sym) const; std::optional lookup(DefId def) const; - std::optional> lookupEx(Symbol sym); + std::optional> lookupEx(Symbol sym); std::optional lookupType(const Name& name); std::optional lookupImportedType(const Name& moduleAlias, const Name& name); diff --git a/Analysis/include/Luau/Symbol.h b/Analysis/include/Luau/Symbol.h index 0432946..b47554e 100644 --- a/Analysis/include/Luau/Symbol.h +++ b/Analysis/include/Luau/Symbol.h @@ -37,6 +37,11 @@ struct Symbol AstLocal* local; AstName global; + explicit operator bool() const + { + return local != nullptr || global.value != nullptr; + } + bool operator==(const Symbol& rhs) const { if (local) diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 00e6d6c..d009001 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -246,6 +246,18 @@ struct WithPredicate { T type; PredicateVec predicates; + + WithPredicate() = default; + explicit WithPredicate(T type) + : type(type) + { + } + + WithPredicate(T type, PredicateVec predicates) + : type(type) + , predicates(std::move(predicates)) + { + } }; using MagicFunction = std::function>( @@ -853,4 +865,15 @@ bool hasTag(TypeId ty, const std::string& tagName); bool hasTag(const Property& prop, const std::string& tagName); bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work. +/* + * Use this to change the kind of a particular type. + * + * LUAU_NOINLINE so that the calling frame doesn't have to pay the stack storage for the new variant. + */ +template +LUAU_NOINLINE T* emplaceType(Type* ty, Args&&... args) +{ + return &ty->ty.emplace(std::forward(args)...); +} + } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index d748a1f..678bd41 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -146,10 +146,12 @@ struct TypeChecker WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExpr& expr); WithPredicate checkExprPackHelper(const ScopePtr& scope, const AstExprCall& expr); + WithPredicate checkExprPackHelper2( + const ScopePtr& scope, const AstExprCall& expr, TypeId selfType, TypeId actualFunctionType, TypeId functionType, TypePackId retPack); std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); - std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + std::unique_ptr> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, diff --git a/Analysis/include/Luau/TypeReduction.h b/Analysis/include/Luau/TypeReduction.h index 0ad034a..80a7ac5 100644 --- a/Analysis/include/Luau/TypeReduction.h +++ b/Analysis/include/Luau/TypeReduction.h @@ -12,11 +12,36 @@ namespace Luau namespace detail { template -struct ReductionContext +struct ReductionEdge { T type = nullptr; bool irreducible = false; }; + +struct TypeReductionMemoization +{ + TypeReductionMemoization() = default; + + TypeReductionMemoization(const TypeReductionMemoization&) = delete; + TypeReductionMemoization& operator=(const TypeReductionMemoization&) = delete; + + TypeReductionMemoization(TypeReductionMemoization&&) = default; + TypeReductionMemoization& operator=(TypeReductionMemoization&&) = default; + + DenseHashMap> types{nullptr}; + DenseHashMap> typePacks{nullptr}; + + bool isIrreducible(TypeId ty); + bool isIrreducible(TypePackId tp); + + TypeId memoize(TypeId ty, TypeId reducedTy); + TypePackId memoize(TypePackId tp, TypePackId reducedTp); + + // Reducing A into B may have a non-irreducible edge A to B for which B is not irreducible, which means B could be reduced into C. + // Because reduction should always be transitive, A should point to C if A points to B and B points to C. + std::optional> memoizedof(TypeId ty) const; + std::optional> memoizedof(TypePackId tp) const; +}; } // namespace detail struct TypeReductionOptions @@ -42,29 +67,19 @@ struct TypeReduction std::optional reduce(TypePackId tp); std::optional reduce(const TypeFun& fun); - /// Creating a child TypeReduction will allow the parent TypeReduction to share its memoization with the child TypeReductions. - /// This is safe as long as the parent's TypeArena continues to outlive both TypeReduction memoization. - TypeReduction fork(NotNull arena, const TypeReductionOptions& opts = {}) const; - private: - const TypeReduction* parent = nullptr; - NotNull arena; NotNull builtinTypes; NotNull handle; - TypeReductionOptions options; - DenseHashMap> memoizedTypes{nullptr}; - DenseHashMap> memoizedTypePacks{nullptr}; + TypeReductionOptions options; + detail::TypeReductionMemoization memoization; // Computes an *estimated length* of the cartesian product of the given type. size_t cartesianProductSize(TypeId ty) const; bool hasExceededCartesianProductLimit(TypeId ty) const; bool hasExceededCartesianProductLimit(TypePackId tp) const; - - std::optional memoizedof(TypeId ty) const; - std::optional memoizedof(TypePackId tp) const; }; } // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 988ad9c..ebfff4c 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -67,6 +67,12 @@ struct Unifier UnifierSharedState& sharedState; + // When the Unifier is forced to unify two blocked types (or packs), they + // get added to these vectors. The ConstraintSolver can use this to know + // when it is safe to reattempt dispatching a constraint. + std::vector blockedTypes; + std::vector blockedTypePacks; + Unifier( NotNull normalizer, Mode mode, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index aa605bd..fe41263 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -320,6 +320,9 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) prepopulateGlobalScope(scope, block); visitBlockWithoutChildScope(scope, block); + + if (FFlag::DebugLuauLogSolverToJson) + logger->captureGenerationModule(module); } void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) @@ -357,13 +360,11 @@ void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, for (const auto& [name, gen] : createGenerics(defnScope, alias->generics, /* useCache */ true)) { initialFun.typeParams.push_back(gen); - defnScope->privateTypeBindings[name] = TypeFun{gen.ty}; } for (const auto& [name, genPack] : createGenericPacks(defnScope, alias->genericPacks, /* useCache */ true)) { initialFun.typePackParams.push_back(genPack); - defnScope->privateTypePackBindings[name] = genPack.tp; } if (alias->exported) @@ -503,13 +504,13 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (j - i < packTypes.head.size()) varTypes[j] = packTypes.head[j - i]; else - varTypes[j] = freshType(scope); + varTypes[j] = arena->addType(BlockedType{}); } } std::vector tailValues{varTypes.begin() + i, varTypes.end()}; TypePackId tailPack = arena->addTypePack(std::move(tailValues)); - addConstraint(scope, local->location, PackSubtypeConstraint{exprPack, tailPack}); + addConstraint(scope, local->location, UnpackConstraint{tailPack, exprPack}); } } } @@ -686,6 +687,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct Checkpoint start = checkpoint(this); FunctionSignature sig = checkFunctionSignature(scope, function->func); + std::unordered_set excludeList; + if (AstExprLocal* localName = function->name->as()) { std::optional existingFunctionTy = scope->lookup(localName->local); @@ -716,9 +719,20 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct } else if (AstExprIndexName* indexName = function->name->as()) { + Checkpoint check1 = checkpoint(this); TypeId lvalueType = checkLValue(scope, indexName); + Checkpoint check2 = checkpoint(this); + + forEachConstraint(check1, check2, this, [&excludeList](const ConstraintPtr& c) { + excludeList.insert(c.get()); + }); + // TODO figure out how to populate the location field of the table Property. - addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType}); + + if (get(lvalueType)) + asMutable(lvalueType)->ty.emplace(generalizedType); + else + addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType}); } else if (AstExprError* err = function->name->as()) { @@ -735,8 +749,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct std::unique_ptr c = std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{generalizedType, sig.signature}); - forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { - c->dependencies.push_back(NotNull{constraint.get()}); + forEachConstraint(start, end, this, [&c, &excludeList](const ConstraintPtr& constraint) { + if (!excludeList.count(constraint.get())) + c->dependencies.push_back(NotNull{constraint.get()}); }); addConstraint(scope, std::move(c)); @@ -763,16 +778,31 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) visitBlockWithoutChildScope(innerScope, block); } +static void bindFreeType(TypeId a, TypeId b) +{ + FreeType* af = getMutable(a); + FreeType* bf = getMutable(b); + + LUAU_ASSERT(af || bf); + + if (!bf) + asMutable(a)->ty.emplace(b); + else if (!af) + asMutable(b)->ty.emplace(a); + else if (subsumes(bf->scope, af->scope)) + asMutable(a)->ty.emplace(b); + else if (subsumes(af->scope, bf->scope)) + asMutable(b)->ty.emplace(a); +} + void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) { - TypePackId varPackId = checkLValues(scope, assign->vars); - - TypePack expectedPack = extendTypePack(*arena, builtinTypes, varPackId, assign->values.size); + std::vector varTypes = checkLValues(scope, assign->vars); std::vector> expectedTypes; - expectedTypes.reserve(expectedPack.head.size()); + expectedTypes.reserve(varTypes.size()); - for (TypeId ty : expectedPack.head) + for (TypeId ty : varTypes) { ty = follow(ty); if (get(ty)) @@ -781,9 +811,10 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) expectedTypes.push_back(ty); } - TypePackId valuePack = checkPack(scope, assign->values, expectedTypes).tp; + TypePackId exprPack = checkPack(scope, assign->values, expectedTypes).tp; + TypePackId varPack = arena->addTypePack({varTypes}); - addConstraint(scope, assign->location, PackSubtypeConstraint{valuePack, varPackId}); + addConstraint(scope, assign->location, PackSubtypeConstraint{exprPack, varPack}); } void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) @@ -865,11 +896,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alia asMutable(aliasTy)->ty.emplace(ty); std::vector typeParams; - for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true)) + for (auto tyParam : createGenerics(*defnScope, alias->generics, /* useCache */ true, /* addTypes */ false)) typeParams.push_back(tyParam.second.ty); std::vector typePackParams; - for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true)) + for (auto tpParam : createGenericPacks(*defnScope, alias->genericPacks, /* useCache */ true, /* addTypes */ false)) typePackParams.push_back(tpParam.second.tp); addConstraint(scope, alias->type->location, @@ -1010,7 +1041,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction for (auto& [name, generic] : generics) { genericTys.push_back(generic.ty); - scope->privateTypeBindings[name] = TypeFun{generic.ty}; } std::vector genericTps; @@ -1018,7 +1048,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction for (auto& [name, generic] : genericPacks) { genericTps.push_back(generic.tp); - scope->privateTypePackBindings[name] = generic.tp; } ScopePtr funScope = scope; @@ -1161,7 +1190,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa TypePackId expectedArgPack = arena->freshTypePack(scope.get()); TypePackId expectedRetPack = arena->freshTypePack(scope.get()); - TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack}); + TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack, std::nullopt, call->self}); TypeId instantiatedFnType = arena->addType(BlockedType{}); addConstraint(scope, call->location, InstantiationConstraint{instantiatedFnType, fnType}); @@ -1264,7 +1293,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa // TODO: How do expectedTypes play into this? Do they? TypePackId rets = arena->addTypePack(BlockedTypePack{}); TypePackId argPack = arena->addTypePack(TypePack{args, argTail}); - FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets); + FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self); NotNull fcc = addConstraint(scope, call->func->location, FunctionCallConstraint{ @@ -1457,7 +1486,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* gl Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) { TypeId obj = check(scope, indexName->expr).ty; - TypeId result = freshType(scope); + TypeId result = arena->addType(BlockedType{}); std::optional def = dfg->getDef(indexName); if (def) @@ -1468,13 +1497,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* scope->dcrRefinements[*def] = result; } - TableType::Props props{{indexName->index.value, Property{result}}}; - const std::optional indexer; - TableType ttv{std::move(props), indexer, TypeLevel{}, scope.get(), TableState::Free}; - - TypeId expectedTableType = arena->addType(std::move(ttv)); - - addConstraint(scope, indexName->expr->location, SubtypeConstraint{obj, expectedTableType}); + addConstraint(scope, indexName->expr->location, HasPropConstraint{result, obj, indexName->index.value}); if (def) return Inference{result, refinementArena.proposition(*def, builtinTypes->truthyType)}; @@ -1589,6 +1612,8 @@ std::tuple ConstraintGraphBuilder::checkBinary( else if (typeguard->type == "number") discriminantTy = builtinTypes->numberType; else if (typeguard->type == "boolean") + discriminantTy = builtinTypes->booleanType; + else if (typeguard->type == "thread") discriminantTy = builtinTypes->threadType; else if (typeguard->type == "table") discriminantTy = builtinTypes->tableType; @@ -1596,8 +1621,8 @@ std::tuple ConstraintGraphBuilder::checkBinary( discriminantTy = builtinTypes->functionType; else if (typeguard->type == "userdata") { - // For now, we don't really care about being accurate with userdata if the typeguard was using typeof - discriminantTy = builtinTypes->neverType; // TODO: replace with top class type + // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. + discriminantTy = builtinTypes->classType; } else if (!typeguard->isTypeof && typeguard->type == "vector") discriminantTy = builtinTypes->neverType; // TODO: figure out a way to deal with this quirky type @@ -1649,18 +1674,15 @@ std::tuple ConstraintGraphBuilder::checkBinary( } } -TypePackId ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) +std::vector ConstraintGraphBuilder::checkLValues(const ScopePtr& scope, AstArray exprs) { std::vector types; types.reserve(exprs.size); - for (size_t i = 0; i < exprs.size; ++i) - { - AstExpr* const expr = exprs.data[i]; + for (AstExpr* expr : exprs) types.push_back(checkLValue(scope, expr)); - } - return arena->addTypePack(std::move(types)); + return types; } /** @@ -1679,6 +1701,28 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) indexExpr->location, indexExpr->expr, syntheticIndex, constantString->location, indexExpr->expr->location.end, '.'}; return checkLValue(scope, &synthetic); } + + // An indexer is only interesting in an lvalue-ey way if it is at the + // tail of an expression. + // + // If the indexer is not at the tail, then we are not interested in + // augmenting the lhs data structure with a new indexer. Constraint + // generation can treat it as an ordinary lvalue. + // + // eg + // + // a.b.c[1] = 44 -- lvalue + // a.b[4].c = 2 -- rvalue + + TypeId resultType = arena->addType(BlockedType{}); + TypeId subjectType = check(scope, indexExpr->expr).ty; + TypeId indexType = check(scope, indexExpr->index).ty; + TypeId propType = arena->addType(BlockedType{}); + addConstraint(scope, expr->location, SetIndexerConstraint{resultType, subjectType, indexType, propType}); + + module->astTypes[expr] = propType; + + return propType; } else if (!expr->is()) return check(scope, expr).ty; @@ -1718,7 +1762,8 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) auto lookupResult = scope->lookupEx(sym); if (!lookupResult) return check(scope, expr).ty; - const auto [subjectType, symbolScope] = std::move(*lookupResult); + const auto [subjectBinding, symbolScope] = std::move(*lookupResult); + TypeId subjectType = subjectBinding->typeId; TypeId propTy = freshType(scope); @@ -1739,14 +1784,17 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) module->astTypes[expr] = prevSegmentTy; module->astTypes[e] = updatedType; - symbolScope->bindings[sym].typeId = updatedType; - - std::optional def = dfg->getDef(sym); - if (def) + if (!subjectType->persistent) { - // This can fail if the user is erroneously trying to augment a builtin - // table like os or string. - symbolScope->dcrRefinements[*def] = updatedType; + symbolScope->bindings[sym].typeId = updatedType; + + std::optional def = dfg->getDef(sym); + if (def) + { + // This can fail if the user is erroneously trying to augment a builtin + // table like os or string. + symbolScope->dcrRefinements[*def] = updatedType; + } } return propTy; @@ -1904,13 +1952,11 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS for (const auto& [name, g] : genericDefinitions) { genericTypes.push_back(g.ty); - signatureScope->privateTypeBindings[name] = TypeFun{g.ty}; } for (const auto& [name, g] : genericPackDefinitions) { genericTypePacks.push_back(g.tp); - signatureScope->privateTypePackBindings[name] = g.tp; } // Local variable works around an odd gcc 11.3 warning: may be used uninitialized @@ -2023,15 +2069,14 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS actualFunction.generics = std::move(genericTypes); actualFunction.genericPacks = std::move(genericTypePacks); actualFunction.argNames = std::move(argNames); + actualFunction.hasSelf = fn->self != nullptr; TypeId actualFunctionType = arena->addType(std::move(actualFunction)); LUAU_ASSERT(actualFunctionType); module->astTypes[fn] = actualFunctionType; if (expectedType && get(*expectedType)) - { - asMutable(*expectedType)->ty.emplace(actualFunctionType); - } + bindFreeType(*expectedType, actualFunctionType); return { /* signature */ actualFunctionType, @@ -2179,13 +2224,11 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b for (const auto& [name, g] : genericDefinitions) { genericTypes.push_back(g.ty); - signatureScope->privateTypeBindings[name] = TypeFun{g.ty}; } for (const auto& [name, g] : genericPackDefinitions) { genericTypePacks.push_back(g.tp); - signatureScope->privateTypePackBindings[name] = g.tp; } } else @@ -2330,7 +2373,7 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(const ScopePtr& scope, const } std::vector> ConstraintGraphBuilder::createGenerics( - const ScopePtr& scope, AstArray generics, bool useCache) + const ScopePtr& scope, AstArray generics, bool useCache, bool addTypes) { std::vector> result; for (const auto& generic : generics) @@ -2350,6 +2393,9 @@ std::vector> ConstraintGraphBuilder::crea if (generic.defaultValue) defaultTy = resolveType(scope, generic.defaultValue, /* inTypeArguments */ false); + if (addTypes) + scope->privateTypeBindings[generic.name.value] = TypeFun{genericTy}; + result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}}); } @@ -2357,7 +2403,7 @@ std::vector> ConstraintGraphBuilder::crea } std::vector> ConstraintGraphBuilder::createGenericPacks( - const ScopePtr& scope, AstArray generics, bool useCache) + const ScopePtr& scope, AstArray generics, bool useCache, bool addTypes) { std::vector> result; for (const auto& generic : generics) @@ -2378,6 +2424,9 @@ std::vector> ConstraintGraphBuilder:: if (generic.defaultValue) defaultTy = resolveTypePack(scope, generic.defaultValue, /* inTypeArguments */ false); + if (addTypes) + scope->privateTypePackBindings[generic.name.value] = genericTy; + result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}}); } @@ -2394,11 +2443,9 @@ Inference ConstraintGraphBuilder::flattenPack(const ScopePtr& scope, Location lo if (auto f = first(tp)) return Inference{*f, refinement}; - TypeId typeResult = freshType(scope); - TypePack onePack{{typeResult}, freshTypePack(scope)}; - TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); - - addConstraint(scope, location, PackSubtypeConstraint{tp, oneTypePack}); + TypeId typeResult = arena->addType(BlockedType{}); + TypePackId resultPack = arena->addTypePack({typeResult}, arena->freshTypePack(scope.get())); + addConstraint(scope, location, UnpackConstraint{resultPack, tp}); return Inference{typeResult, refinement}; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 879dac3..96673e3 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -22,6 +22,22 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau { +size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const +{ + size_t result = 0; + + if (const TypeId* ty = get_if(&bci)) + result = std::hash()(*ty); + else if (const TypePackId* tp = get_if(&bci)) + result = std::hash()(*tp); + else if (Constraint const* const* c = get_if(&bci)) + result = std::hash()(*c); + else + LUAU_ASSERT(!"Should be unreachable"); + + return result; +} + [[maybe_unused]] static void dumpBindings(NotNull scope, ToStringOptions& opts) { for (const auto& [k, v] : scope->bindings) @@ -221,10 +237,12 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) } ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, - ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) + ModuleName moduleName, NotNull reducer, NotNull moduleResolver, std::vector requireCycles, + DcrLogger* logger) : arena(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) + , reducer(reducer) , constraints(std::move(constraints)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) @@ -326,6 +344,27 @@ void ConstraintSolver::run() if (force) printf("Force "); printf("Dispatched\n\t%s\n", saveMe.c_str()); + + if (force) + { + printf("Blocked on:\n"); + + for (const auto& [bci, cv] : blocked) + { + if (end(cv) == std::find(begin(cv), end(cv), c)) + continue; + + if (auto bty = get_if(&bci)) + printf("\tType %s\n", toString(*bty, opts).c_str()); + else if (auto btp = get_if(&bci)) + printf("\tPack %s\n", toString(*btp, opts).c_str()); + else if (auto cc = get_if(&bci)) + printf("\tCons %s\n", toString(**cc, opts).c_str()); + else + LUAU_ASSERT(!"Unreachable??"); + } + } + dump(this, opts); } } @@ -411,8 +450,12 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*hpc, constraint); else if (auto spc = get(*constraint)) success = tryDispatch(*spc, constraint, force); + else if (auto spc = get(*constraint)) + success = tryDispatch(*spc, constraint, force); else if (auto sottc = get(*constraint)) success = tryDispatch(*sottc, constraint); + else if (auto uc = get(*constraint)) + success = tryDispatch(*uc, constraint); else LUAU_ASSERT(false); @@ -424,26 +467,46 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) { - if (!recursiveBlock(c.subType, constraint)) - return false; - if (!recursiveBlock(c.superType, constraint)) - return false; - if (isBlocked(c.subType)) return block(c.subType, constraint); else if (isBlocked(c.superType)) return block(c.superType, constraint); - unify(c.subType, c.superType, constraint->scope); + Unifier u{normalizer, Mode::Strict, constraint->scope, Location{}, Covariant}; + u.useScopes = true; + + u.tryUnify(c.subType, c.superType); + + if (!u.blockedTypes.empty() || !u.blockedTypePacks.empty()) + { + for (TypeId bt : u.blockedTypes) + block(bt, constraint); + for (TypePackId btp : u.blockedTypePacks) + block(btp, constraint); + return false; + } + + if (!u.errors.empty()) + { + TypeId errorType = errorRecoveryType(); + u.tryUnify(c.subType, errorType); + u.tryUnify(c.superType, errorType); + } + + const auto [changedTypes, changedPacks] = u.log.getChanges(); + + u.log.commit(); + + unblock(changedTypes); + unblock(changedPacks); + + // unify(c.subType, c.superType, constraint->scope); return true; } bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) { - if (!recursiveBlock(c.subPack, constraint) || !recursiveBlock(c.superPack, constraint)) - return false; - if (isBlocked(c.subPack)) return block(c.subPack, constraint); else if (isBlocked(c.superPack)) @@ -1183,8 +1246,26 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdType(BlockedType{}); TypeId inferredTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope.get(), argsPack, c.result}); - auto ic = pushConstraint(constraint->scope, constraint->location, InstantiationConstraint{instantiatedTy, fn}); - auto sc = pushConstraint(constraint->scope, constraint->location, SubtypeConstraint{instantiatedTy, inferredTy}); + auto pushConstraintGreedy = [this, constraint](ConstraintV cv) -> Constraint* { + std::unique_ptr c = std::make_unique(constraint->scope, constraint->location, std::move(cv)); + NotNull borrow{c.get()}; + + bool ok = tryDispatch(borrow, false); + if (ok) + return nullptr; + + solverConstraints.push_back(std::move(c)); + unsolvedConstraints.push_back(borrow); + + return borrow; + }; + + // HACK: We don't want other constraints to act on the free type pack + // created above until after these two constraints are solved, so we try to + // dispatch them directly. + + auto ic = pushConstraintGreedy(InstantiationConstraint{instantiatedTy, fn}); + auto sc = pushConstraintGreedy(SubtypeConstraint{instantiatedTy, inferredTy}); // Anything that is blocked on this constraint must also be blocked on our // synthesized constraints. @@ -1193,8 +1274,10 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullsecond) { - block(ic, blockedConstraint); - block(sc, blockedConstraint); + if (ic) + block(NotNull{ic}, blockedConstraint); + if (sc) + block(NotNull{sc}, blockedConstraint); } } @@ -1230,6 +1313,8 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullreduce(subjectType).value_or(subjectType); + std::optional resultType = lookupTableProp(subjectType, c.prop); if (!resultType) { @@ -1360,11 +1445,18 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNullscope); + if (!isBlocked(c.propType)) + unify(c.propType, *existingPropType, constraint->scope); bind(c.resultType, c.subjectType); return true; } + if (get(subjectType) || get(subjectType) || get(subjectType)) + { + bind(c.resultType, subjectType); + return true; + } + if (get(subjectType)) { TypeId ty = arena->freshType(constraint->scope); @@ -1381,21 +1473,27 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) { if (ttv->state == TableState::Free) { + LUAU_ASSERT(!subjectType->persistent); + ttv->props[c.path[0]] = Property{c.propType}; bind(c.resultType, c.subjectType); return true; } else if (ttv->state == TableState::Unsealed) { + LUAU_ASSERT(!subjectType->persistent); + std::optional augmented = updateTheTableType(NotNull{arena}, subjectType, c.path, c.propType); bind(c.resultType, augmented.value_or(subjectType)); + bind(subjectType, c.resultType); return true; } else @@ -1411,16 +1509,62 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType) || get(subjectType) || get(subjectType)) - { - bind(c.resultType, subjectType); - return true; - } LUAU_ASSERT(0); return true; } +bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force) +{ + TypeId subjectType = follow(c.subjectType); + if (isBlocked(subjectType)) + return block(subjectType, constraint); + + if (auto ft = get(subjectType)) + { + Scope* scope = ft->scope; + TableType* tt = &asMutable(subjectType)->ty.emplace(TableState::Free, TypeLevel{}, scope); + tt->indexer = TableIndexer{c.indexType, c.propType}; + + asMutable(c.resultType)->ty.emplace(subjectType); + asMutable(c.propType)->ty.emplace(scope); + unblock(c.propType); + unblock(c.resultType); + + return true; + } + else if (auto tt = get(subjectType)) + { + if (tt->indexer) + { + // TODO This probably has to be invariant. + unify(c.indexType, tt->indexer->indexType, constraint->scope); + asMutable(c.propType)->ty.emplace(tt->indexer->indexResultType); + asMutable(c.resultType)->ty.emplace(subjectType); + unblock(c.propType); + unblock(c.resultType); + return true; + } + else if (tt->state == TableState::Free || tt->state == TableState::Unsealed) + { + auto mtt = getMutable(subjectType); + mtt->indexer = TableIndexer{c.indexType, c.propType}; + asMutable(c.propType)->ty.emplace(tt->scope); + asMutable(c.resultType)->ty.emplace(subjectType); + unblock(c.propType); + unblock(c.resultType); + return true; + } + // Do not augment sealed or generic tables that lack indexers + } + + asMutable(c.propType)->ty.emplace(builtinTypes->errorRecoveryType()); + asMutable(c.resultType)->ty.emplace(builtinTypes->errorRecoveryType()); + unblock(c.propType); + unblock(c.resultType); + return true; +} + bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNull constraint) { if (isBlocked(c.discriminantType)) @@ -1439,6 +1583,69 @@ bool ConstraintSolver::tryDispatch(const SingletonOrTopTypeConstraint& c, NotNul return true; } +bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNull constraint) +{ + TypePackId sourcePack = follow(c.sourcePack); + TypePackId resultPack = follow(c.resultPack); + + if (isBlocked(sourcePack)) + return block(sourcePack, constraint); + + if (isBlocked(resultPack)) + { + asMutable(resultPack)->ty.emplace(sourcePack); + unblock(resultPack); + return true; + } + + TypePack srcPack = extendTypePack(*arena, builtinTypes, sourcePack, size(resultPack)); + + auto destIter = begin(resultPack); + auto destEnd = end(resultPack); + + size_t i = 0; + while (destIter != destEnd) + { + if (i >= srcPack.head.size()) + break; + TypeId srcTy = follow(srcPack.head[i]); + + if (isBlocked(*destIter)) + { + if (follow(srcTy) == *destIter) + { + // Cyclic type dependency. (????) + asMutable(*destIter)->ty.emplace(constraint->scope); + } + else + asMutable(*destIter)->ty.emplace(srcTy); + unblock(*destIter); + } + else + unify(*destIter, srcTy, constraint->scope); + + ++destIter; + ++i; + } + + // We know that resultPack does not have a tail, but we don't know if + // sourcePack is long enough to fill every value. Replace every remaining + // result TypeId with the error recovery type. + + while (destIter != destEnd) + { + if (isBlocked(*destIter)) + { + asMutable(*destIter)->ty.emplace(builtinTypes->errorRecoveryType()); + unblock(*destIter); + } + + ++destIter; + } + + return true; +} + bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) { auto block_ = [&](auto&& t) { @@ -1628,10 +1835,20 @@ bool ConstraintSolver::tryDispatchIterableFunction( std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName) { + std::unordered_set seen; + return lookupTableProp(subjectType, propName, seen); +} + +std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, const std::string& propName, std::unordered_set& seen) +{ + if (!seen.insert(subjectType).second) + return std::nullopt; + auto collectParts = [&](auto&& unionOrIntersection) -> std::pair, std::vector> { std::optional blocked; std::vector parts; + std::vector freeParts; for (TypeId expectedPart : unionOrIntersection) { expectedPart = follow(expectedPart); @@ -1644,6 +1861,29 @@ std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, cons else if (ttv->indexer && maybeString(ttv->indexer->indexType)) parts.push_back(ttv->indexer->indexResultType); } + else if (get(expectedPart)) + { + freeParts.push_back(expectedPart); + } + } + + // If the only thing resembling a match is a single fresh type, we can + // confidently tablify it. If other types match or if there are more + // than one free type, we can't do anything. + if (parts.empty() && 1 == freeParts.size()) + { + TypeId freePart = freeParts.front(); + const FreeType* ft = get(freePart); + LUAU_ASSERT(ft); + Scope* scope = ft->scope; + + TableType* tt = &asMutable(freePart)->ty.emplace(); + tt->state = TableState::Free; + tt->scope = scope; + TypeId propType = arena->freshType(scope); + tt->props[propName] = Property{propType}; + + parts.push_back(propType); } return {blocked, parts}; @@ -1651,12 +1891,75 @@ std::optional ConstraintSolver::lookupTableProp(TypeId subjectType, cons std::optional resultType; - if (auto ttv = get(subjectType)) + if (get(subjectType) || get(subjectType)) + { + return subjectType; + } + else if (auto ttv = getMutable(subjectType)) { if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) resultType = prop->second.type; else if (ttv->indexer && maybeString(ttv->indexer->indexType)) resultType = ttv->indexer->indexResultType; + else if (ttv->state == TableState::Free) + { + resultType = arena->addType(FreeType{ttv->scope}); + ttv->props[propName] = Property{*resultType}; + } + } + else if (auto mt = get(subjectType)) + { + if (auto p = lookupTableProp(mt->table, propName, seen)) + return p; + + TypeId mtt = follow(mt->metatable); + + if (get(mtt)) + return mtt; + else if (auto metatable = get(mtt)) + { + auto indexProp = metatable->props.find("__index"); + if (indexProp == metatable->props.end()) + return std::nullopt; + + // TODO: __index can be an overloaded function. + + TypeId indexType = follow(indexProp->second.type); + + if (auto ft = get(indexType)) + { + std::optional ret = first(ft->retTypes); + if (ret) + return *ret; + else + return std::nullopt; + } + + return lookupTableProp(indexType, propName, seen); + } + } + else if (auto ct = get(subjectType)) + { + while (ct) + { + if (auto prop = ct->props.find(propName); prop != ct->props.end()) + return prop->second.type; + else if (ct->parent) + ct = get(follow(*ct->parent)); + else + break; + } + } + else if (auto pt = get(subjectType); pt && pt->metatable) + { + const TableType* metatable = get(follow(*pt->metatable)); + LUAU_ASSERT(metatable); + + auto indexProp = metatable->props.find("__index"); + if (indexProp == metatable->props.end()) + return std::nullopt; + + return lookupTableProp(indexProp->second.type, propName, seen); } else if (auto utv = get(subjectType)) { @@ -1704,7 +2007,7 @@ void ConstraintSolver::block(NotNull target, NotNull constraint) @@ -1715,7 +2018,7 @@ bool ConstraintSolver::block(TypeId target, NotNull constraint if (FFlag::DebugLuauLogSolver) printf("block TypeId %s on\t%s\n", toString(target, opts).c_str(), toString(*constraint, opts).c_str()); - block_(target, constraint); + block_(follow(target), constraint); return false; } @@ -1802,7 +2105,7 @@ void ConstraintSolver::unblock(NotNull progressed) if (FFlag::DebugLuauLogSolverToJson) logger->popBlock(progressed); - return unblock_(progressed); + return unblock_(progressed.get()); } void ConstraintSolver::unblock(TypeId progressed) @@ -1810,7 +2113,10 @@ void ConstraintSolver::unblock(TypeId progressed) if (FFlag::DebugLuauLogSolverToJson) logger->popBlock(progressed); - return unblock_(progressed); + unblock_(progressed); + + if (auto bt = get(progressed)) + unblock(bt->boundTo); } void ConstraintSolver::unblock(TypePackId progressed) diff --git a/Analysis/src/DcrLogger.cpp b/Analysis/src/DcrLogger.cpp index a1ef650..9f66b02 100644 --- a/Analysis/src/DcrLogger.cpp +++ b/Analysis/src/DcrLogger.cpp @@ -9,17 +9,39 @@ namespace Luau { +template +static std::string toPointerId(const T* ptr) +{ + return std::to_string(reinterpret_cast(ptr)); +} + +static std::string toPointerId(NotNull ptr) +{ + return std::to_string(reinterpret_cast(ptr.get())); +} + namespace Json { +template +void write(JsonEmitter& emitter, const T* ptr) +{ + write(emitter, toPointerId(ptr)); +} + +void write(JsonEmitter& emitter, NotNull ptr) +{ + write(emitter, toPointerId(ptr)); +} + void write(JsonEmitter& emitter, const Location& location) { - ObjectEmitter o = emitter.writeObject(); - o.writePair("beginLine", location.begin.line); - o.writePair("beginColumn", location.begin.column); - o.writePair("endLine", location.end.line); - o.writePair("endColumn", location.end.column); - o.finish(); + ArrayEmitter a = emitter.writeArray(); + a.writeValue(location.begin.line); + a.writeValue(location.begin.column); + a.writeValue(location.end.line); + a.writeValue(location.end.column); + a.finish(); } void write(JsonEmitter& emitter, const ErrorSnapshot& snapshot) @@ -47,24 +69,43 @@ void write(JsonEmitter& emitter, const TypeBindingSnapshot& snapshot) o.finish(); } +template +void write(JsonEmitter& emitter, const DenseHashMap& map) +{ + ObjectEmitter o = emitter.writeObject(); + for (const auto& [k, v] : map) + o.writePair(toPointerId(k), v); + o.finish(); +} + +void write(JsonEmitter& emitter, const ExprTypesAtLocation& tys) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("location", tys.location); + o.writePair("ty", toPointerId(tys.ty)); + + if (tys.expectedTy) + o.writePair("expectedTy", toPointerId(*tys.expectedTy)); + + o.finish(); +} + +void write(JsonEmitter& emitter, const AnnotationTypesAtLocation& tys) +{ + ObjectEmitter o = emitter.writeObject(); + o.writePair("location", tys.location); + o.writePair("resolvedTy", toPointerId(tys.resolvedTy)); + o.finish(); +} + void write(JsonEmitter& emitter, const ConstraintGenerationLog& log) { ObjectEmitter o = emitter.writeObject(); o.writePair("source", log.source); - - emitter.writeComma(); - write(emitter, "constraintLocations"); - emitter.writeRaw(":"); - - ObjectEmitter locationEmitter = emitter.writeObject(); - - for (const auto& [id, location] : log.constraintLocations) - { - locationEmitter.writePair(id, location); - } - - locationEmitter.finish(); o.writePair("errors", log.errors); + o.writePair("exprTypeLocations", log.exprTypeLocations); + o.writePair("annotationTypeLocations", log.annotationTypeLocations); + o.finish(); } @@ -78,26 +119,34 @@ void write(JsonEmitter& emitter, const ScopeSnapshot& snapshot) o.finish(); } -void write(JsonEmitter& emitter, const ConstraintBlockKind& kind) -{ - switch (kind) - { - case ConstraintBlockKind::TypeId: - return write(emitter, "type"); - case ConstraintBlockKind::TypePackId: - return write(emitter, "typePack"); - case ConstraintBlockKind::ConstraintId: - return write(emitter, "constraint"); - default: - LUAU_ASSERT(0); - } -} - void write(JsonEmitter& emitter, const ConstraintBlock& block) { ObjectEmitter o = emitter.writeObject(); - o.writePair("kind", block.kind); o.writePair("stringification", block.stringification); + + auto go = [&o](auto&& t) { + using T = std::decay_t; + + o.writePair("id", toPointerId(t)); + + if constexpr (std::is_same_v) + { + o.writePair("kind", "type"); + } + else if constexpr (std::is_same_v) + { + o.writePair("kind", "typePack"); + } + else if constexpr (std::is_same_v>) + { + o.writePair("kind", "constraint"); + } + else + static_assert(always_false_v, "non-exhaustive possibility switch"); + }; + + visit(go, block.target); + o.finish(); } @@ -114,7 +163,8 @@ void write(JsonEmitter& emitter, const BoundarySnapshot& snapshot) { ObjectEmitter o = emitter.writeObject(); o.writePair("rootScope", snapshot.rootScope); - o.writePair("constraints", snapshot.constraints); + o.writePair("unsolvedConstraints", snapshot.unsolvedConstraints); + o.writePair("typeStrings", snapshot.typeStrings); o.finish(); } @@ -125,6 +175,7 @@ void write(JsonEmitter& emitter, const StepSnapshot& snapshot) o.writePair("forced", snapshot.forced); o.writePair("unsolvedConstraints", snapshot.unsolvedConstraints); o.writePair("rootScope", snapshot.rootScope); + o.writePair("typeStrings", snapshot.typeStrings); o.finish(); } @@ -146,11 +197,6 @@ void write(JsonEmitter& emitter, const TypeCheckLog& log) } // namespace Json -static std::string toPointerId(NotNull ptr) -{ - return std::to_string(reinterpret_cast(ptr.get())); -} - static ScopeSnapshot snapshotScope(const Scope* scope, ToStringOptions& opts) { std::unordered_map bindings; @@ -230,6 +276,32 @@ void DcrLogger::captureSource(std::string source) generationLog.source = std::move(source); } +void DcrLogger::captureGenerationModule(const ModulePtr& module) +{ + generationLog.exprTypeLocations.reserve(module->astTypes.size()); + for (const auto& [expr, ty] : module->astTypes) + { + ExprTypesAtLocation tys; + tys.location = expr->location; + tys.ty = ty; + + if (auto expectedTy = module->astExpectedTypes.find(expr)) + tys.expectedTy = *expectedTy; + + generationLog.exprTypeLocations.push_back(tys); + } + + generationLog.annotationTypeLocations.reserve(module->astResolvedTypes.size()); + for (const auto& [annot, ty] : module->astResolvedTypes) + { + AnnotationTypesAtLocation tys; + tys.location = annot->location; + tys.resolvedTy = ty; + + generationLog.annotationTypeLocations.push_back(tys); + } +} + void DcrLogger::captureGenerationError(const TypeError& error) { std::string stringifiedError = toString(error); @@ -239,12 +311,6 @@ void DcrLogger::captureGenerationError(const TypeError& error) }); } -void DcrLogger::captureConstraintLocation(NotNull constraint, Location location) -{ - std::string id = toPointerId(constraint); - generationLog.constraintLocations[id] = location; -} - void DcrLogger::pushBlock(NotNull constraint, TypeId block) { constraintBlocks[constraint].push_back(block); @@ -284,44 +350,70 @@ void DcrLogger::popBlock(NotNull block) } } -void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints) +static void snapshotTypeStrings(const std::vector& interestedExprs, + const std::vector& interestedAnnots, DenseHashMap& map, ToStringOptions& opts) { - solveLog.initialState.rootScope = snapshotScope(rootScope, opts); - solveLog.initialState.constraints.clear(); + for (const ExprTypesAtLocation& tys : interestedExprs) + { + map[tys.ty] = toString(tys.ty, opts); + + if (tys.expectedTy) + map[*tys.expectedTy] = toString(*tys.expectedTy, opts); + } + + for (const AnnotationTypesAtLocation& tys : interestedAnnots) + { + map[tys.resolvedTy] = toString(tys.resolvedTy, opts); + } +} + +void DcrLogger::captureBoundaryState( + BoundarySnapshot& target, const Scope* rootScope, const std::vector>& unsolvedConstraints) +{ + target.rootScope = snapshotScope(rootScope, opts); + target.unsolvedConstraints.clear(); for (NotNull c : unsolvedConstraints) { - std::string id = toPointerId(c); - solveLog.initialState.constraints[id] = { + target.unsolvedConstraints[c.get()] = { toString(*c.get(), opts), c->location, snapshotBlocks(c), }; } + + snapshotTypeStrings(generationLog.exprTypeLocations, generationLog.annotationTypeLocations, target.typeStrings, opts); +} + +void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints) +{ + captureBoundaryState(solveLog.initialState, rootScope, unsolvedConstraints); } StepSnapshot DcrLogger::prepareStepSnapshot( const Scope* rootScope, NotNull current, bool force, const std::vector>& unsolvedConstraints) { ScopeSnapshot scopeSnapshot = snapshotScope(rootScope, opts); - std::string currentId = toPointerId(current); - std::unordered_map constraints; + DenseHashMap constraints{nullptr}; for (NotNull c : unsolvedConstraints) { - std::string id = toPointerId(c); - constraints[id] = { + constraints[c.get()] = { toString(*c.get(), opts), c->location, snapshotBlocks(c), }; } + DenseHashMap typeStrings{nullptr}; + snapshotTypeStrings(generationLog.exprTypeLocations, generationLog.annotationTypeLocations, typeStrings, opts); + return StepSnapshot{ - currentId, + current, force, - constraints, + std::move(constraints), scopeSnapshot, + std::move(typeStrings), }; } @@ -332,18 +424,7 @@ void DcrLogger::commitStepSnapshot(StepSnapshot snapshot) void DcrLogger::captureFinalSolverState(const Scope* rootScope, const std::vector>& unsolvedConstraints) { - solveLog.finalState.rootScope = snapshotScope(rootScope, opts); - solveLog.finalState.constraints.clear(); - - for (NotNull c : unsolvedConstraints) - { - std::string id = toPointerId(c); - solveLog.finalState.constraints[id] = { - toString(*c.get(), opts), - c->location, - snapshotBlocks(c), - }; - } + captureBoundaryState(solveLog.finalState, rootScope, unsolvedConstraints); } void DcrLogger::captureTypeCheckError(const TypeError& error) @@ -370,21 +451,21 @@ std::vector DcrLogger::snapshotBlocks(NotNull if (const TypeId* ty = get_if(&target)) { snapshot.push_back({ - ConstraintBlockKind::TypeId, + *ty, toString(*ty, opts), }); } else if (const TypePackId* tp = get_if(&target)) { snapshot.push_back({ - ConstraintBlockKind::TypePackId, + *tp, toString(*tp, opts), }); } else if (const NotNull* c = get_if>(&target)) { snapshot.push_back({ - ConstraintBlockKind::ConstraintId, + *c, toString(*(c->get()), opts), }); } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index fb61b4a..91c72e4 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -899,8 +899,8 @@ ModulePtr check( cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); - ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, moduleResolver, - requireCycles, logger.get()}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, + NotNull{result->reduction.get()}, moduleResolver, requireCycles, logger.get()}; if (options.randomizeConstraintResolutionSeed) cs.randomize(*options.randomizeConstraintResolutionSeed); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index a7b2b72..0b76081 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -1441,6 +1441,8 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor if (!unionNormals(here, *tn)) return false; } + else if (get(there)) + LUAU_ASSERT(!"Internal error: Trying to normalize a BlockedType"); else LUAU_ASSERT(!"Unreachable"); diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index aac7864..845ae3a 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -183,7 +183,7 @@ struct PureQuantifier : Substitution else if (ttv->state == TableState::Generic) seenGenericType = true; - return ttv->state == TableState::Unsealed || (ttv->state == TableState::Free && subsumes(scope, ttv->scope)); + return (ttv->state == TableState::Unsealed || ttv->state == TableState::Free) && subsumes(scope, ttv->scope); } return false; diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 84925f7..cac7212 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -31,12 +31,12 @@ std::optional Scope::lookup(Symbol sym) const { auto r = const_cast(this)->lookupEx(sym); if (r) - return r->first; + return r->first->typeId; else return std::nullopt; } -std::optional> Scope::lookupEx(Symbol sym) +std::optional> Scope::lookupEx(Symbol sym) { Scope* s = this; @@ -44,7 +44,7 @@ std::optional> Scope::lookupEx(Symbol sym) { auto it = s->bindings.find(sym); if (it != s->bindings.end()) - return std::pair{it->second.typeId, s}; + return std::pair{&it->second, s}; if (s->parent) s = s->parent.get(); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 1972177..d0c5398 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1533,6 +1533,10 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) const std::string pathStr = c.path.size() == 1 ? "\"" + c.path[0] + "\"" : "[\"" + join(c.path, "\", \"") + "\"]"; return tos(c.resultType) + " ~ setProp " + tos(c.subjectType) + ", " + pathStr + " " + tos(c.propType); } + else if constexpr (std::is_same_v) + { + return tos(c.resultType) + " ~ setIndexer " + tos(c.subjectType) + " [ " + tos(c.indexType) + " ] " + tos(c.propType); + } else if constexpr (std::is_same_v) { std::string result = tos(c.resultType); @@ -1543,6 +1547,8 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) else return result + " ~ if isSingleton D then D else unknown where D = " + discriminant; } + else if constexpr (std::is_same_v) + return tos(c.resultPack) + " ~ unpack " + tos(c.sourcePack); else static_assert(always_false_v, "Non-exhaustive constraint switch"); }; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 4322a0d..f23fad7 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -4,6 +4,7 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" #include "Luau/Clone.h" +#include "Luau/Common.h" #include "Luau/DcrLogger.h" #include "Luau/Error.h" #include "Luau/Instantiation.h" @@ -329,11 +330,12 @@ struct TypeChecker2 for (size_t i = 0; i < count; ++i) { AstExpr* value = i < local->values.size ? local->values.data[i] : nullptr; + const bool isPack = value && (value->is() || value->is()); if (value) visit(value, RValue); - if (i != local->values.size - 1 || value) + if (i != local->values.size - 1 || !isPack) { AstLocal* var = i < local->vars.size ? local->vars.data[i] : nullptr; @@ -351,16 +353,19 @@ struct TypeChecker2 visit(var->annotation); } } - else + else if (value) { - LUAU_ASSERT(value); + TypePackId valuePack = lookupPack(value); + TypePack valueTypes; + if (i < local->vars.size) + valueTypes = extendTypePack(module->internalTypes, builtinTypes, valuePack, local->vars.size - i); - TypePackId valueTypes = lookupPack(value); - auto it = begin(valueTypes); + Location errorLocation; for (size_t j = i; j < local->vars.size; ++j) { - if (it == end(valueTypes)) + if (j - i >= valueTypes.head.size()) { + errorLocation = local->vars.data[j]->location; break; } @@ -368,14 +373,28 @@ struct TypeChecker2 if (var->annotation) { TypeId varType = lookupAnnotation(var->annotation); - ErrorVec errors = tryUnify(stack.back(), value->location, *it, varType); + ErrorVec errors = tryUnify(stack.back(), value->location, valueTypes.head[j - i], varType); if (!errors.empty()) reportErrors(std::move(errors)); visit(var->annotation); } + } - ++it; + if (valueTypes.head.size() < local->vars.size - i) + { + reportError( + CountMismatch{ + // We subtract 1 here because the final AST + // expression is not worth one value. It is worth 0 + // or more depending on valueTypes.head + local->values.size - 1 + valueTypes.head.size(), + std::nullopt, + local->vars.size, + local->values.data[local->values.size - 1]->is() ? CountMismatch::FunctionResult + : CountMismatch::ExprListResult, + }, + errorLocation); } } } @@ -810,6 +829,95 @@ struct TypeChecker2 // TODO! } + ErrorVec visitOverload(AstExprCall* call, NotNull overloadFunctionType, const std::vector& argLocs, + TypePackId expectedArgTypes, TypePackId expectedRetType) + { + ErrorVec overloadErrors = + tryUnify(stack.back(), call->location, overloadFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult); + + size_t argIndex = 0; + auto inferredArgIt = begin(overloadFunctionType->argTypes); + auto expectedArgIt = begin(expectedArgTypes); + while (inferredArgIt != end(overloadFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes)) + { + Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex]; + ErrorVec argErrors = tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt); + for (TypeError e : argErrors) + overloadErrors.emplace_back(e); + + ++argIndex; + ++inferredArgIt; + ++expectedArgIt; + } + + // piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad + ErrorVec argumentErrors = tryUnify(stack.back(), call->location, expectedArgTypes, overloadFunctionType->argTypes); + for (TypeError e : argumentErrors) + if (get(e) != nullptr) + overloadErrors.emplace_back(std::move(e)); + + return overloadErrors; + } + + void reportOverloadResolutionErrors(AstExprCall* call, std::vector overloads, TypePackId expectedArgTypes, + const std::vector& overloadsThatMatchArgCount, std::vector> overloadsErrors) + { + if (overloads.size() == 1) + { + reportErrors(std::get<0>(overloadsErrors.front())); + return; + } + + std::vector overloadTypes = overloadsThatMatchArgCount; + if (overloadsThatMatchArgCount.size() == 0) + { + reportError(GenericError{"No overload for function accepts " + std::to_string(size(expectedArgTypes)) + " arguments."}, call->location); + // If no overloads match argument count, just list all overloads. + overloadTypes = overloads; + } + else + { + // Report errors of the first argument-count-matching, but failing overload + TypeId overload = overloadsThatMatchArgCount[0]; + + // Remove the overload we are reporting errors about from the list of alternatives + overloadTypes.erase(std::remove(overloadTypes.begin(), overloadTypes.end(), overload), overloadTypes.end()); + + const FunctionType* ftv = get(overload); + LUAU_ASSERT(ftv); // overload must be a function type here + + auto error = std::find_if(overloadsErrors.begin(), overloadsErrors.end(), [ftv](const std::pair& e) { + return ftv == std::get<1>(e); + }); + + LUAU_ASSERT(error != overloadsErrors.end()); + reportErrors(std::get<0>(*error)); + + // If only one overload matched, we don't need this error because we provided the previous errors. + if (overloadsThatMatchArgCount.size() == 1) + return; + } + + std::string s; + for (size_t i = 0; i < overloadTypes.size(); ++i) + { + TypeId overload = follow(overloadTypes[i]); + + if (i > 0) + s += "; "; + + if (i > 0 && i == overloadTypes.size() - 1) + s += "and "; + + s += toString(overload); + } + + if (overloadsThatMatchArgCount.size() == 0) + reportError(ExtraInformation{"Available overloads: " + s}, call->func->location); + else + reportError(ExtraInformation{"Other overloads are also not viable: " + s}, call->func->location); + } + void visit(AstExprCall* call) { visit(call->func, RValue); @@ -865,6 +973,10 @@ struct TypeChecker2 return; } } + else if (auto itv = get(functionType)) + { + // We do nothing here because we'll flatten the intersection later, but we don't want to report it as a non-function. + } else if (auto utv = get(functionType)) { // Sometimes it's okay to call a union of functions, but only if all of the functions are the same. @@ -930,48 +1042,105 @@ struct TypeChecker2 TypePackId expectedArgTypes = arena->addTypePack(args); - const FunctionType* inferredFunctionType = get(testFunctionType); - LUAU_ASSERT(inferredFunctionType); // testFunctionType should always be a FunctionType here + std::vector overloads = flattenIntersection(testFunctionType); + std::vector> overloadsErrors; + overloadsErrors.reserve(overloads.size()); - size_t argIndex = 0; - auto inferredArgIt = begin(inferredFunctionType->argTypes); - auto expectedArgIt = begin(expectedArgTypes); - while (inferredArgIt != end(inferredFunctionType->argTypes) && expectedArgIt != end(expectedArgTypes)) + std::vector overloadsThatMatchArgCount; + + for (TypeId overload : overloads) { - Location argLoc = (argIndex >= argLocs.size()) ? argLocs.back() : argLocs[argIndex]; - reportErrors(tryUnify(stack.back(), argLoc, *expectedArgIt, *inferredArgIt)); + overload = follow(overload); - ++argIndex; - ++inferredArgIt; - ++expectedArgIt; + const FunctionType* overloadFn = get(overload); + if (!overloadFn) + { + reportError(CannotCallNonFunction{overload}, call->func->location); + return; + } + else + { + // We may have to instantiate the overload in order for it to typecheck. + if (std::optional instantiatedFunctionType = instantiation.substitute(overload)) + { + overloadFn = get(*instantiatedFunctionType); + } + else + { + overloadsErrors.emplace_back(std::vector{TypeError{call->func->location, UnificationTooComplex{}}}, overloadFn); + return; + } + } + + ErrorVec overloadErrors = visitOverload(call, NotNull{overloadFn}, argLocs, expectedArgTypes, expectedRetType); + if (overloadErrors.empty()) + return; + + bool argMismatch = false; + for (auto error : overloadErrors) + { + CountMismatch* cm = get(error); + if (!cm) + continue; + + if (cm->context == CountMismatch::Arg) + { + argMismatch = true; + break; + } + } + + if (!argMismatch) + overloadsThatMatchArgCount.push_back(overload); + + overloadsErrors.emplace_back(std::move(overloadErrors), overloadFn); } - // piggyback on the unifier for arity checking, but we can't do this for checking the actual arguments since the locations would be bad - ErrorVec errors = tryUnify(stack.back(), call->location, expectedArgTypes, inferredFunctionType->argTypes); - for (TypeError e : errors) - if (get(e) != nullptr) - reportError(std::move(e)); + reportOverloadResolutionErrors(call, overloads, expectedArgTypes, overloadsThatMatchArgCount, overloadsErrors); + } - reportErrors(tryUnify(stack.back(), call->location, inferredFunctionType->retTypes, expectedRetType, CountMismatch::FunctionResult)); + void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context) + { + visit(expr, RValue); + + TypeId leftType = lookupType(expr); + const NormalizedType* norm = normalizer.normalize(leftType); + if (!norm) + reportError(NormalizationTooComplex{}, location); + + checkIndexTypeFromType(leftType, *norm, propName, location, context); } void visit(AstExprIndexName* indexName, ValueContext context) { - visit(indexName->expr, RValue); - - TypeId leftType = lookupType(indexName->expr); - const NormalizedType* norm = normalizer.normalize(leftType); - if (!norm) - reportError(NormalizationTooComplex{}, indexName->indexLocation); - - checkIndexTypeFromType(leftType, *norm, indexName->index.value, indexName->location, context); + visitExprName(indexName->expr, indexName->location, indexName->index.value, context); } void visit(AstExprIndexExpr* indexExpr, ValueContext context) { + if (auto str = indexExpr->index->as()) + { + const std::string stringValue(str->value.data, str->value.size); + visitExprName(indexExpr->expr, indexExpr->location, stringValue, context); + return; + } + // TODO! visit(indexExpr->expr, LValue); visit(indexExpr->index, RValue); + + NotNull scope = stack.back(); + + TypeId exprType = lookupType(indexExpr->expr); + TypeId indexType = lookupType(indexExpr->index); + + if (auto tt = get(exprType)) + { + if (tt->indexer) + reportErrors(tryUnify(scope, indexExpr->index->location, indexType, tt->indexer->indexType)); + else + reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); + } } void visit(AstExprFunction* fn) @@ -1879,8 +2048,17 @@ struct TypeChecker2 ty = *mtIndex; } - if (getTableType(ty)) - return bool(findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location)); + if (auto tt = getTableType(ty)) + { + if (findTablePropertyRespectingMeta(builtinTypes, module->errors, ty, prop, location)) + return true; + + else if (tt->indexer && isPrim(tt->indexer->indexResultType, PrimitiveType::String)) + return tt->indexer->indexResultType; + + else + return false; + } else if (const ClassType* cls = get(ty)) return bool(lookupClassProp(cls, prop)); else if (const UnionType* utv = get(ty)) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index e59c7e0..adca034 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1759,7 +1759,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(expr.location); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } WithPredicate result; @@ -1767,23 +1767,23 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (auto a = expr.as()) result = checkExpr(scope, *a->expr, expectedType); else if (expr.is()) - result = {nilType}; + result = WithPredicate{nilType}; else if (const AstExprConstantBool* bexpr = expr.as()) { if (forceSingleton || (expectedType && maybeSingleton(*expectedType))) - result = {singletonType(bexpr->value)}; + result = WithPredicate{singletonType(bexpr->value)}; else - result = {booleanType}; + result = WithPredicate{booleanType}; } else if (const AstExprConstantString* sexpr = expr.as()) { if (forceSingleton || (expectedType && maybeSingleton(*expectedType))) - result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))}; + result = WithPredicate{singletonType(std::string(sexpr->value.data, sexpr->value.size))}; else - result = {stringType}; + result = WithPredicate{stringType}; } else if (expr.is()) - result = {numberType}; + result = WithPredicate{numberType}; else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) @@ -1837,7 +1837,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp // TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint // ice("AstExprLocal exists but no binding definition for it?", expr.location); reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}}); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) @@ -1849,7 +1849,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}}); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) @@ -1859,26 +1859,26 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (get(varargPack)) { if (std::optional ty = first(varargPack)) - return {*ty}; + return WithPredicate{*ty}; - return {nilType}; + return WithPredicate{nilType}; } else if (get(varargPack)) { TypeId head = freshType(scope); TypePackId tail = freshTypePack(scope); *asMutable(varargPack) = TypePack{{head}, tail}; - return {head}; + return WithPredicate{head}; } if (get(varargPack)) - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) - return {vtp->ty}; + return WithPredicate{vtp->ty}; else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } else ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); @@ -1929,9 +1929,9 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp lhsType = stripFromNilAndReport(lhsType, expr.expr->location); if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, /* addErrors= */ true)) - return {*ty}; + return WithPredicate{*ty}; - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors) @@ -2138,7 +2138,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (std::optional refiTy = resolveLValue(scope, *lvalue)) return {*refiTy, {TruthyPredicate{std::move(*lvalue), expr.location}}}; - return {ty}; + return WithPredicate{ty}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) @@ -2147,7 +2147,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp checkFunctionBody(funScope, funTy, expr); - return {quantify(funScope, funTy, expr.location)}; + return WithPredicate{quantify(funScope, funTy, expr.location)}; } TypeId TypeChecker::checkExprTable( @@ -2252,7 +2252,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(expr.location); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } std::vector> fieldTypes(expr.items.size); @@ -2339,7 +2339,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp expectedIndexResultType = fieldTypes[i].second; } - return {checkExprTable(scope, expr, fieldTypes, expectedType)}; + return WithPredicate{checkExprTable(scope, expr, fieldTypes, expectedType)}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr) @@ -2356,7 +2356,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp const bool operandIsAny = get(operandType) || get(operandType) || get(operandType); if (operandIsAny) - return {operandType}; + return WithPredicate{operandType}; if (typeCouldHaveMetatable(operandType)) { @@ -2377,16 +2377,16 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (!state.errors.empty()) retType = errorRecoveryType(retType); - return {retType}; + return WithPredicate{retType}; } reportError(expr.location, GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } reportErrors(tryUnify(operandType, numberType, scope, expr.location)); - return {numberType}; + return WithPredicate{numberType}; } case AstExprUnary::Len: { @@ -2396,7 +2396,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp // # operator is guaranteed to return number if (get(operandType) || get(operandType) || get(operandType)) - return {numberType}; + return WithPredicate{numberType}; DenseHashSet seen{nullptr}; @@ -2420,7 +2420,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (!hasLength(operandType, seen, &recursionCount)) reportError(TypeError{expr.location, NotATable{operandType}}); - return {numberType}; + return WithPredicate{numberType}; } default: ice("Unknown AstExprUnary " + std::to_string(int(expr.op))); @@ -3014,7 +3014,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp WithPredicate rhs = checkExpr(scope, *expr.right); // Intentionally discarding predicates with other operators. - return {checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)}; + return WithPredicate{checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)}; } } @@ -3045,7 +3045,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp // any type errors that may arise from it are going to be useless. currentModule->errors.resize(oldSize); - return {errorRecoveryType(scope)}; + return WithPredicate{errorRecoveryType(scope)}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) @@ -3061,12 +3061,12 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp WithPredicate falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); if (falseType.type == trueType.type) - return {trueType.type}; + return WithPredicate{trueType.type}; std::vector types = reduceUnion({trueType.type, falseType.type}); if (types.empty()) - return {neverType}; - return {types.size() == 1 ? types[0] : addType(UnionType{std::move(types)})}; + return WithPredicate{neverType}; + return WithPredicate{types.size() == 1 ? types[0] : addType(UnionType{std::move(types)})}; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprInterpString& expr) @@ -3074,7 +3074,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp for (AstExpr* expr : expr.expressions) checkExpr(scope, *expr); - return {stringType}; + return WithPredicate{stringType}; } TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx) @@ -3704,7 +3704,7 @@ WithPredicate TypeChecker::checkExprPack(const ScopePtr& scope, cons { WithPredicate result = checkExprPackHelper(scope, expr); if (containsNever(result.type)) - return {uninhabitableTypePack}; + return WithPredicate{uninhabitableTypePack}; return result; } @@ -3715,14 +3715,14 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope else if (expr.is()) { if (!scope->varargPack) - return {errorRecoveryTypePack(scope)}; + return WithPredicate{errorRecoveryTypePack(scope)}; - return {*scope->varargPack}; + return WithPredicate{*scope->varargPack}; } else { TypeId type = checkExpr(scope, expr).type; - return {addTypePack({type})}; + return WithPredicate{addTypePack({type})}; } } @@ -3994,71 +3994,77 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope { retPack = freshTypePack(free->level); TypePackId freshArgPack = freshTypePack(free->level); - asMutable(actualFunctionType)->ty.emplace(free->level, freshArgPack, retPack); + emplaceType(asMutable(actualFunctionType), free->level, freshArgPack, retPack); } else retPack = freshTypePack(scope->level); - // checkExpr will log the pre-instantiated type of the function. - // That's not nearly as interesting as the instantiated type, which will include details about how - // generic functions are being instantiated for this particular callsite. - currentModule->astOriginalCallTypes[expr.func] = follow(functionType); - currentModule->astTypes[expr.func] = actualFunctionType; + // We break this function up into a lambda here to limit our stack footprint. + // The vectors used by this function aren't allocated until the lambda is actually called. + auto the_rest = [&]() -> WithPredicate { + // checkExpr will log the pre-instantiated type of the function. + // That's not nearly as interesting as the instantiated type, which will include details about how + // generic functions are being instantiated for this particular callsite. + currentModule->astOriginalCallTypes[expr.func] = follow(functionType); + currentModule->astTypes[expr.func] = actualFunctionType; - std::vector overloads = flattenIntersection(actualFunctionType); + std::vector overloads = flattenIntersection(actualFunctionType); - std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); + std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); - WithPredicate argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); - TypePackId argPack = argListResult.type; + WithPredicate argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); + TypePackId argPack = argListResult.type; - if (get(argPack)) - return {errorRecoveryTypePack(scope)}; + if (get(argPack)) + return WithPredicate{errorRecoveryTypePack(scope)}; - TypePack* args = nullptr; - if (expr.self) - { - argPack = addTypePack(TypePack{{selfType}, argPack}); - argListResult.type = argPack; - } - args = getMutable(argPack); - LUAU_ASSERT(args); + TypePack* args = nullptr; + if (expr.self) + { + argPack = addTypePack(TypePack{{selfType}, argPack}); + argListResult.type = argPack; + } + args = getMutable(argPack); + LUAU_ASSERT(args); - std::vector argLocations; - argLocations.reserve(expr.args.size + 1); - if (expr.self) - argLocations.push_back(expr.func->as()->expr->location); - for (AstExpr* arg : expr.args) - argLocations.push_back(arg->location); + std::vector argLocations; + argLocations.reserve(expr.args.size + 1); + if (expr.self) + argLocations.push_back(expr.func->as()->expr->location); + for (AstExpr* arg : expr.args) + argLocations.push_back(arg->location); - std::vector errors; // errors encountered for each overload + std::vector errors; // errors encountered for each overload - std::vector overloadsThatMatchArgCount; - std::vector overloadsThatDont; + std::vector overloadsThatMatchArgCount; + std::vector overloadsThatDont; - for (TypeId fn : overloads) - { - fn = follow(fn); + for (TypeId fn : overloads) + { + fn = follow(fn); - if (auto ret = checkCallOverload( - scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) - return *ret; - } + if (auto ret = checkCallOverload( + scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) + return *ret; + } - if (handleSelfCallMismatch(scope, expr, args, argLocations, errors)) - return {retPack}; + if (handleSelfCallMismatch(scope, expr, args, argLocations, errors)) + return WithPredicate{retPack}; - reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); - const FunctionType* overload = nullptr; - if (!overloadsThatMatchArgCount.empty()) - overload = get(overloadsThatMatchArgCount[0]); - if (!overload && !overloadsThatDont.empty()) - overload = get(overloadsThatDont[0]); - if (overload) - return {errorRecoveryTypePack(overload->retTypes)}; + const FunctionType* overload = nullptr; + if (!overloadsThatMatchArgCount.empty()) + overload = get(overloadsThatMatchArgCount[0]); + if (!overload && !overloadsThatDont.empty()) + overload = get(overloadsThatDont[0]); + if (overload) + return WithPredicate{errorRecoveryTypePack(overload->retTypes)}; - return {errorRecoveryTypePack(retPack)}; + return WithPredicate{errorRecoveryTypePack(retPack)}; + }; + + return the_rest(); } std::vector> TypeChecker::getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall) @@ -4119,8 +4125,13 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st return expectedTypes; } -std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, +/* + * Note: We return a std::unique_ptr here rather than an optional to manage our stack consumption. + * If this was an optional, callers would have to pay the stack cost for the result. This is problematic + * for functions that need to support recursion up to 600 levels deep. + */ +std::unique_ptr> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, + TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector* argLocations, const WithPredicate& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { LUAU_ASSERT(argLocations); @@ -4130,16 +4141,16 @@ std::optional> TypeChecker::checkCallOverload(const Sc if (get(fn)) { unify(anyTypePack, argPack, scope, expr.location); - return {{anyTypePack}}; + return std::make_unique>(anyTypePack); } if (get(fn)) { - return {{errorRecoveryTypePack(scope)}}; + return std::make_unique>(errorRecoveryTypePack(scope)); } if (get(fn)) - return {{uninhabitableTypePack}}; + return std::make_unique>(uninhabitableTypePack); if (auto ftv = get(fn)) { @@ -4152,7 +4163,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc options.isFunctionCall = true; unify(r, fn, scope, expr.location, options); - return {{retPack}}; + return std::make_unique>(retPack); } std::vector metaArgLocations; @@ -4191,7 +4202,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc { reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); unify(errorRecoveryTypePack(scope), retPack, scope, expr.func->location); - return {{errorRecoveryTypePack(retPack)}}; + return std::make_unique>(errorRecoveryTypePack(retPack)); } // When this function type has magic functions and did return something, we select that overload instead. @@ -4200,7 +4211,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc { // TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458 if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) - return *ret; + return std::make_unique>(std::move(*ret)); } Unifier state = mkUnifier(scope, expr.location); @@ -4209,7 +4220,7 @@ std::optional> TypeChecker::checkCallOverload(const Sc checkArgumentList(scope, *expr.func, state, retPack, ftv->retTypes, /*argLocations*/ {}); if (!state.errors.empty()) { - return {}; + return nullptr; } checkArgumentList(scope, *expr.func, state, argPack, ftv->argTypes, *argLocations); @@ -4244,10 +4255,10 @@ std::optional> TypeChecker::checkCallOverload(const Sc currentModule->astOverloadResolvedTypes[&expr] = fn; // We select this overload - return {{retPack}}; + return std::make_unique>(retPack); } - return {}; + return nullptr; } bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, @@ -4404,7 +4415,7 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons }; if (exprs.size == 0) - return {pack}; + return WithPredicate{pack}; TypePack* tp = getMutable(pack); @@ -4484,7 +4495,7 @@ WithPredicate TypeChecker::checkExprList(const ScopePtr& scope, cons log.commit(); if (uninhabitable) - return {uninhabitableTypePack}; + return WithPredicate{uninhabitableTypePack}; return {pack, predicates}; } diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index 94fb4ad..2393829 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -16,11 +16,167 @@ LUAU_FASTFLAGVARIABLE(DebugLuauDontReduceTypes, false) namespace Luau { +namespace detail +{ +bool TypeReductionMemoization::isIrreducible(TypeId ty) +{ + ty = follow(ty); + + // Only does shallow check, the TypeReducer itself already does deep traversal. + if (auto edge = types.find(ty); edge && edge->irreducible) + return true; + else if (get(ty) || get(ty) || get(ty)) + return false; + else if (auto tt = get(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed)) + return false; + else + return true; +} + +bool TypeReductionMemoization::isIrreducible(TypePackId tp) +{ + tp = follow(tp); + + // Only does shallow check, the TypeReducer itself already does deep traversal. + if (auto edge = typePacks.find(tp); edge && edge->irreducible) + return true; + else if (get(tp) || get(tp)) + return false; + else if (auto vtp = get(tp)) + return isIrreducible(vtp->ty); + else + return true; +} + +TypeId TypeReductionMemoization::memoize(TypeId ty, TypeId reducedTy) +{ + ty = follow(ty); + reducedTy = follow(reducedTy); + + // The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible. + // We don't need to recurse much further than that, because we already record the irreducibility from + // the bottom up. + bool irreducible = isIrreducible(reducedTy); + if (auto it = get(reducedTy)) + { + for (TypeId part : it) + irreducible &= isIrreducible(part); + } + else if (auto ut = get(reducedTy)) + { + for (TypeId option : ut) + irreducible &= isIrreducible(option); + } + else if (auto tt = get(reducedTy)) + { + for (auto& [k, p] : tt->props) + irreducible &= isIrreducible(p.type); + + if (tt->indexer) + { + irreducible &= isIrreducible(tt->indexer->indexType); + irreducible &= isIrreducible(tt->indexer->indexResultType); + } + + for (auto ta : tt->instantiatedTypeParams) + irreducible &= isIrreducible(ta); + + for (auto tpa : tt->instantiatedTypePackParams) + irreducible &= isIrreducible(tpa); + } + else if (auto mt = get(reducedTy)) + { + irreducible &= isIrreducible(mt->table); + irreducible &= isIrreducible(mt->metatable); + } + else if (auto ft = get(reducedTy)) + { + irreducible &= isIrreducible(ft->argTypes); + irreducible &= isIrreducible(ft->retTypes); + } + else if (auto nt = get(reducedTy)) + irreducible &= isIrreducible(nt->ty); + + types[ty] = {reducedTy, irreducible}; + types[reducedTy] = {reducedTy, irreducible}; + return reducedTy; +} + +TypePackId TypeReductionMemoization::memoize(TypePackId tp, TypePackId reducedTp) +{ + tp = follow(tp); + reducedTp = follow(reducedTp); + + bool irreducible = isIrreducible(reducedTp); + TypePackIterator it = begin(tp); + while (it != end(tp)) + { + irreducible &= isIrreducible(*it); + ++it; + } + + if (it.tail()) + irreducible &= isIrreducible(*it.tail()); + + typePacks[tp] = {reducedTp, irreducible}; + typePacks[reducedTp] = {reducedTp, irreducible}; + return reducedTp; +} + +std::optional> TypeReductionMemoization::memoizedof(TypeId ty) const +{ + auto fetchContext = [this](TypeId ty) -> std::optional> { + if (auto edge = types.find(ty)) + return *edge; + else + return std::nullopt; + }; + + TypeId currentTy = ty; + std::optional> lastEdge; + while (auto edge = fetchContext(currentTy)) + { + lastEdge = edge; + if (edge->irreducible) + return edge; + else if (edge->type == currentTy) + return edge; + else + currentTy = edge->type; + } + + return lastEdge; +} + +std::optional> TypeReductionMemoization::memoizedof(TypePackId tp) const +{ + auto fetchContext = [this](TypePackId tp) -> std::optional> { + if (auto edge = typePacks.find(tp)) + return *edge; + else + return std::nullopt; + }; + + TypePackId currentTp = tp; + std::optional> lastEdge; + while (auto edge = fetchContext(currentTp)) + { + lastEdge = edge; + if (edge->irreducible) + return edge; + else if (edge->type == currentTp) + return edge; + else + currentTp = edge->type; + } + + return lastEdge; +} +} // namespace detail + namespace { -using detail::ReductionContext; - template std::pair get2(const Thing& one, const Thing& two) { @@ -34,9 +190,7 @@ struct TypeReducer NotNull arena; NotNull builtinTypes; NotNull handle; - - DenseHashMap>* memoizedTypes; - DenseHashMap>* memoizedTypePacks; + NotNull memoization; DenseHashSet* cyclics; int depth = 0; @@ -50,12 +204,6 @@ struct TypeReducer TypeId functionType(TypeId ty); TypeId negationType(TypeId ty); - bool isIrreducible(TypeId ty); - bool isIrreducible(TypePackId tp); - - TypeId memoize(TypeId ty, TypeId reducedTy); - TypePackId memoize(TypePackId tp, TypePackId reducedTp); - using BinaryFold = std::optional (TypeReducer::*)(TypeId, TypeId); using UnaryFold = TypeId (TypeReducer::*)(TypeId); @@ -64,12 +212,15 @@ struct TypeReducer { ty = follow(ty); - if (auto ctx = memoizedTypes->find(ty)) - return {ctx->type, getMutable(ctx->type)}; + if (auto edge = memoization->memoizedof(ty)) + return {edge->type, getMutable(edge->type)}; + // We specifically do not want to use [`detail::TypeReductionMemoization::memoize`] because that will + // potentially consider these copiedTy to be reducible, but we need this to resolve cyclic references + // without attempting to recursively reduce it, causing copies of copies of copies of... TypeId copiedTy = arena->addType(*t); - (*memoizedTypes)[ty] = {copiedTy, true}; - (*memoizedTypes)[copiedTy] = {copiedTy, true}; + memoization->types[ty] = {copiedTy, true}; + memoization->types[copiedTy] = {copiedTy, true}; return {copiedTy, getMutable(copiedTy)}; } @@ -175,8 +326,13 @@ TypeId TypeReducer::reduce(TypeId ty) { ty = follow(ty); - if (auto ctx = memoizedTypes->find(ty); ctx && ctx->irreducible) - return ctx->type; + if (auto edge = memoization->memoizedof(ty)) + { + if (edge->irreducible) + return edge->type; + else + ty = edge->type; + } else if (cyclics->contains(ty)) return ty; @@ -196,15 +352,20 @@ TypeId TypeReducer::reduce(TypeId ty) else result = ty; - return memoize(ty, result); + return memoization->memoize(ty, result); } TypePackId TypeReducer::reduce(TypePackId tp) { tp = follow(tp); - if (auto ctx = memoizedTypePacks->find(tp); ctx && ctx->irreducible) - return ctx->type; + if (auto edge = memoization->memoizedof(tp)) + { + if (edge->irreducible) + return edge->type; + else + tp = edge->type; + } else if (cyclics->contains(tp)) return tp; @@ -237,11 +398,11 @@ TypePackId TypeReducer::reduce(TypePackId tp) } if (!didReduce) - return memoize(tp, tp); + return memoization->memoize(tp, tp); else if (head.empty() && tail) - return memoize(tp, *tail); + return memoization->memoize(tp, *tail); else - return memoize(tp, arena->addTypePack(TypePack{std::move(head), tail})); + return memoization->memoize(tp, arena->addTypePack(TypePack{std::move(head), tail})); } std::optional TypeReducer::intersectionType(TypeId left, TypeId right) @@ -832,111 +993,6 @@ TypeId TypeReducer::negationType(TypeId ty) return ty; // for all T except the ones handled above, ~T ~ ~T } -bool TypeReducer::isIrreducible(TypeId ty) -{ - ty = follow(ty); - - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (auto ctx = memoizedTypes->find(ty); ctx && ctx->irreducible) - return true; - else if (get(ty) || get(ty) || get(ty)) - return false; - else if (auto tt = get(ty); tt && (tt->state == TableState::Free || tt->state == TableState::Unsealed)) - return false; - else - return true; -} - -bool TypeReducer::isIrreducible(TypePackId tp) -{ - tp = follow(tp); - - // Only does shallow check, the TypeReducer itself already does deep traversal. - if (auto ctx = memoizedTypePacks->find(tp); ctx && ctx->irreducible) - return true; - else if (get(tp) || get(tp)) - return false; - else if (auto vtp = get(tp)) - return isIrreducible(vtp->ty); - else - return true; -} - -TypeId TypeReducer::memoize(TypeId ty, TypeId reducedTy) -{ - ty = follow(ty); - reducedTy = follow(reducedTy); - - // The irreducibility of this [`reducedTy`] depends on whether its contents are themselves irreducible. - // We don't need to recurse much further than that, because we already record the irreducibility from - // the bottom up. - bool irreducible = isIrreducible(reducedTy); - if (auto it = get(reducedTy)) - { - for (TypeId part : it) - irreducible &= isIrreducible(part); - } - else if (auto ut = get(reducedTy)) - { - for (TypeId option : ut) - irreducible &= isIrreducible(option); - } - else if (auto tt = get(reducedTy)) - { - for (auto& [k, p] : tt->props) - irreducible &= isIrreducible(p.type); - - if (tt->indexer) - { - irreducible &= isIrreducible(tt->indexer->indexType); - irreducible &= isIrreducible(tt->indexer->indexResultType); - } - - for (auto ta : tt->instantiatedTypeParams) - irreducible &= isIrreducible(ta); - - for (auto tpa : tt->instantiatedTypePackParams) - irreducible &= isIrreducible(tpa); - } - else if (auto mt = get(reducedTy)) - { - irreducible &= isIrreducible(mt->table); - irreducible &= isIrreducible(mt->metatable); - } - else if (auto ft = get(reducedTy)) - { - irreducible &= isIrreducible(ft->argTypes); - irreducible &= isIrreducible(ft->retTypes); - } - else if (auto nt = get(reducedTy)) - irreducible &= isIrreducible(nt->ty); - - (*memoizedTypes)[ty] = {reducedTy, irreducible}; - (*memoizedTypes)[reducedTy] = {reducedTy, irreducible}; - return reducedTy; -} - -TypePackId TypeReducer::memoize(TypePackId tp, TypePackId reducedTp) -{ - tp = follow(tp); - reducedTp = follow(reducedTp); - - bool irreducible = isIrreducible(reducedTp); - TypePackIterator it = begin(tp); - while (it != end(tp)) - { - irreducible &= isIrreducible(*it); - ++it; - } - - if (it.tail()) - irreducible &= isIrreducible(*it.tail()); - - (*memoizedTypePacks)[tp] = {reducedTp, irreducible}; - (*memoizedTypePacks)[reducedTp] = {reducedTp, irreducible}; - return reducedTp; -} - struct MarkCycles : TypeVisitor { DenseHashSet cyclics{nullptr}; @@ -961,7 +1017,6 @@ struct MarkCycles : TypeVisitor return !cyclics.find(follow(tp)); } }; - } // namespace TypeReduction::TypeReduction( @@ -981,8 +1036,13 @@ std::optional TypeReduction::reduce(TypeId ty) return ty; else if (!options.allowTypeReductionsFromOtherArenas && ty->owningArena != arena) return ty; - else if (auto memoized = memoizedof(ty)) - return *memoized; + else if (auto edge = memoization.memoizedof(ty)) + { + if (edge->irreducible) + return edge->type; + else + ty = edge->type; + } else if (hasExceededCartesianProductLimit(ty)) return std::nullopt; @@ -991,7 +1051,7 @@ std::optional TypeReduction::reduce(TypeId ty) MarkCycles finder; finder.traverse(ty); - TypeReducer reducer{arena, builtinTypes, handle, &memoizedTypes, &memoizedTypePacks, &finder.cyclics}; + TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics}; return reducer.reduce(ty); } catch (const RecursionLimitException&) @@ -1008,8 +1068,13 @@ std::optional TypeReduction::reduce(TypePackId tp) return tp; else if (!options.allowTypeReductionsFromOtherArenas && tp->owningArena != arena) return tp; - else if (auto memoized = memoizedof(tp)) - return *memoized; + else if (auto edge = memoization.memoizedof(tp)) + { + if (edge->irreducible) + return edge->type; + else + tp = edge->type; + } else if (hasExceededCartesianProductLimit(tp)) return std::nullopt; @@ -1018,7 +1083,7 @@ std::optional TypeReduction::reduce(TypePackId tp) MarkCycles finder; finder.traverse(tp); - TypeReducer reducer{arena, builtinTypes, handle, &memoizedTypes, &memoizedTypePacks, &finder.cyclics}; + TypeReducer reducer{arena, builtinTypes, handle, NotNull{&memoization}, &finder.cyclics}; return reducer.reduce(tp); } catch (const RecursionLimitException&) @@ -1039,13 +1104,6 @@ std::optional TypeReduction::reduce(const TypeFun& fun) return std::nullopt; } -TypeReduction TypeReduction::fork(NotNull arena, const TypeReductionOptions& opts) const -{ - TypeReduction child{arena, builtinTypes, handle, opts}; - child.parent = this; - return child; -} - size_t TypeReduction::cartesianProductSize(TypeId ty) const { ty = follow(ty); @@ -1093,24 +1151,4 @@ bool TypeReduction::hasExceededCartesianProductLimit(TypePackId tp) const return false; } -std::optional TypeReduction::memoizedof(TypeId ty) const -{ - if (auto ctx = memoizedTypes.find(ty); ctx && ctx->irreducible) - return ctx->type; - else if (parent) - return parent->memoizedof(ty); - else - return std::nullopt; -} - -std::optional TypeReduction::memoizedof(TypePackId tp) const -{ - if (auto ctx = memoizedTypePacks.find(tp); ctx && ctx->irreducible) - return ctx->type; - else if (parent) - return parent->memoizedof(tp); - else - return std::nullopt; -} - } // namespace Luau diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 7104f2e..6364a5a 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -520,7 +520,12 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool size_t errorCount = errors.size(); - if (const UnionType* subUnion = log.getMutable(subTy)) + if (log.getMutable(subTy) && log.getMutable(superTy)) + { + blockedTypes.push_back(subTy); + blockedTypes.push_back(superTy); + } + else if (const UnionType* subUnion = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, subUnion, superTy); } diff --git a/CLI/Reduce.cpp b/CLI/Reduce.cpp index d24c987..b7c7801 100644 --- a/CLI/Reduce.cpp +++ b/CLI/Reduce.cpp @@ -487,7 +487,7 @@ int main(int argc, char** argv) if (args.size() < 4) help(args); - for (int i = 1; i < args.size(); ++i) + for (size_t i = 1; i < args.size(); ++i) { if (args[i] == "--help") help(args); diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index ebbba68..2955342 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -42,6 +42,7 @@ struct IrBuilder IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d); IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e); + IrOp inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e, IrOp f); IrOp block(IrBlockKind kind); // Requested kind can be ignored if we are in an outlined sequence IrOp blockAtInst(uint32_t index); @@ -57,6 +58,8 @@ struct IrBuilder IrFunction function; + uint32_t activeBlockIdx = ~0u; + std::vector instIndexToBlock; // Block index at the bytecode instruction }; diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 6a70946..18d510c 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -5,6 +5,7 @@ #include "Luau/RegisterX64.h" #include "Luau/RegisterA64.h" +#include #include #include @@ -186,6 +187,16 @@ enum class IrCmd : uint8_t // A: int INT_TO_NUM, + // Adjust stack top (L->top) to point at 'B' TValues *after* the specified register + // This is used to return muliple values + // A: Rn + // B: int (offset) + ADJUST_STACK_TO_REG, + + // Restore stack top (L->top) to point to the function stack top (L->ci->top) + // This is used to recover after calling a variadic function + ADJUST_STACK_TO_TOP, + // Fallback functions // Perform an arithmetic operation on TValues of any type @@ -329,7 +340,7 @@ enum class IrCmd : uint8_t // Call specified function // A: unsigned int (bytecode instruction index) // B: Rn (function, followed by arguments) - // C: int (argument count or -1 to preserve all arguments up to stack top) + // C: int (argument count or -1 to use all arguments up to stack top) // D: int (result count or -1 to preserve all results and adjust stack top) // Note: return values are placed starting from Rn specified in 'B' LOP_CALL, @@ -337,13 +348,13 @@ enum class IrCmd : uint8_t // Return specified values from the function // A: unsigned int (bytecode instruction index) // B: Rn (value start) - // B: int (result count or -1 to return all values up to stack top) + // C: int (result count or -1 to return all values up to stack top) LOP_RETURN, // Perform a fast call of a built-in function // A: unsigned int (bytecode instruction index) // B: Rn (argument start) - // C: int (argument count or -1 preserve all arguments up to stack top) + // C: int (argument count or -1 use all arguments up to stack top) // D: block (fallback) // Note: return values are placed starting from Rn specified in 'B' LOP_FASTCALL, @@ -560,6 +571,7 @@ struct IrInst IrOp c; IrOp d; IrOp e; + IrOp f; uint32_t lastUse = 0; uint16_t useCount = 0; @@ -584,9 +596,10 @@ struct IrBlock uint16_t useCount = 0; - // Start points to an instruction index in a stream - // End is implicit + // 'start' and 'finish' define an inclusive range of instructions which belong to this block inside the function + // When block has been constructed, 'finish' always points to the first and only terminating instruction uint32_t start = ~0u; + uint32_t finish = ~0u; Label label; }; @@ -633,6 +646,19 @@ struct IrFunction return value.valueTag; } + std::optional asTagOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Tag) + return std::nullopt; + + return value.valueTag; + } + bool boolOp(IrOp op) { IrConst& value = constOp(op); @@ -641,6 +667,19 @@ struct IrFunction return value.valueBool; } + std::optional asBoolOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Bool) + return std::nullopt; + + return value.valueBool; + } + int intOp(IrOp op) { IrConst& value = constOp(op); @@ -649,6 +688,19 @@ struct IrFunction return value.valueInt; } + std::optional asIntOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Int) + return std::nullopt; + + return value.valueInt; + } + unsigned uintOp(IrOp op) { IrConst& value = constOp(op); @@ -657,6 +709,19 @@ struct IrFunction return value.valueUint; } + std::optional asUintOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Uint) + return std::nullopt; + + return value.valueUint; + } + double doubleOp(IrOp op) { IrConst& value = constOp(op); @@ -665,11 +730,31 @@ struct IrFunction return value.valueDouble; } + std::optional asDoubleOp(IrOp op) + { + if (op.kind != IrOpKind::Constant) + return std::nullopt; + + IrConst& value = constOp(op); + + if (value.kind != IrConstKind::Double) + return std::nullopt; + + return value.valueDouble; + } + IrCondition conditionOp(IrOp op) { LUAU_ASSERT(op.kind == IrOpKind::Condition); return IrCondition(op.index); } + + uint32_t getBlockIndex(const IrBlock& block) + { + // Can only be called with blocks from our vector + LUAU_ASSERT(&block >= blocks.data() && &block <= blocks.data() + blocks.size()); + return uint32_t(&block - blocks.data()); + } }; } // namespace CodeGen diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 3e95813..0a23b3f 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -162,6 +162,8 @@ inline bool isPseudo(IrCmd cmd) return cmd == IrCmd::NOP || cmd == IrCmd::SUBSTITUTE; } +bool isGCO(uint8_t tag); + // Remove a single instruction void kill(IrFunction& function, IrInst& inst); @@ -179,7 +181,7 @@ void replace(IrFunction& function, IrOp& original, IrOp replacement); // Replace a single instruction // Target instruction index instead of reference is used to handle introduction of a new block terminator -void replace(IrFunction& function, uint32_t instIdx, IrInst replacement); +void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst replacement); // Replace instruction with a different value (using IrCmd::SUBSTITUTE) void substitute(IrFunction& function, IrInst& inst, IrOp replacement); @@ -188,10 +190,13 @@ void substitute(IrFunction& function, IrInst& inst, IrOp replacement); void applySubstitutions(IrFunction& function, IrOp& op); void applySubstitutions(IrFunction& function, IrInst& inst); +// Compare numbers using IR condition value +bool compare(double a, double b, IrCondition cond); + // Perform constant folding on instruction at index // For most instructions, successful folding results in a IrCmd::SUBSTITUTE // But it can also be successful on conditional control-flow, replacing it with an unconditional IrCmd::JUMP -void foldConstants(IrBuilder& build, IrFunction& function, uint32_t instIdx); +void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint32_t instIdx); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/OptimizeConstProp.h b/CodeGen/include/Luau/OptimizeConstProp.h new file mode 100644 index 0000000..3be0441 --- /dev/null +++ b/CodeGen/include/Luau/OptimizeConstProp.h @@ -0,0 +1,16 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/IrData.h" + +namespace Luau +{ +namespace CodeGen +{ + +struct IrBuilder; + +void constPropInBlockChains(IrBuilder& build); + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 78f001f..5076cba 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -7,6 +7,7 @@ #include "Luau/CodeBlockUnwind.h" #include "Luau/IrAnalysis.h" #include "Luau/IrBuilder.h" +#include "Luau/OptimizeConstProp.h" #include "Luau/OptimizeFinalX64.h" #include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilderDwarf2.h" @@ -31,7 +32,7 @@ #endif #endif -LUAU_FASTFLAGVARIABLE(DebugUseOldCodegen, false) +LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt, false) namespace Luau { @@ -40,12 +41,6 @@ namespace CodeGen constexpr uint32_t kFunctionAlignment = 32; -struct InstructionOutline -{ - int pcpos; - int length; -}; - static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers) { if (build.logText) @@ -64,346 +59,6 @@ static void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers) emitContinueCallInVm(build); } -static int emitInst(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, LuauOpcode op, const Instruction* pc, int i, - Label* labelarr, Label& next, Label& fallback) -{ - int skip = 0; - - switch (op) - { - case LOP_NOP: - break; - case LOP_LOADNIL: - emitInstLoadNil(build, pc); - break; - case LOP_LOADB: - emitInstLoadB(build, pc, i, labelarr); - break; - case LOP_LOADN: - emitInstLoadN(build, pc); - break; - case LOP_LOADK: - emitInstLoadK(build, pc); - break; - case LOP_LOADKX: - emitInstLoadKX(build, pc); - break; - case LOP_MOVE: - emitInstMove(build, pc); - break; - case LOP_GETGLOBAL: - emitInstGetGlobal(build, pc, i, fallback); - break; - case LOP_SETGLOBAL: - emitInstSetGlobal(build, pc, i, next, fallback); - break; - case LOP_NAMECALL: - emitInstNameCall(build, pc, i, proto->k, next, fallback); - break; - case LOP_CALL: - emitInstCall(build, helpers, pc, i); - break; - case LOP_RETURN: - emitInstReturn(build, helpers, pc, i); - break; - case LOP_GETTABLE: - emitInstGetTable(build, pc, fallback); - break; - case LOP_SETTABLE: - emitInstSetTable(build, pc, next, fallback); - break; - case LOP_GETTABLEKS: - emitInstGetTableKS(build, pc, i, fallback); - break; - case LOP_SETTABLEKS: - emitInstSetTableKS(build, pc, i, next, fallback); - break; - case LOP_GETTABLEN: - emitInstGetTableN(build, pc, fallback); - break; - case LOP_SETTABLEN: - emitInstSetTableN(build, pc, next, fallback); - break; - case LOP_JUMP: - emitInstJump(build, pc, i, labelarr); - break; - case LOP_JUMPBACK: - emitInstJumpBack(build, pc, i, labelarr); - break; - case LOP_JUMPIF: - emitInstJumpIf(build, pc, i, labelarr, /* not_ */ false); - break; - case LOP_JUMPIFNOT: - emitInstJumpIf(build, pc, i, labelarr, /* not_ */ true); - break; - case LOP_JUMPIFEQ: - emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ false, fallback); - break; - case LOP_JUMPIFLE: - emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::LessEqual, fallback); - break; - case LOP_JUMPIFLT: - emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::Less, fallback); - break; - case LOP_JUMPIFNOTEQ: - emitInstJumpIfEq(build, pc, i, labelarr, /* not_ */ true, fallback); - break; - case LOP_JUMPIFNOTLE: - emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLessEqual, fallback); - break; - case LOP_JUMPIFNOTLT: - emitInstJumpIfCond(build, pc, i, labelarr, ConditionX64::NotLess, fallback); - break; - case LOP_JUMPX: - emitInstJumpX(build, pc, i, labelarr); - break; - case LOP_JUMPXEQKNIL: - emitInstJumpxEqNil(build, pc, i, labelarr); - break; - case LOP_JUMPXEQKB: - emitInstJumpxEqB(build, pc, i, labelarr); - break; - case LOP_JUMPXEQKN: - emitInstJumpxEqN(build, pc, proto->k, i, labelarr); - break; - case LOP_JUMPXEQKS: - emitInstJumpxEqS(build, pc, i, labelarr); - break; - case LOP_ADD: - emitInstBinary(build, pc, TM_ADD, fallback); - break; - case LOP_SUB: - emitInstBinary(build, pc, TM_SUB, fallback); - break; - case LOP_MUL: - emitInstBinary(build, pc, TM_MUL, fallback); - break; - case LOP_DIV: - emitInstBinary(build, pc, TM_DIV, fallback); - break; - case LOP_MOD: - emitInstBinary(build, pc, TM_MOD, fallback); - break; - case LOP_POW: - emitInstBinary(build, pc, TM_POW, fallback); - break; - case LOP_ADDK: - emitInstBinaryK(build, pc, TM_ADD, fallback); - break; - case LOP_SUBK: - emitInstBinaryK(build, pc, TM_SUB, fallback); - break; - case LOP_MULK: - emitInstBinaryK(build, pc, TM_MUL, fallback); - break; - case LOP_DIVK: - emitInstBinaryK(build, pc, TM_DIV, fallback); - break; - case LOP_MODK: - emitInstBinaryK(build, pc, TM_MOD, fallback); - break; - case LOP_POWK: - emitInstPowK(build, pc, proto->k, fallback); - break; - case LOP_NOT: - emitInstNot(build, pc); - break; - case LOP_MINUS: - emitInstMinus(build, pc, fallback); - break; - case LOP_LENGTH: - emitInstLength(build, pc, fallback); - break; - case LOP_NEWTABLE: - emitInstNewTable(build, pc, i, next); - break; - case LOP_DUPTABLE: - emitInstDupTable(build, pc, i, next); - break; - case LOP_SETLIST: - emitInstSetList(build, pc, next); - break; - case LOP_GETUPVAL: - emitInstGetUpval(build, pc); - break; - case LOP_SETUPVAL: - emitInstSetUpval(build, pc, next); - break; - case LOP_CLOSEUPVALS: - emitInstCloseUpvals(build, pc, next); - break; - case LOP_FASTCALL: - // We want to lower next instruction at skip+2, but this instruction is only 1 long, so we need to add 1 - skip = emitInstFastCall(build, pc, i, next) + 1; - break; - case LOP_FASTCALL1: - // We want to lower next instruction at skip+2, but this instruction is only 1 long, so we need to add 1 - skip = emitInstFastCall1(build, pc, i, next) + 1; - break; - case LOP_FASTCALL2: - skip = emitInstFastCall2(build, pc, i, next); - break; - case LOP_FASTCALL2K: - skip = emitInstFastCall2K(build, pc, i, next); - break; - case LOP_FORNPREP: - emitInstForNPrep(build, pc, i, next, labelarr[i + 1 + LUAU_INSN_D(*pc)]); - break; - case LOP_FORNLOOP: - emitInstForNLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)], next); - break; - case LOP_FORGLOOP: - emitinstForGLoop(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)], next, fallback); - break; - case LOP_FORGPREP_NEXT: - emitInstForGPrepNext(build, pc, labelarr[i + 1 + LUAU_INSN_D(*pc)], fallback); - break; - case LOP_FORGPREP_INEXT: - emitInstForGPrepInext(build, pc, labelarr[i + 1 + LUAU_INSN_D(*pc)], fallback); - break; - case LOP_AND: - emitInstAnd(build, pc); - break; - case LOP_ANDK: - emitInstAndK(build, pc); - break; - case LOP_OR: - emitInstOr(build, pc); - break; - case LOP_ORK: - emitInstOrK(build, pc); - break; - case LOP_GETIMPORT: - emitInstGetImport(build, pc, fallback); - break; - case LOP_CONCAT: - emitInstConcat(build, pc, i, next); - break; - case LOP_COVERAGE: - emitInstCoverage(build, i); - break; - default: - emitFallback(build, data, op, i); - break; - } - - return skip; -} - -static void emitInstFallback(AssemblyBuilderX64& build, NativeState& data, LuauOpcode op, const Instruction* pc, int i, Label* labelarr) -{ - switch (op) - { - case LOP_GETIMPORT: - emitSetSavedPc(build, i + 1); - emitInstGetImportFallback(build, LUAU_INSN_A(*pc), pc[1]); - break; - case LOP_GETTABLE: - emitInstGetTableFallback(build, pc, i); - break; - case LOP_SETTABLE: - emitInstSetTableFallback(build, pc, i); - break; - case LOP_GETTABLEN: - emitInstGetTableNFallback(build, pc, i); - break; - case LOP_SETTABLEN: - emitInstSetTableNFallback(build, pc, i); - break; - case LOP_NAMECALL: - // TODO: fast-paths that we've handled can be removed from the fallback - emitFallback(build, data, op, i); - break; - case LOP_JUMPIFEQ: - emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ false); - break; - case LOP_JUMPIFLE: - emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::LessEqual); - break; - case LOP_JUMPIFLT: - emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::Less); - break; - case LOP_JUMPIFNOTEQ: - emitInstJumpIfEqFallback(build, pc, i, labelarr, /* not_ */ true); - break; - case LOP_JUMPIFNOTLE: - emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLessEqual); - break; - case LOP_JUMPIFNOTLT: - emitInstJumpIfCondFallback(build, pc, i, labelarr, ConditionX64::NotLess); - break; - case LOP_ADD: - emitInstBinaryFallback(build, pc, i, TM_ADD); - break; - case LOP_SUB: - emitInstBinaryFallback(build, pc, i, TM_SUB); - break; - case LOP_MUL: - emitInstBinaryFallback(build, pc, i, TM_MUL); - break; - case LOP_DIV: - emitInstBinaryFallback(build, pc, i, TM_DIV); - break; - case LOP_MOD: - emitInstBinaryFallback(build, pc, i, TM_MOD); - break; - case LOP_POW: - emitInstBinaryFallback(build, pc, i, TM_POW); - break; - case LOP_ADDK: - emitInstBinaryKFallback(build, pc, i, TM_ADD); - break; - case LOP_SUBK: - emitInstBinaryKFallback(build, pc, i, TM_SUB); - break; - case LOP_MULK: - emitInstBinaryKFallback(build, pc, i, TM_MUL); - break; - case LOP_DIVK: - emitInstBinaryKFallback(build, pc, i, TM_DIV); - break; - case LOP_MODK: - emitInstBinaryKFallback(build, pc, i, TM_MOD); - break; - case LOP_POWK: - emitInstBinaryKFallback(build, pc, i, TM_POW); - break; - case LOP_MINUS: - emitInstMinusFallback(build, pc, i); - break; - case LOP_LENGTH: - emitInstLengthFallback(build, pc, i); - break; - case LOP_FORGLOOP: - emitinstForGLoopFallback(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]); - break; - case LOP_FORGPREP_NEXT: - case LOP_FORGPREP_INEXT: - emitInstForGPrepXnextFallback(build, pc, i, labelarr[i + 1 + LUAU_INSN_D(*pc)]); - break; - case LOP_GETGLOBAL: - // TODO: luaV_gettable + cachedslot update instead of full fallback - emitFallback(build, data, op, i); - break; - case LOP_SETGLOBAL: - // TODO: luaV_settable + cachedslot update instead of full fallback - emitFallback(build, data, op, i); - break; - case LOP_GETTABLEKS: - // Full fallback required for LOP_GETTABLEKS because 'luaV_gettable' doesn't handle builtin vector field access - // It is also required to perform cached slot update - // TODO: extra fast-paths could be lowered before the full fallback - emitFallback(build, data, op, i); - break; - case LOP_SETTABLEKS: - // TODO: luaV_settable + cachedslot update instead of full fallback - emitFallback(build, data, op, i); - break; - default: - LUAU_ASSERT(!"Expected fallback for instruction"); - } -} - static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) { NativeProto* result = new NativeProto(); @@ -423,153 +78,32 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat build.logAppend("\n"); } - if (!FFlag::DebugUseOldCodegen) - { - build.align(kFunctionAlignment, AlignmentDataX64::Ud2); - - Label start = build.setLabel(); - - IrBuilder builder; - builder.buildFunctionIr(proto); - - optimizeMemoryOperandsX64(builder.function); - - IrLoweringX64 lowering(build, helpers, data, proto, builder.function); - - lowering.lower(options); - - result->instTargets = new uintptr_t[proto->sizecode]; - - for (int i = 0; i < proto->sizecode; i++) - { - auto [irLocation, asmLocation] = builder.function.bcMapping[i]; - - result->instTargets[i] = irLocation == ~0u ? 0 : asmLocation - start.location; - } - - result->location = start.location; - - if (build.logText) - build.logAppend("\n"); - - return result; - } - - std::vector