From 4b267aa5c561b3254f43b3f3ead05fa4be499629 Mon Sep 17 00:00:00 2001 From: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com> Date: Fri, 28 Apr 2023 22:55:13 +0300 Subject: [PATCH] Sync to upstream/release/574 (#910) * Added a limit on how many instructions the Compiler can safely produce (reported by @TheGreatSageEqualToHeaven) C++ API Changes: * With work started on read-only and write-only properties, `Property::type` member variable has been replaced with `TypeId type()` and `setType(TypeId)` functions. * New `LazyType` unwrap callback now has a `void` return type, all that's required from the callback is to write into `unwrapped` field. In our work on the new type solver, the following issues were fixed: * Work has started to support https://github.com/Roblox/luau/pull/77 and https://github.com/Roblox/luau/pull/79 * Refinements are no longer applied on l-values, removing some false-positive errors * Improved overload resolution against expected result type * `Frontend::prepareModuleScope` now works in the new solver * Cofinite strings are now comparable And these are the changes in native code generation (JIT): * Fixed MIN_NUM and MAX_NUM constant fold when one of the arguments is NaN * Added constant folding for number conversions and bit operations * Value spilling and rematerialization is now supported on arm64 * Improved FASTCALL2K IR generation to support second argument constant * Added value numbering and load/store propagation optimizations * Added STORE_VECTOR on arm64, completing the IR lowering on this target --- .../include/Luau/ConstraintGraphBuilder.h | 12 +- Analysis/include/Luau/Frontend.h | 13 +- Analysis/include/Luau/Normalize.h | 3 + Analysis/include/Luau/Type.h | 35 +- Analysis/include/Luau/TypeInfer.h | 7 +- Analysis/include/Luau/TypeUtils.h | 6 + Analysis/include/Luau/VisitType.h | 8 +- Analysis/src/AstQuery.cpp | 4 +- Analysis/src/Autocomplete.cpp | 4 +- Analysis/src/BuiltinDefinitions.cpp | 7 +- Analysis/src/Clone.cpp | 39 +- Analysis/src/ConstraintGraphBuilder.cpp | 75 ++- Analysis/src/ConstraintSolver.cpp | 35 +- Analysis/src/Frontend.cpp | 140 ++--- Analysis/src/Normalize.cpp | 39 +- Analysis/src/Quantify.cpp | 2 +- Analysis/src/Substitution.cpp | 8 +- Analysis/src/ToDot.cpp | 4 +- Analysis/src/ToString.cpp | 2 +- Analysis/src/Type.cpp | 104 +++- Analysis/src/TypeAttach.cpp | 4 +- Analysis/src/TypeChecker2.cpp | 94 +-- Analysis/src/TypeInfer.cpp | 33 +- Analysis/src/TypeReduction.cpp | 10 +- Analysis/src/TypeUtils.cpp | 6 +- Analysis/src/Unifier.cpp | 24 +- CLI/Repl.cpp | 35 +- CodeGen/include/Luau/AssemblyBuilderA64.h | 1 + CodeGen/include/Luau/AssemblyBuilderX64.h | 2 +- CodeGen/include/Luau/IrBuilder.h | 2 + CodeGen/include/Luau/IrData.h | 81 ++- CodeGen/include/Luau/IrUtils.h | 3 +- CodeGen/include/Luau/OptimizeConstProp.h | 4 +- CodeGen/include/Luau/RegisterA64.h | 34 ++ CodeGen/src/AssemblyBuilderA64.cpp | 24 +- CodeGen/src/AssemblyBuilderX64.cpp | 4 +- CodeGen/src/BitUtils.h | 20 + CodeGen/src/CodeGen.cpp | 103 ++-- CodeGen/src/CodeGenA64.cpp | 21 +- CodeGen/src/EmitBuiltinsX64.cpp | 52 +- CodeGen/src/EmitBuiltinsX64.h | 2 +- CodeGen/src/EmitCommonA64.h | 3 +- CodeGen/src/EmitInstructionX64.cpp | 11 +- CodeGen/src/IrAnalysis.cpp | 7 +- CodeGen/src/IrBuilder.cpp | 11 +- CodeGen/src/IrDump.cpp | 10 +- CodeGen/src/IrLoweringA64.cpp | 281 ++++----- CodeGen/src/IrLoweringA64.h | 3 + CodeGen/src/IrLoweringX64.cpp | 106 ++-- CodeGen/src/IrRegAllocA64.cpp | 154 +++-- CodeGen/src/IrRegAllocA64.h | 2 +- CodeGen/src/IrTranslateBuiltins.cpp | 297 +++++---- CodeGen/src/IrTranslation.cpp | 20 +- CodeGen/src/IrUtils.cpp | 176 +++++- CodeGen/src/IrValueLocationTracking.cpp | 1 - CodeGen/src/NativeState.h | 7 +- CodeGen/src/OptimizeConstProp.cpp | 193 +++++- CodeGen/src/OptimizeFinalX64.cpp | 1 - Compiler/include/Luau/BytecodeBuilder.h | 2 + Compiler/src/BytecodeBuilder.cpp | 5 + Compiler/src/Compiler.cpp | 6 + Makefile | 2 +- VM/src/ldo.cpp | 39 +- VM/src/ltablib.cpp | 47 +- tests/AssemblyBuilderA64.test.cpp | 10 + tests/Conformance.test.cpp | 6 +- tests/ConstraintGraphBuilderFixture.cpp | 4 +- tests/Fixture.cpp | 9 +- tests/Frontend.test.cpp | 17 + tests/IrBuilder.test.cpp | 563 ++++++++++++++++-- tests/Module.test.cpp | 16 +- tests/NonstrictMode.test.cpp | 8 +- tests/ToString.test.cpp | 12 +- tests/TypeInfer.annotations.test.cpp | 8 +- tests/TypeInfer.anyerror.test.cpp | 2 +- tests/TypeInfer.builtins.test.cpp | 2 +- tests/TypeInfer.functions.test.cpp | 4 +- tests/TypeInfer.generics.test.cpp | 6 +- tests/TypeInfer.negations.test.cpp | 29 + tests/TypeInfer.refinements.test.cpp | 32 + tests/TypeInfer.tables.test.cpp | 22 +- tests/TypeInfer.tryUnify.test.cpp | 8 +- tools/faillist.txt | 7 +- tools/test_dcr.py | 14 +- 84 files changed, 2232 insertions(+), 1037 deletions(-) diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index cbf679c..5800d14 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -11,6 +11,7 @@ #include "Luau/Refinement.h" #include "Luau/Symbol.h" #include "Luau/Type.h" +#include "Luau/TypeUtils.h" #include "Luau/Variant.h" #include @@ -91,10 +92,14 @@ struct ConstraintGraphBuilder const NotNull ice; ScopePtr globalScope; + + std::function prepareModuleScope; + DcrLogger* logger; ConstraintGraphBuilder(ModulePtr module, TypeArena* arena, NotNull moduleResolver, NotNull builtinTypes, - NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, NotNull dfg); + NotNull ice, const ScopePtr& globalScope, std::function prepareModuleScope, + DcrLogger* logger, NotNull dfg); /** * Fabricates a new free type belonging to a given scope. @@ -174,11 +179,12 @@ struct ConstraintGraphBuilder * surrounding context. Used to implement bidirectional type checking. * @return the type of the expression. */ - Inference check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType = {}, bool forceSingleton = false); + Inference check(const ScopePtr& scope, AstExpr* expr, ValueContext context = ValueContext::RValue, std::optional expectedType = {}, + bool forceSingleton = false); Inference check(const ScopePtr& scope, AstExprConstantString* string, std::optional expectedType, bool forceSingleton); Inference check(const ScopePtr& scope, AstExprConstantBool* bool_, std::optional expectedType, bool forceSingleton); - Inference check(const ScopePtr& scope, AstExprLocal* local); + Inference check(const ScopePtr& scope, AstExprLocal* local, ValueContext context); Inference check(const ScopePtr& scope, AstExprGlobal* global); Inference check(const ScopePtr& scope, AstExprIndexName* indexName); Inference check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 856c5da..67e840e 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -53,9 +53,7 @@ std::optional pathExprToModuleName(const ModuleName& currentModuleNa * error when we try during typechecking. */ std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& expr); -// TODO: Deprecate this code path when we move away from the old solver -LoadDefinitionFileResult loadDefinitionFileNoDCR(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view definition, - const std::string& packageName, bool captureComments); + struct SourceNode { bool hasDirtySourceModule() const @@ -209,10 +207,6 @@ public: GlobalTypes globals; GlobalTypes globalsForAutocomplete; - // TODO: remove with FFlagLuauOnDemandTypecheckers - TypeChecker typeChecker_DEPRECATED; - TypeChecker typeCheckerForAutocomplete_DEPRECATED; - ConfigResolver* configResolver; FrontendOptions options; InternalErrorReporter iceHandler; @@ -227,10 +221,11 @@ public: ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& globalScope, FrontendOptions options); + const ScopePtr& globalScope, std::function prepareModuleScope, FrontendOptions options); ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& globalScope, FrontendOptions options, bool recordJsonLog); + const ScopePtr& globalScope, std::function prepareModuleScope, FrontendOptions options, + bool recordJsonLog); } // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index efcb510..6c80828 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -348,10 +348,13 @@ public: bool intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there); bool intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); bool intersectNormalWithTy(NormalizedType& here, TypeId there); + bool normalizeIntersections(const std::vector& intersections, NormalizedType& outType); // Check for inhabitance bool isInhabited(TypeId ty, std::unordered_set seen = {}); bool isInhabited(const NormalizedType* norm, std::unordered_set seen = {}); + // Check for intersections being inhabited + bool isIntersectionInhabited(TypeId left, TypeId right); // -------- Convert back from a normalized type to a type TypeId typeFromNormal(const NormalizedType& norm); diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 24fb7db..5d92cbd 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -301,7 +301,7 @@ struct MagicFunctionCallContext TypePackId result; }; -using DcrMagicFunction = bool (*)(MagicFunctionCallContext); +using DcrMagicFunction = std::function; struct MagicRefinementContext { @@ -379,12 +379,39 @@ struct TableIndexer struct Property { - TypeId type; + static Property readonly(TypeId ty); + static Property writeonly(TypeId ty); + static Property rw(TypeId ty); // Shared read-write type. + static Property rw(TypeId read, TypeId write); // Separate read-write type. + static std::optional create(std::optional read, std::optional write); + bool deprecated = false; std::string deprecatedSuggestion; std::optional location = std::nullopt; Tags tags; std::optional documentationSymbol; + + // DEPRECATED + // TODO: Kill all constructors in favor of `Property::rw(TypeId read, TypeId write)` and friends. + Property(); + Property(TypeId readTy, bool deprecated = false, const std::string& deprecatedSuggestion = "", std::optional location = std::nullopt, + const Tags& tags = {}, const std::optional& documentationSymbol = std::nullopt); + + // DEPRECATED: Should only be called in non-RWP! We assert that the `readTy` is not nullopt. + // TODO: Kill once we don't have non-RWP. + TypeId type() const; + void setType(TypeId ty); + + // Should only be called in RWP! + // We do not assert that `readTy` nor `writeTy` are nullopt or not. + // The invariant is that at least one of them mustn't be nullopt, which we do assert here. + // TODO: Kill this in favor of exposing `readTy`/`writeTy` directly? If we do, we'll lose the asserts which will be useful while debugging. + std::optional readType() const; + std::optional writeType() const; + +private: + std::optional readTy; + std::optional writeTy; }; struct TableType @@ -552,7 +579,7 @@ struct IntersectionType struct LazyType { LazyType() = default; - LazyType(std::function thunk_DEPRECATED, std::function unwrap) + LazyType(std::function thunk_DEPRECATED, std::function unwrap) : thunk_DEPRECATED(thunk_DEPRECATED) , unwrap(unwrap) { @@ -593,7 +620,7 @@ struct LazyType std::function thunk_DEPRECATED; - std::function unwrap; + std::function unwrap; std::atomic unwrapped = nullptr; }; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index b5db3f5..cceff0d 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -11,6 +11,7 @@ #include "Luau/TxnLog.h" #include "Luau/Type.h" #include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" #include "Luau/Unifier.h" #include "Luau/UnifierSharedState.h" @@ -58,12 +59,6 @@ public: } }; -enum class ValueContext -{ - LValue, - RValue -}; - struct GlobalTypes { GlobalTypes(NotNull builtinTypes); diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 42ba405..86f20f3 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -15,6 +15,12 @@ namespace Luau struct TxnLog; struct TypeArena; +enum class ValueContext +{ + LValue, + RValue +}; + using ScopePtr = std::shared_ptr; std::optional findMetatableEntry( diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index c7dcdcc..663627d 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -9,7 +9,7 @@ #include "Luau/Type.h" LUAU_FASTINT(LuauVisitRecursionLimit) -LUAU_FASTFLAG(LuauBoundLazyTypes) +LUAU_FASTFLAG(LuauBoundLazyTypes2) namespace Luau { @@ -242,7 +242,7 @@ struct GenericTypeVisitor else { for (auto& [_name, prop] : ttv->props) - traverse(prop.type); + traverse(prop.type()); if (ttv->indexer) { @@ -265,7 +265,7 @@ struct GenericTypeVisitor if (visit(ty, *ctv)) { for (const auto& [name, prop] : ctv->props) - traverse(prop.type); + traverse(prop.type()); if (ctv->parent) traverse(*ctv->parent); @@ -294,7 +294,7 @@ struct GenericTypeVisitor } else if (auto ltv = get(ty)) { - if (FFlag::LuauBoundLazyTypes) + if (FFlag::LuauBoundLazyTypes2) { if (TypeId unwrapped = ltv->unwrapped) traverse(unwrapped); diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index cb3efe6..38f3bdf 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -501,12 +501,12 @@ std::optional getDocumentationSymbolAtPosition(const Source if (const TableType* ttv = get(parentTy)) { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) - return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); + return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); } else if (const ClassType* ctv = get(parentTy)) { if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) - return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); + return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); } } } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 42fc9a7..4b66568 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -260,7 +260,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul // already populated, it takes precedence over the property we found just now. if (result.count(name) == 0 && name != kParseNameError) { - Luau::TypeId type = Luau::follow(prop.type); + Luau::TypeId type = Luau::follow(prop.type()); TypeCorrectKind typeCorrect = indexType == PropIndexType::Key ? TypeCorrectKind::Correct : checkTypeCorrectKind(module, typeArena, builtinTypes, nodes.back(), {{}, {}}, type); @@ -287,7 +287,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul auto indexIt = mtable->props.find("__index"); if (indexIt != mtable->props.end()) { - TypeId followed = follow(indexIt->second.type); + TypeId followed = follow(indexIt->second.type()); if (get(followed) || get(followed)) { autocompleteProps(module, typeArena, builtinTypes, rootTy, followed, indexType, nodes, result, seen); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 8988b33..c55a88e 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -52,6 +52,7 @@ TypeId makeIntersection(TypeArena& arena, std::vector&& types) TypeId makeOption(NotNull builtinTypes, TypeArena& arena, TypeId t) { + LUAU_ASSERT(t); return makeUnion(arena, {builtinTypes->nilType, t}); } @@ -236,7 +237,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC auto it = stringMetatableTable->props.find("__index"); LUAU_ASSERT(it != stringMetatableTable->props.end()); - addGlobalBinding(globals, "string", it->second.type, "@luau"); + addGlobalBinding(globals, "string", it->second.type(), "@luau"); // next(t: Table, i: K?) -> (K?, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}}); @@ -301,8 +302,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC ttv->props["foreach"].deprecated = true; ttv->props["foreachi"].deprecated = true; - attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); - attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); + attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack); + attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack); } attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index f510265..450b84a 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -8,6 +8,7 @@ LUAU_FASTFLAG(DebugLuauCopyBeforeNormalizing) LUAU_FASTFLAG(LuauClonePublicInterfaceLess2) +LUAU_FASTFLAG(DebugLuauReadWriteProperties) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) @@ -17,6 +18,40 @@ namespace Luau namespace { +Property clone(const Property& prop, TypeArena& dest, CloneState& cloneState) +{ + if (FFlag::DebugLuauReadWriteProperties) + { + std::optional cloneReadTy; + if (auto ty = prop.readType()) + cloneReadTy = clone(*ty, dest, cloneState); + + std::optional cloneWriteTy; + if (auto ty = prop.writeType()) + cloneWriteTy = clone(*ty, dest, cloneState); + + std::optional cloned = Property::create(cloneReadTy, cloneWriteTy); + LUAU_ASSERT(cloned); + cloned->deprecated = prop.deprecated; + cloned->deprecatedSuggestion = prop.deprecatedSuggestion; + cloned->location = prop.location; + cloned->tags = prop.tags; + cloned->documentationSymbol = prop.documentationSymbol; + return *cloned; + } + else + { + return Property{ + clone(prop.type(), dest, cloneState), + prop.deprecated, + prop.deprecatedSuggestion, + prop.location, + prop.tags, + prop.documentationSymbol, + }; + } +} + struct TypePackCloner; /* @@ -251,7 +286,7 @@ void TypeCloner::operator()(const TableType& t) ttv->boundTo = clone(*t.boundTo, dest, cloneState); for (const auto& [name, prop] : t.props) - ttv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + ttv->props[name] = clone(prop, dest, cloneState); if (t.indexer) ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, cloneState), clone(t.indexer->indexResultType, dest, cloneState)}; @@ -285,7 +320,7 @@ void TypeCloner::operator()(const ClassType& t) seenTypes[typeId] = result; for (const auto& [name, prop] : t.props) - ctv->props[name] = {clone(prop.type, dest, cloneState), prop.deprecated, {}, prop.location, prop.tags}; + ctv->props[name] = clone(prop, dest, cloneState); if (t.parent) ctv->parent = clone(*t.parent, dest, cloneState); diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index ad7cff9..611f420 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -134,8 +134,8 @@ void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const Con } // namespace ConstraintGraphBuilder::ConstraintGraphBuilder(ModulePtr module, TypeArena* arena, NotNull moduleResolver, - NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, - NotNull dfg) + NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, + std::function prepareModuleScope, DcrLogger* logger, NotNull dfg) : module(module) , builtinTypes(builtinTypes) , arena(arena) @@ -144,6 +144,7 @@ ConstraintGraphBuilder::ConstraintGraphBuilder(ModulePtr module, TypeArena* aren , moduleResolver(moduleResolver) , ice(ice) , globalScope(globalScope) + , prepareModuleScope(std::move(prepareModuleScope)) , logger(logger) { LUAU_ASSERT(module); @@ -510,7 +511,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* l if (hasAnnotation) expectedType = varTypes.at(i); - TypeId exprType = check(scope, value, expectedType).ty; + TypeId exprType = check(scope, value, ValueContext::RValue, expectedType).ty; if (i < varTypes.size()) { if (varTypes[i]) @@ -898,7 +899,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompound ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) { - RefinementId refinement = check(scope, ifStatement->condition, std::nullopt).refinement; + RefinementId refinement = check(scope, ifStatement->condition, ValueContext::RValue, std::nullopt).refinement; ScopePtr thenScope = childScope(ifStatement->thenbody, scope); applyRefinements(thenScope, ifStatement->condition->location, refinement); @@ -1081,7 +1082,7 @@ ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareC } else { - TypeId currentTy = assignToMetatable ? metatable->props[propName].type : ctv->props[propName].type; + TypeId currentTy = assignToMetatable ? metatable->props[propName].type() : ctv->props[propName].type(); // We special-case this logic to keep the intersection flat; otherwise we // would create a ton of nested intersection types. @@ -1182,7 +1183,7 @@ InferencePack ConstraintGraphBuilder::checkPack( std::optional expectedType; if (i < expectedTypes.size()) expectedType = expectedTypes[i]; - head.push_back(check(scope, expr, expectedType).ty); + head.push_back(check(scope, expr, ValueContext::RValue, expectedType).ty); } else { @@ -1225,7 +1226,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* std::optional expectedType; if (!expectedTypes.empty()) expectedType = expectedTypes[0]; - TypeId t = check(scope, expr, expectedType).ty; + TypeId t = check(scope, expr, ValueContext::RValue, expectedType).ty; result = InferencePack{arena->addTypePack({t})}; } @@ -1332,7 +1333,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa } else if (i < exprArgs.size() - 1 || !(arg->is() || arg->is())) { - auto [ty, refinement] = check(scope, arg, expectedType); + auto [ty, refinement] = check(scope, arg, ValueContext::RValue, expectedType); args.push_back(ty); argumentRefinements.push_back(refinement); } @@ -1434,7 +1435,8 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa } } -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType, bool forceSingleton) +Inference ConstraintGraphBuilder::check( + const ScopePtr& scope, AstExpr* expr, ValueContext context, std::optional expectedType, bool forceSingleton) { RecursionCounter counter{&recursionCount}; @@ -1447,7 +1449,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st Inference result; if (auto group = expr->as()) - result = check(scope, group->expr, expectedType, forceSingleton); + result = check(scope, group->expr, ValueContext::RValue, expectedType, forceSingleton); else if (auto stringExpr = expr->as()) result = check(scope, stringExpr, expectedType, forceSingleton); else if (expr->is()) @@ -1457,7 +1459,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr, st else if (expr->is()) result = Inference{builtinTypes->nilType}; else if (auto local = expr->as()) - result = check(scope, local); + result = check(scope, local, context); else if (auto global = expr->as()) result = check(scope, global); else if (expr->is()) @@ -1566,11 +1568,11 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprConstantBo return Inference{builtinTypes->booleanType}; } -Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local) +Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprLocal* local, ValueContext context) { BreadcrumbId bc = dfg->getBreadcrumb(local); - if (auto ty = scope->lookup(bc->def)) + if (auto ty = scope->lookup(bc->def); ty && context == ValueContext::RValue) return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; else if (auto ty = scope->lookup(local->local)) return Inference{*ty, refinementArena.proposition(bc, builtinTypes->truthyType)}; @@ -1676,18 +1678,18 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* if ScopePtr thenScope = childScope(ifElse->trueExpr, scope); applyRefinements(thenScope, ifElse->trueExpr->location, refinement); - TypeId thenType = check(thenScope, ifElse->trueExpr, expectedType).ty; + TypeId thenType = check(thenScope, ifElse->trueExpr, ValueContext::RValue, expectedType).ty; ScopePtr elseScope = childScope(ifElse->falseExpr, scope); applyRefinements(elseScope, ifElse->falseExpr->location, refinementArena.negation(refinement)); - TypeId elseType = check(elseScope, ifElse->falseExpr, expectedType).ty; + TypeId elseType = check(elseScope, ifElse->falseExpr, ValueContext::RValue, expectedType).ty; return Inference{expectedType ? *expectedType : arena->addType(UnionType{{thenType, elseType}})}; } Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) { - check(scope, typeAssert->expr, std::nullopt); + check(scope, typeAssert->expr, ValueContext::RValue, std::nullopt); return Inference{resolveType(scope, typeAssert->annotation, /* inTypeArguments */ false)}; } @@ -1704,21 +1706,31 @@ std::tuple ConstraintGraphBuilder::checkBinary( { if (binary->op == AstExprBinary::And) { - auto [leftType, leftRefinement] = check(scope, binary->left, expectedType); + std::optional relaxedExpectedLhs; + + if (expectedType) + relaxedExpectedLhs = arena->addType(UnionType{{builtinTypes->falsyType, *expectedType}}); + + auto [leftType, leftRefinement] = check(scope, binary->left, ValueContext::RValue, relaxedExpectedLhs); ScopePtr rightScope = childScope(binary->right, scope); applyRefinements(rightScope, binary->right->location, leftRefinement); - auto [rightType, rightRefinement] = check(rightScope, binary->right, expectedType); + auto [rightType, rightRefinement] = check(rightScope, binary->right, ValueContext::RValue, expectedType); return {leftType, rightType, refinementArena.conjunction(leftRefinement, rightRefinement)}; } else if (binary->op == AstExprBinary::Or) { - auto [leftType, leftRefinement] = check(scope, binary->left, expectedType); + std::optional relaxedExpectedLhs; + + if (expectedType) + relaxedExpectedLhs = arena->addType(UnionType{{builtinTypes->falsyType, *expectedType}}); + + auto [leftType, leftRefinement] = check(scope, binary->left, ValueContext::RValue, relaxedExpectedLhs); ScopePtr rightScope = childScope(binary->right, scope); applyRefinements(rightScope, binary->right->location, refinementArena.negation(leftRefinement)); - auto [rightType, rightRefinement] = check(rightScope, binary->right, expectedType); + auto [rightType, rightRefinement] = check(rightScope, binary->right, ValueContext::RValue, expectedType); return {leftType, rightType, refinementArena.disjunction(leftRefinement, rightRefinement)}; } @@ -1774,8 +1786,8 @@ std::tuple ConstraintGraphBuilder::checkBinary( } else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) { - TypeId leftType = check(scope, binary->left, expectedType, true).ty; - TypeId rightType = check(scope, binary->right, expectedType, true).ty; + TypeId leftType = check(scope, binary->left, ValueContext::RValue, expectedType, true).ty; + TypeId rightType = check(scope, binary->right, ValueContext::RValue, expectedType, true).ty; RefinementId leftRefinement = nullptr; if (auto bc = dfg->getBreadcrumb(binary->left)) @@ -1795,8 +1807,8 @@ std::tuple ConstraintGraphBuilder::checkBinary( } else { - TypeId leftType = check(scope, binary->left, expectedType).ty; - TypeId rightType = check(scope, binary->right, expectedType).ty; + TypeId leftType = check(scope, binary->left, ValueContext::RValue).ty; + TypeId rightType = check(scope, binary->right, ValueContext::RValue).ty; return {leftType, rightType, nullptr}; } } @@ -1859,7 +1871,7 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) return propType; } else if (!isIndexNameEquivalent(expr)) - return check(scope, expr).ty; + return check(scope, expr, ValueContext::LValue).ty; Symbol sym; std::vector segments; @@ -1894,11 +1906,11 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) } else { - return check(scope, expr).ty; + return check(scope, expr, ValueContext::LValue).ty; } } else - return check(scope, expr).ty; + return check(scope, expr, ValueContext::LValue).ty; } LUAU_ASSERT(!segments.empty()); @@ -1908,7 +1920,7 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) auto lookupResult = scope->lookupEx(sym); if (!lookupResult) - return check(scope, expr).ty; + return check(scope, expr, ValueContext::LValue).ty; const auto [subjectBinding, symbolScope] = std::move(*lookupResult); TypeId subjectType = subjectBinding->typeId; @@ -2029,7 +2041,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp checkExpectedIndexResultType = pinnedIndexResultType; } - TypeId itemTy = check(scope, item.value, checkExpectedIndexResultType).ty; + TypeId itemTy = check(scope, item.value, ValueContext::RValue, checkExpectedIndexResultType).ty; if (isIndexedResultType && !pinnedIndexResultType) pinnedIndexResultType = itemTy; @@ -2039,7 +2051,7 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp // Even though we don't need to use the type of the item's key if // it's a string constant, we still want to check it to populate // astTypes. - TypeId keyTy = check(scope, item.key, annotatedKeyType).ty; + TypeId keyTy = check(scope, item.key, ValueContext::RValue, annotatedKeyType).ty; if (AstExprConstantString* key = item.key->as()) { @@ -2646,6 +2658,9 @@ void ConstraintGraphBuilder::prepopulateGlobalScope(const ScopePtr& globalScope, { GlobalPrepopulator gp{NotNull{globalScope.get()}, arena}; + if (prepareModuleScope) + prepareModuleScope(module->name, globalScope); + program->visit(&gp); } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 558ad2d..ec63b25 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -472,16 +472,20 @@ bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) { + TypeId generalizedType = follow(c.generalizedType); + if (isBlocked(c.sourceType)) return block(c.sourceType, constraint); + else if (get(generalizedType)) + return block(generalizedType, constraint); std::optional generalized = quantify(arena, c.sourceType, constraint->scope); if (generalized) { - if (isBlocked(c.generalizedType)) - asMutable(c.generalizedType)->ty.emplace(*generalized); + if (get(generalizedType)) + asMutable(generalizedType)->ty.emplace(*generalized); else - unify(c.generalizedType, *generalized, constraint->scope); + unify(generalizedType, *generalized, constraint->scope); } else { @@ -505,10 +509,8 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull instantiated = inst.substitute(c.superType); LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS - if (isBlocked(c.subType)) - asMutable(c.subType)->ty.emplace(*instantiated); - else - unify(c.subType, *instantiated, constraint->scope); + LUAU_ASSERT(get(c.subType)); + asMutable(c.subType)->ty.emplace(*instantiated); unblock(c.subType); @@ -586,6 +588,8 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull(resultType)); + bool isLogical = c.op == AstExprBinary::Op::And || c.op == AstExprBinary::Op::Or; /* Compound assignments create constraints of the form @@ -979,6 +983,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul } auto bindResult = [this, &c](TypeId result) { + LUAU_ASSERT(get(c.target)); asMutable(c.target)->ty.emplace(result); unblock(c.target); }; @@ -1280,6 +1285,8 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull(expectedType)) return block(expectedType, constraint); + LUAU_ASSERT(get(c.resultType)); + TypeId bindTo = maybeSingleton(expectedType) ? c.singletonType : c.multitonType; asMutable(c.resultType)->ty.emplace(bindTo); unblock(c.resultType); @@ -1291,6 +1298,8 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNull(c.resultType)); + if (isBlocked(subjectType) || get(subjectType)) return block(subjectType, constraint); @@ -1351,7 +1360,7 @@ static void updateTheTableType( if (it == tbl->props.end()) return; - t = follow(it->second.type); + t = follow(it->second.type()); } // The last path segment should not be a property of the table at all. @@ -1388,7 +1397,7 @@ static void updateTheTableType( if (!tt) return; - tt->props[lastSegment].type = replaceTy; + tt->props[lastSegment].setType(replaceTy); } bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull constraint, bool force) @@ -1853,7 +1862,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa else if (auto ttv = getMutable(subjectType)) { if (auto prop = ttv->props.find(propName); prop != ttv->props.end()) - return {{}, prop->second.type}; + return {{}, prop->second.type()}; else if (ttv->indexer && maybeString(ttv->indexer->indexType)) return {{}, ttv->indexer->indexResultType}; else if (ttv->state == TableState::Free) @@ -1881,7 +1890,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa // TODO: __index can be an overloaded function. - TypeId indexType = follow(indexProp->second.type); + TypeId indexType = follow(indexProp->second.type()); if (auto ft = get(indexType)) { @@ -1902,7 +1911,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa else if (auto ct = get(subjectType)) { if (auto p = lookupClassProp(ct, propName)) - return {{}, p->type}; + return {{}, p->type()}; } else if (auto pt = get(subjectType); pt && pt->metatable) { @@ -1913,7 +1922,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa if (indexProp == metatable->props.end()) return {{}, std::nullopt}; - return lookupTableProp(indexProp->second.type, propName, seen); + return lookupTableProp(indexProp->second.type(), propName, seen); } else if (auto ft = get(subjectType)) { diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 916dd1d..486ef69 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -32,8 +32,8 @@ LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) -LUAU_FASTFLAGVARIABLE(LuauOnDemandTypecheckers, false) LUAU_FASTFLAG(LuauRequirePathTrueModuleName) +LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) namespace Luau { @@ -133,10 +133,6 @@ static void persistCheckedTypes(ModulePtr checkedModule, GlobalTypes& globals, S LoadDefinitionFileResult Frontend::loadDefinitionFile(GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName, bool captureComments, bool typeCheckForAutocomplete) { - if (!FFlag::DebugLuauDeferredConstraintResolution && !FFlag::LuauOnDemandTypecheckers) - return Luau::loadDefinitionFileNoDCR(typeCheckForAutocomplete ? typeCheckerForAutocomplete_DEPRECATED : typeChecker_DEPRECATED, - typeCheckForAutocomplete ? globalsForAutocomplete : globals, targetScope, source, packageName, captureComments); - LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); Luau::SourceModule sourceModule; @@ -154,28 +150,6 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(GlobalTypes& globals, Scop return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; } -LoadDefinitionFileResult loadDefinitionFileNoDCR(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, - const std::string& packageName, bool captureComments) -{ - LUAU_ASSERT(!FFlag::LuauOnDemandTypecheckers); - LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); - - Luau::SourceModule sourceModule; - Luau::ParseResult parseResult = parseSourceForModule(source, sourceModule, captureComments); - - if (parseResult.errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; - - ModulePtr checkedModule = typeChecker.check(sourceModule, Mode::Definition); - - if (checkedModule->errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; - - persistCheckedTypes(checkedModule, globals, targetScope, packageName); - - return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; -} - std::vector parsePathExpr(const AstExpr& pathExpr) { const AstExprIndexName* indexName = pathExpr.as(); @@ -409,8 +383,6 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c , moduleResolverForAutocomplete(this) , globals(builtinTypes) , globalsForAutocomplete(builtinTypes) - , typeChecker_DEPRECATED(globals.globalScope, &moduleResolver, builtinTypes, &iceHandler) - , typeCheckerForAutocomplete_DEPRECATED(globalsForAutocomplete.globalScope, &moduleResolverForAutocomplete, builtinTypes, &iceHandler) , configResolver(configResolver) , options(options) { @@ -479,68 +451,32 @@ CheckResult Frontend::check(const ModuleName& name, std::optional 0) - typeCheckerForAutocomplete_DEPRECATED.instantiationChildLimit = - std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckerForAutocomplete_DEPRECATED.instantiationChildLimit = std::nullopt; - - if (FInt::LuauTypeInferIterationLimit > 0) - typeCheckerForAutocomplete_DEPRECATED.unifierIterationLimit = - std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckerForAutocomplete_DEPRECATED.unifierIterationLimit = std::nullopt; - - moduleForAutocomplete = - FFlag::DebugLuauDeferredConstraintResolution - ? check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, /*recordJsonLog*/ false, {}) - : typeCheckerForAutocomplete_DEPRECATED.check(sourceModule, Mode::Strict, environmentScope); - } + if (autocompleteTimeLimit != 0.0) + typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; else - { - // The autocomplete typecheck is always in strict mode with DM awareness - // to provide better type information for IDE features - TypeCheckLimits typeCheckLimits; + typeCheckLimits.finishTime = std::nullopt; - if (autocompleteTimeLimit != 0.0) - typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; - else - typeCheckLimits.finishTime = std::nullopt; + // TODO: This is a dirty ad hoc solution for autocomplete timeouts + // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit + // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle + if (FInt::LuauTarjanChildLimit > 0) + typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.instantiationChildLimit = std::nullopt; - // TODO: This is a dirty ad hoc solution for autocomplete timeouts - // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit - // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle - if (FInt::LuauTarjanChildLimit > 0) - typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckLimits.instantiationChildLimit = std::nullopt; + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.unifierIterationLimit = std::nullopt; - if (FInt::LuauTypeInferIterationLimit > 0) - typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckLimits.unifierIterationLimit = std::nullopt; - - moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, - /*recordJsonLog*/ false, typeCheckLimits); - } + ModulePtr moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, + /*recordJsonLog*/ false, typeCheckLimits); resolver.setModule(moduleName, moduleForAutocomplete); @@ -565,21 +501,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalget(global.c_str()); if (name.value) - result->bindings[name].typeId = FFlag::LuauOnDemandTypecheckers ? builtinTypes->anyType : typeChecker_DEPRECATED.anyType; + result->bindings[name].typeId = builtinTypes->anyType; } } @@ -856,15 +778,17 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& parentScope, FrontendOptions options) + const ScopePtr& parentScope, std::function prepareModuleScope, FrontendOptions options) { const bool recordJsonLog = FFlag::DebugLuauLogSolverToJson; - return check(sourceModule, requireCycles, builtinTypes, iceHandler, moduleResolver, fileResolver, parentScope, options, recordJsonLog); + return check(sourceModule, requireCycles, builtinTypes, iceHandler, moduleResolver, fileResolver, parentScope, std::move(prepareModuleScope), + options, recordJsonLog); } ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, - const ScopePtr& parentScope, FrontendOptions options, bool recordJsonLog) + const ScopePtr& parentScope, std::function prepareModuleScope, FrontendOptions options, + bool recordJsonLog) { ModulePtr result = std::make_shared(); result->name = sourceModule.name; @@ -897,6 +821,7 @@ ModulePtr check(const SourceModule& sourceModule, const std::vector seen) { for (const auto& [_, prop] : ttv->props) { - if (!isInhabited(prop.type, seen)) + if (!isInhabited(prop.type(), seen)) return false; } return true; @@ -316,6 +316,20 @@ bool Normalizer::isInhabited(TypeId ty, std::unordered_set seen) return isInhabited(norm, seen); } +bool Normalizer::isIntersectionInhabited(TypeId left, TypeId right) +{ + left = follow(left); + right = follow(right); + std::unordered_set seen = {}; + seen.insert(left); + seen.insert(right); + + NormalizedType norm{builtinTypes}; + if (!normalizeIntersections({left, right}, norm)) + return false; + return isInhabited(&norm, seen); +} + static int tyvarIndex(TypeId ty) { if (const GenericType* gtv = get(ty)) @@ -593,6 +607,23 @@ const NormalizedType* Normalizer::normalize(TypeId ty) return result; } +bool Normalizer::normalizeIntersections(const std::vector& intersections, NormalizedType& outType) +{ + if (!arena) + sharedState->iceHandler->ice("Normalizing types outside a module"); + NormalizedType norm{builtinTypes}; + norm.tops = builtinTypes->anyType; + // Now we need to intersect the two types + for (auto ty : intersections) + if (!intersectNormalWithTy(norm, ty)) + return false; + + if (!unionNormals(outType, norm)) + return false; + + return true; +} + void Normalizer::clearNormal(NormalizedType& norm) { norm.tops = builtinTypes->neverType; @@ -2134,9 +2165,9 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there { const auto& [_name, tprop] = *tfound; // TODO: variance issues here, which can't be fixed until we have read/write property types - prop.type = intersectionType(hprop.type, tprop.type); - hereSubThere &= (prop.type == hprop.type); - thereSubHere &= (prop.type == tprop.type); + prop.setType(intersectionType(hprop.type(), tprop.type())); + hereSubThere &= (prop.type() == hprop.type()); + thereSubHere &= (prop.type() == tprop.type()); } // TODO: string indexers result.props[name] = prop; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 0b8f462..0a7975f 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -116,7 +116,7 @@ void quantify(TypeId ty, TypeLevel level) for (const auto& [_, prop] : ttv->props) { - auto ftv = getMutable(follow(prop.type)); + auto ftv = getMutable(follow(prop.type())); if (!ftv || !ftv->hasSelf) continue; diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 9621721..6a600b6 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -219,7 +219,7 @@ void Tarjan::visitChildren(TypeId ty, int index) { LUAU_ASSERT(!ttv->boundTo); for (const auto& [name, prop] : ttv->props) - visitChild(prop.type); + visitChild(prop.type()); if (ttv->indexer) { visitChild(ttv->indexer->indexType); @@ -258,7 +258,7 @@ void Tarjan::visitChildren(TypeId ty, int index) else if (const ClassType* ctv = get(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) { for (const auto& [name, prop] : ctv->props) - visitChild(prop.type); + visitChild(prop.type()); if (ctv->parent) visitChild(*ctv->parent); @@ -750,7 +750,7 @@ void Substitution::replaceChildren(TypeId ty) { LUAU_ASSERT(!ttv->boundTo); for (auto& [name, prop] : ttv->props) - prop.type = replace(prop.type); + prop.setType(replace(prop.type())); if (ttv->indexer) { ttv->indexer->indexType = replace(ttv->indexer->indexType); @@ -789,7 +789,7 @@ void Substitution::replaceChildren(TypeId ty) else if (ClassType* ctv = getMutable(ty); FFlag::LuauClassTypeVarsInSubstitution && ctv) { for (auto& [name, prop] : ctv->props) - prop.type = replace(prop.type); + prop.setType(replace(prop.type())); if (ctv->parent) ctv->parent = replace(*ctv->parent); diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 117d39d..8d889cb 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -171,7 +171,7 @@ void StateDot::visitChildren(TypeId ty, int index) return visitChild(*ttv->boundTo, index, "boundTo"); for (const auto& [name, prop] : ttv->props) - visitChild(prop.type, index, name.c_str()); + visitChild(prop.type(), index, name.c_str()); if (ttv->indexer) { visitChild(ttv->indexer->indexType, index, "[index]"); @@ -250,7 +250,7 @@ void StateDot::visitChildren(TypeId ty, int index) finishNode(); for (const auto& [name, prop] : ctv->props) - visitChild(prop.type, index, name.c_str()); + visitChild(prop.type(), index, name.c_str()); if (ctv->parent) visitChild(*ctv->parent, index, "[parent]"); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 46d2e8f..ea3ab57 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -660,7 +660,7 @@ struct TypeStringifier state.emit("\"]"); } state.emit(": "); - stringify(prop.type); + stringify(prop.type()); comma = true; ++index; } diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index e4d9ab3..2ca39b4 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -26,7 +26,8 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) -LUAU_FASTFLAGVARIABLE(LuauBoundLazyTypes, false) +LUAU_FASTFLAG(DebugLuauReadWriteProperties) +LUAU_FASTFLAGVARIABLE(LuauBoundLazyTypes2, false) namespace Luau { @@ -57,7 +58,7 @@ TypeId follow(TypeId t) TypeId follow(TypeId t, std::function mapper) { auto advance = [&mapper](TypeId ty) -> std::optional { - if (FFlag::LuauBoundLazyTypes) + if (FFlag::LuauBoundLazyTypes2) { TypeId mapped = mapper(ty); @@ -74,7 +75,8 @@ TypeId follow(TypeId t, std::function mapper) if (unwrapped) return unwrapped; - unwrapped = ltv->unwrap(*ltv); + ltv->unwrap(*ltv); + unwrapped = ltv->unwrapped.load(); if (!unwrapped) throw InternalCompilerError("Lazy Type didn't fill in unwrapped type field"); @@ -109,7 +111,7 @@ TypeId follow(TypeId t, std::function mapper) } }; - if (!FFlag::LuauBoundLazyTypes) + if (!FFlag::LuauBoundLazyTypes2) force(t); TypeId cycleTester = t; // Null once we've determined that there is no cycle @@ -120,7 +122,7 @@ TypeId follow(TypeId t, std::function mapper) while (true) { - if (!FFlag::LuauBoundLazyTypes) + if (!FFlag::LuauBoundLazyTypes2) force(t); auto a1 = advance(t); @@ -622,6 +624,92 @@ FunctionType::FunctionType(TypeLevel level, Scope* scope, std::vector ge { } +Property::Property() {} + +Property::Property(TypeId readTy, bool deprecated, const std::string& deprecatedSuggestion, std::optional location, const Tags& tags, + const std::optional& documentationSymbol) + : deprecated(deprecated) + , deprecatedSuggestion(deprecatedSuggestion) + , location(location) + , tags(tags) + , documentationSymbol(documentationSymbol) + , readTy(readTy) + , writeTy(readTy) +{ + LUAU_ASSERT(!FFlag::DebugLuauReadWriteProperties); +} + +Property Property::readonly(TypeId ty) +{ + LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); + + Property p; + p.readTy = ty; + return p; +} + +Property Property::writeonly(TypeId ty) +{ + LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); + + Property p; + p.writeTy = ty; + return p; +} + +Property Property::rw(TypeId ty) +{ + return Property::rw(ty, ty); +} + +Property Property::rw(TypeId read, TypeId write) +{ + LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); + + Property p; + p.readTy = read; + p.writeTy = write; + return p; +} + +std::optional Property::create(std::optional read, std::optional write) +{ + if (read && !write) + return Property::readonly(*read); + else if (!read && write) + return Property::writeonly(*write); + else if (read && write) + return Property::rw(*read, *write); + else + return std::nullopt; +} + +TypeId Property::type() const +{ + LUAU_ASSERT(!FFlag::DebugLuauReadWriteProperties); + LUAU_ASSERT(readTy); + return *readTy; +} + +void Property::setType(TypeId ty) +{ + readTy = ty; +} + +std::optional Property::readType() const +{ + LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); + LUAU_ASSERT(!(bool(readTy) && bool(writeTy))); + return readTy; +} + +std::optional Property::writeType() const +{ + LUAU_ASSERT(FFlag::DebugLuauReadWriteProperties); + LUAU_ASSERT(!(bool(readTy) && bool(writeTy))); + return writeTy; +} + TableType::TableType(TableState state, TypeLevel level, Scope* scope) : state(state) , level(level) @@ -709,7 +797,7 @@ bool areEqual(SeenSet& seen, const TableType& lhs, const TableType& rhs) if (l->first != r->first) return false; - if (!areEqual(seen, *l->second.type, *r->second.type)) + if (!areEqual(seen, *l->second.type(), *r->second.type())) return false; ++l; ++r; @@ -1011,7 +1099,7 @@ void persist(TypeId ty) LUAU_ASSERT(ttv->state != TableState::Free && ttv->state != TableState::Unsealed); for (const auto& [_name, prop] : ttv->props) - queue.push_back(prop.type); + queue.push_back(prop.type()); if (ttv->indexer) { @@ -1022,7 +1110,7 @@ void persist(TypeId ty) else if (auto ctv = get(t)) { for (const auto& [_name, prop] : ctv->props) - queue.push_back(prop.type); + queue.push_back(prop.type()); } else if (auto utv = get(t)) { diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 7ed4eb4..86f7816 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -180,7 +180,7 @@ public: char* name = allocateString(*allocator, propName); props.data[idx].name = AstName(name); - props.data[idx].type = Luau::visit(*this, prop.type->ty); + props.data[idx].type = Luau::visit(*this, prop.type()->ty); props.data[idx].location = Location(); idx++; } @@ -221,7 +221,7 @@ public: char* name = allocateString(*allocator, propName); props.data[idx].name = AstName{name}; - props.data[idx].type = Luau::visit(*this, prop.type->ty); + props.data[idx].type = Luau::visit(*this, prop.type()->ty); props.data[idx].location = Location(); idx++; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 893f51d..a103df1 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -171,6 +171,22 @@ struct TypeChecker2 return follow(*tp); } + TypeId lookupExpectedType(AstExpr* expr) + { + if (TypeId* ty = module->astExpectedTypes.find(expr)) + return follow(*ty); + + return builtinTypes->anyType; + } + + TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena) + { + if (TypeId* ty = module->astExpectedTypes.find(expr)) + return arena.addTypePack(TypePack{{follow(*ty)}, std::nullopt}); + + return builtinTypes->anyTypePack; + } + TypePackId reconstructPack(AstArray exprs, TypeArena& arena) { if (exprs.size == 0) @@ -208,12 +224,6 @@ struct TypeChecker2 return bestScope; } - enum ValueContext - { - LValue, - RValue - }; - void visit(AstStat* stat) { auto pusher = pushStack(stat); @@ -272,7 +282,7 @@ struct TypeChecker2 void visit(AstStatIf* ifStatement) { - visit(ifStatement->condition, RValue); + visit(ifStatement->condition, ValueContext::RValue); visit(ifStatement->thenbody); if (ifStatement->elsebody) visit(ifStatement->elsebody); @@ -280,14 +290,14 @@ struct TypeChecker2 void visit(AstStatWhile* whileStatement) { - visit(whileStatement->condition, RValue); + visit(whileStatement->condition, ValueContext::RValue); visit(whileStatement->body); } void visit(AstStatRepeat* repeatStatement) { visit(repeatStatement->body); - visit(repeatStatement->condition, RValue); + visit(repeatStatement->condition, ValueContext::RValue); } void visit(AstStatBreak*) {} @@ -314,12 +324,12 @@ struct TypeChecker2 } for (AstExpr* expr : ret->list) - visit(expr, RValue); + visit(expr, ValueContext::RValue); } void visit(AstStatExpr* expr) { - visit(expr->expr, RValue); + visit(expr->expr, ValueContext::RValue); } void visit(AstStatLocal* local) @@ -331,7 +341,7 @@ struct TypeChecker2 const bool isPack = value && (value->is() || value->is()); if (value) - visit(value, RValue); + visit(value, ValueContext::RValue); if (i != local->values.size - 1 || !isPack) { @@ -412,7 +422,7 @@ struct TypeChecker2 if (!expr) return; - visit(expr, RValue); + visit(expr, ValueContext::RValue); reportErrors(tryUnify(scope, expr->location, lookupType(expr), builtinTypes->numberType)); }; @@ -432,7 +442,7 @@ struct TypeChecker2 } for (AstExpr* expr : forInStatement->values) - visit(expr, RValue); + visit(expr, ValueContext::RValue); visit(forInStatement->body); @@ -643,11 +653,11 @@ struct TypeChecker2 for (size_t i = 0; i < count; ++i) { AstExpr* lhs = assign->vars.data[i]; - visit(lhs, LValue); + visit(lhs, ValueContext::LValue); TypeId lhsType = lookupType(lhs); AstExpr* rhs = assign->values.data[i]; - visit(rhs, RValue); + visit(rhs, ValueContext::RValue); TypeId rhsType = lookupType(rhs); if (get(lhsType)) @@ -671,7 +681,7 @@ struct TypeChecker2 void visit(AstStatFunction* stat) { - visit(stat->name, LValue); + visit(stat->name, ValueContext::LValue); visit(stat->func); } @@ -724,7 +734,7 @@ struct TypeChecker2 void visit(AstStatError* stat) { for (AstExpr* expr : stat->expressions) - visit(expr, RValue); + visit(expr, ValueContext::RValue); for (AstStat* s : stat->statements) visit(s); @@ -926,7 +936,7 @@ struct TypeChecker2 TypeArena* arena = &testArena; Instantiation instantiation{TxnLog::empty(), arena, TypeLevel{}, stack.back()}; - TypePackId expectedRetType = lookupPack(call); + TypePackId expectedRetType = lookupExpectedPack(call, *arena); TypeId functionType = lookupType(call->func); TypeId testFunctionType = functionType; TypePack args; @@ -1105,10 +1115,10 @@ struct TypeChecker2 void visit(AstExprCall* call) { - visit(call->func, RValue); + visit(call->func, ValueContext::RValue); for (AstExpr* arg : call->args) - visit(arg, RValue); + visit(arg, ValueContext::RValue); visitCall(call); } @@ -1158,7 +1168,7 @@ struct TypeChecker2 void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context) { - visit(expr, RValue); + visit(expr, ValueContext::RValue); TypeId leftType = stripFromNilAndReport(lookupType(expr), location); checkIndexTypeFromType(leftType, propName, location, context); @@ -1179,8 +1189,8 @@ struct TypeChecker2 } // TODO! - visit(indexExpr->expr, LValue); - visit(indexExpr->index, RValue); + visit(indexExpr->expr, ValueContext::LValue); + visit(indexExpr->index, ValueContext::RValue); NotNull scope = stack.back(); @@ -1242,14 +1252,14 @@ struct TypeChecker2 for (const AstExprTable::Item& item : expr->items) { if (item.key) - visit(item.key, LValue); - visit(item.value, RValue); + visit(item.key, ValueContext::LValue); + visit(item.value, ValueContext::RValue); } } void visit(AstExprUnary* expr) { - visit(expr->expr, RValue); + visit(expr->expr, ValueContext::RValue); NotNull scope = stack.back(); TypeId operandType = lookupType(expr->expr); @@ -1330,8 +1340,8 @@ struct TypeChecker2 TypeId visit(AstExprBinary* expr, AstNode* overrideKey = nullptr) { - visit(expr->left, LValue); - visit(expr->right, LValue); + visit(expr->left, ValueContext::LValue); + visit(expr->right, ValueContext::LValue); NotNull scope = stack.back(); @@ -1363,11 +1373,14 @@ struct TypeChecker2 return leftType; } + bool typesHaveIntersection = normalizer.isIntersectionInhabited(leftType, rightType); if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end()) { std::optional leftMt = getMetatable(leftType, builtinTypes); std::optional rightMt = getMetatable(rightType, builtinTypes); bool matches = leftMt == rightMt; + + if (isEquality && !matches) { auto testUnion = [&matches, builtinTypes = this->builtinTypes](const UnionType* utv, std::optional otherMt) { @@ -1390,6 +1403,13 @@ struct TypeChecker2 { testUnion(utv, leftMt); } + + // If either left or right has no metatable (or both), we need to consider if + // there are values in common that could possibly inhabit the type (and thus equality could be considered) + if (!leftMt.has_value() || !rightMt.has_value()) + { + matches = matches || typesHaveIntersection; + } } if (!matches && isComparison) @@ -1584,7 +1604,7 @@ struct TypeChecker2 void visit(AstExprTypeAssertion* expr) { - visit(expr->expr, RValue); + visit(expr->expr, ValueContext::RValue); visit(expr->annotation); TypeId annotationType = lookupAnnotation(expr->annotation); @@ -1603,22 +1623,22 @@ struct TypeChecker2 void visit(AstExprIfElse* expr) { // TODO! - visit(expr->condition, RValue); - visit(expr->trueExpr, RValue); - visit(expr->falseExpr, RValue); + visit(expr->condition, ValueContext::RValue); + visit(expr->trueExpr, ValueContext::RValue); + visit(expr->falseExpr, ValueContext::RValue); } void visit(AstExprInterpString* interpString) { for (AstExpr* expr : interpString->expressions) - visit(expr, RValue); + visit(expr, ValueContext::RValue); } void visit(AstExprError* expr) { // TODO! for (AstExpr* e : expr->expressions) - visit(e, RValue); + visit(e, ValueContext::RValue); } /** Extract a TypeId for the first type of the provided pack. @@ -1858,7 +1878,7 @@ struct TypeChecker2 void visit(AstTypeTypeof* ty) { - visit(ty->expr, RValue); + visit(ty->expr, ValueContext::RValue); } void visit(AstTypeUnion* ty) @@ -2109,7 +2129,7 @@ struct TypeChecker2 // because classes come into being with full knowledge of their // shape. We instead want to report the unknown property error of // the `else` branch. - else if (context == LValue && !get(tableTy)) + else if (context == ValueContext::LValue && !get(tableTy)) reportError(CannotExtendTable{tableTy, CannotExtendTable::Property, prop}, location); else reportError(UnknownProperty{tableTy, prop}, location); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index a8c093a..8f9e185 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1786,7 +1786,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& } else { - TypeId currentTy = assignTo[propName].type; + TypeId currentTy = assignTo[propName].type(); // We special-case this logic to keep the intersection flat; otherwise we // would create a ton of nested intersection types. @@ -2076,7 +2076,7 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( if (TableType* tableType = getMutableTableType(type)) { if (auto it = tableType->props.find(name); it != tableType->props.end()) - return it->second.type; + return it->second.type(); else if (auto indexer = tableType->indexer) { // TODO: Property lookup should work with string singletons or unions thereof as the indexer key type. @@ -2104,7 +2104,7 @@ std::optional TypeChecker::getIndexTypeFromTypeImpl( { const Property* prop = lookupClassProp(cls, name); if (prop) - return prop->type; + return prop->type(); } else if (const UnionType* utv = get(type)) { @@ -2294,9 +2294,9 @@ TypeId TypeChecker::checkExprTable( if (it != expectedTable->props.end()) { Property expectedProp = it->second; - ErrorVec errors = tryUnify(exprType, expectedProp.type, scope, k->location); + ErrorVec errors = tryUnify(exprType, expectedProp.type(), scope, k->location); if (errors.empty()) - exprType = expectedProp.type; + exprType = expectedProp.type(); } else if (expectedTable->indexer && maybeString(expectedTable->indexer->indexType)) { @@ -2390,7 +2390,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (expectedTable) { if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) - expectedResultType = prop->second.type; + expectedResultType = prop->second.type(); else if (expectedIndexType && maybeString(*expectedIndexType)) expectedResultType = expectedIndexResultType; } @@ -2402,7 +2402,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (const TableType* ttv = get(follow(expectedOption))) { if (auto prop = ttv->props.find(key->value.data); prop != ttv->props.end()) - expectedResultTypes.push_back(prop->second.type); + expectedResultTypes.push_back(prop->second.type()); else if (ttv->indexer && maybeString(ttv->indexer->indexType)) expectedResultTypes.push_back(ttv->indexer->indexResultType); } @@ -3257,13 +3257,13 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex const auto& it = lhsTable->props.find(name); if (it != lhsTable->props.end()) { - return it->second.type; + return it->second.type(); } else if ((ctx == ValueContext::LValue && lhsTable->state == TableState::Unsealed) || lhsTable->state == TableState::Free) { TypeId theType = freshType(scope); Property& property = lhsTable->props[name]; - property.type = theType; + property.setType(theType); property.location = expr.indexLocation; return theType; } @@ -3303,7 +3303,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex return errorRecoveryType(scope); } - return prop->type; + return prop->type(); } else if (get(lhs)) { @@ -3351,7 +3351,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); return errorRecoveryType(scope); } - return prop->type; + return prop->type(); } } else if (FFlag::LuauAllowIndexClassParameters) @@ -3378,13 +3378,13 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex const auto& it = exprTable->props.find(value->value.data); if (it != exprTable->props.end()) { - return it->second.type; + return it->second.type(); } else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free) { TypeId resultType = freshType(scope); Property& property = exprTable->props[value->value.data]; - property.type = resultType; + property.setType(resultType); property.location = expr.index->location; return resultType; } @@ -3467,13 +3467,12 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T Name name = indexName->index.value; if (ttv->props.count(name)) - return ttv->props[name].type; + return ttv->props[name].type(); Property& property = ttv->props[name]; - - property.type = freshTy(); + property.setType(freshTy()); property.location = indexName->indexLocation; - return property.type; + return property.type(); } else if (funName.is()) return errorRecoveryType(scope); diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index 031c599..b81cca7 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -70,7 +70,7 @@ TypeId TypeReductionMemoization::memoize(TypeId ty, TypeId reducedTy) else if (auto tt = get(reducedTy)) { for (auto& [k, p] : tt->props) - irreducible &= isIrreducible(p.type); + irreducible &= isIrreducible(p.type()); if (tt->indexer) { @@ -539,7 +539,7 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) // even if we have the corresponding property in the other one. if (auto other = t2->props.find(name); other != t2->props.end()) { - TypeId propTy = apply(&TypeReducer::intersectionType, prop.type, other->second.type); + TypeId propTy = apply(&TypeReducer::intersectionType, prop.type(), other->second.type()); if (get(propTy)) return builtinTypes->neverType; // { p : string } & { p : number } ~ { p : string & number } ~ { p : never } ~ never else @@ -554,7 +554,7 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) // TODO: And vice versa, t2 properties against t1 indexer if it exists, // even if we have the corresponding property in the other one. if (!t1->props.count(name)) - table->props[name] = {reduce(prop.type)}; // {} & { p : string & string } ~ { p : string } + table->props[name] = {reduce(prop.type())}; // {} & { p : string & string } ~ { p : string } } if (t1->indexer && t2->indexer) @@ -966,11 +966,11 @@ TypeId TypeReducer::tableType(TypeId ty) for (auto& [name, prop] : copied->props) { - TypeId propTy = reduce(prop.type); + TypeId propTy = reduce(prop.type()); if (get(propTy)) return builtinTypes->neverType; else - prop.type = propTy; + prop.setType(propTy); } if (copied->indexer) diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index e5029e5..9124e2f 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -34,7 +34,7 @@ std::optional findMetatableEntry( auto it = mtt->props.find(entry); if (it != mtt->props.end()) - return it->second.type; + return it->second.type(); else return std::nullopt; } @@ -49,7 +49,7 @@ std::optional findTablePropertyRespectingMeta( { const auto& it = tableType->props.find(name); if (it != tableType->props.end()) - return it->second.type; + return it->second.type(); } std::optional mtIndex = findMetatableEntry(builtinTypes, errors, ty, "__index", location); @@ -67,7 +67,7 @@ std::optional findTablePropertyRespectingMeta( { const auto& fit = itt->props.find(name); if (fit != itt->props.end()) - return fit->second.type; + return fit->second.type(); } else if (const auto& itf = get(index)) { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 3f4e34f..3ca9359 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -351,7 +351,7 @@ static std::optional> getTableMatchT { for (auto&& [name, prop] : ttv->props) { - if (auto sing = get(follow(prop.type))) + if (auto sing = get(follow(prop.type()))) return {{name, sing}}; } } @@ -2003,7 +2003,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto subIter = subTable->props.find(propName); - if (subIter == subTable->props.end() && subTable->state == TableState::Unsealed && !isOptional(superProp.type)) + if (subIter == subTable->props.end() && subTable->state == TableState::Unsealed && !isOptional(superProp.type())) missingProperties.push_back(propName); } @@ -2044,7 +2044,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(r->second.type, prop.type); + innerState.tryUnify_(r->second.type(), prop.type()); checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); @@ -2060,7 +2060,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subTable->indexer->indexResultType, prop.type); + innerState.tryUnify_(subTable->indexer->indexResultType, prop.type()); checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); @@ -2068,7 +2068,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) log.concat(std::move(innerState.log)); failure |= innerState.failure; } - else if (subTable->state == TableState::Unsealed && isOptional(prop.type)) + else if (subTable->state == TableState::Unsealed && isOptional(prop.type())) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) @@ -2123,7 +2123,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(superTable->indexer->indexResultType, prop.type); + innerState.tryUnify_(superTable->indexer->indexResultType, prop.type()); checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); @@ -2137,7 +2137,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // TODO: file a JIRA // TODO: hopefully readonly/writeonly properties will fix this. Property clone = prop; - clone.type = deeplyOptional(clone.type); + clone.setType(deeplyOptional(clone.type())); PendingType* pendingSuper = log.queue(superTy); TableType* pendingSuperTtv = getMutable(pendingSuper); @@ -2297,7 +2297,7 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) if (auto it = mttv->props.find("__index"); it != mttv->props.end()) { - TypeId ty = it->second.type; + TypeId ty = it->second.type(); Unifier child = makeChildUnifier(); child.tryUnify_(ty, superTy); @@ -2349,7 +2349,7 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see result = types->addType(*ttv); TableType* resultTtv = getMutable(result); for (auto& [name, prop] : resultTtv->props) - prop.type = deeplyOptional(prop.type, seen); + prop.setType(deeplyOptional(prop.type(), seen)); return types->addType(UnionType{{builtinTypes->nilType, result}}); } else @@ -2394,7 +2394,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { if (std::optional mtPropTy = findTablePropertyRespectingMeta(superTy, propName)) { - innerState.tryUnify(prop.type, *mtPropTy); + innerState.tryUnify(prop.type(), *mtPropTy); } else { @@ -2505,7 +2505,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) else { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(classProp->type, prop.type); + innerState.tryUnify_(classProp->type(), prop.type()); checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); @@ -2674,7 +2674,7 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas else if (auto table = state.log.getMutable(ty)) { for (const auto& [_name, prop] : table->props) - queue.push_back(prop.type); + queue.push_back(prop.type()); if (table->indexer) { diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index bcf70f2..4303364 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -703,10 +703,26 @@ struct CompileStats size_t lines; size_t bytecode; size_t codegen; + + double readTime; + double miscTime; + double parseTime; + double compileTime; + double codegenTime; }; +static double recordDeltaTime(double& timer) +{ + double now = Luau::TimeTrace::getClock(); + double delta = now - timer; + timer = now; + return delta; +} + static bool compileFile(const char* name, CompileFormat format, CompileStats& stats) { + double currts = Luau::TimeTrace::getClock(); + std::optional source = readFile(name); if (!source) { @@ -714,6 +730,8 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st return false; } + stats.readTime += recordDeltaTime(currts); + // NOTE: Normally, you should use Luau::compile or luau_compile (see lua_require as an example) // This function is much more complicated because it supports many output human-readable formats through internal interfaces @@ -753,6 +771,8 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st bcb.setDumpSource(*source); } + stats.miscTime += recordDeltaTime(currts); + Luau::Allocator allocator; Luau::AstNameTable names(allocator); Luau::ParseResult result = Luau::Parser::parse(source->c_str(), source->size(), names, allocator); @@ -761,9 +781,11 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st throw Luau::ParseErrors(result.errors); stats.lines += result.lines; + stats.parseTime += recordDeltaTime(currts); Luau::compileOrThrow(bcb, result, names, copts()); stats.bytecode += bcb.getBytecode().size(); + stats.compileTime += recordDeltaTime(currts); switch (format) { @@ -784,6 +806,7 @@ static bool compileFile(const char* name, CompileFormat format, CompileStats& st break; case CompileFormat::CodegenNull: stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size(); + stats.codegenTime += recordDeltaTime(currts); break; case CompileFormat::Null: break; @@ -998,18 +1021,18 @@ int replMain(int argc, char** argv) CompileStats stats = {}; int failed = 0; - double startTime = Luau::TimeTrace::getClock(); for (const std::string& path : files) failed += !compileFile(path.c_str(), compileFormat, stats); - double duration = Luau::TimeTrace::getClock() - startTime; - if (compileFormat == CompileFormat::Null) - printf("Compiled %d KLOC into %d KB bytecode in %.2fs\n", int(stats.lines / 1000), int(stats.bytecode / 1024), duration); + printf("Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n", int(stats.lines / 1000), + int(stats.bytecode / 1024), stats.readTime, stats.parseTime, stats.compileTime); else if (compileFormat == CompileFormat::CodegenNull) - printf("Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) in %.2fs\n", int(stats.lines / 1000), int(stats.bytecode / 1024), - int(stats.codegen / 1024), stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), duration); + printf("Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n", + int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024), + stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), stats.readTime, stats.parseTime, stats.compileTime, + stats.codegenTime); return failed ? 1 : 0; } diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 1a5f513..26be11c 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -136,6 +136,7 @@ public: void frinta(RegisterA64 dst, RegisterA64 src); void frintm(RegisterA64 dst, RegisterA64 src); void frintp(RegisterA64 dst, RegisterA64 src); + void fcvt(RegisterA64 dst, RegisterA64 src); void fcvtzs(RegisterA64 dst, RegisterA64 src); void fcvtzu(RegisterA64 dst, RegisterA64 src); void scvtf(RegisterA64 dst, RegisterA64 src); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index bb3ebb2..e162cd3 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -154,7 +154,7 @@ public: // Run final checks - void finalize(); + bool finalize(); // Places a label at current location and returns it Label setLabel(); diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index e6202c7..3b09359 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -36,6 +36,8 @@ struct IrBuilder // Source block that is cloned cannot use values coming in from a predecessor void clone(const IrBlock& source, bool removeCurrentTerminator); + IrOp undef(); + IrOp constBool(bool value); IrOp constInt(int value); IrOp constUint(unsigned value); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 47a9733..addd18f 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -10,6 +10,7 @@ #include #include +#include struct Proto; @@ -18,6 +19,12 @@ namespace Luau namespace CodeGen { +// IR extensions to LuauBuiltinFunction enum (these only exist inside IR, and start from 256 to avoid collisions) +enum +{ + LBF_IR_MATH_LOG2 = 256, +}; + // IR instruction command. // In the command description, following abbreviations are used: // * Rn - VM stack register slot, n in 0..254 @@ -112,7 +119,7 @@ enum class IrCmd : uint8_t ADD_INT, SUB_INT, - // Add/Sub/Mul/Div/Mod/Pow two double numbers + // Add/Sub/Mul/Div/Mod two double numbers // A, B: double // In final x64 lowering, B can also be Rn or Kn ADD_NUM, @@ -120,7 +127,6 @@ enum class IrCmd : uint8_t MUL_NUM, DIV_NUM, MOD_NUM, - POW_NUM, // Get the minimum/maximum of two numbers // If one of the values is NaN, 'B' is returned as the result @@ -192,8 +198,8 @@ enum class IrCmd : uint8_t // D: block (if false) JUMP_LT_INT, - // Jump if A >= B - // A, B: uint + // Jump if unsigned(A) >= unsigned(B) + // A, B: int // C: condition // D: block (if true) // E: block (if false) @@ -543,17 +549,17 @@ enum class IrCmd : uint8_t // A: operand of any type // Performs bitwise and/xor/or on two unsigned integers - // A, B: uint + // A, B: int BITAND_UINT, BITXOR_UINT, BITOR_UINT, // Performs bitwise not on an unsigned integer - // A: uint + // A: int BITNOT_UINT, // Performs bitwise shift/rotate on an unsigned integer - // A: uint (source) + // A: int (source) // B: int (shift amount) BITLSHIFT_UINT, BITRSHIFT_UINT, @@ -562,7 +568,7 @@ enum class IrCmd : uint8_t BITRROTATE_UINT, // Returns the number of consecutive zero bits in A starting from the left-most (most significant) bit. - // A: uint + // A: int BITCOUNTLZ_UINT, BITCOUNTRZ_UINT, @@ -621,6 +627,8 @@ enum class IrOpKind : uint32_t { None, + Undef, + // To reference a constant value Constant, @@ -710,6 +718,63 @@ struct IrInst // When IrInst operands are used, current instruction index is often required to track lifetime constexpr uint32_t kInvalidInstIdx = ~0u; +struct IrInstHash +{ + static const uint32_t m = 0x5bd1e995; + static const int r = 24; + + static uint32_t mix(uint32_t h, uint32_t k) + { + // MurmurHash2 step + k *= m; + k ^= k >> r; + k *= m; + + h *= m; + h ^= k; + + return h; + } + + static uint32_t mix(uint32_t h, IrOp op) + { + static_assert(sizeof(op) == sizeof(uint32_t)); + uint32_t k; + memcpy(&k, &op, sizeof(op)); + + return mix(h, k); + } + + size_t operator()(const IrInst& key) const + { + // MurmurHash2 unrolled + uint32_t h = 25; + + h = mix(h, uint32_t(key.cmd)); + h = mix(h, key.a); + h = mix(h, key.b); + h = mix(h, key.c); + h = mix(h, key.d); + h = mix(h, key.e); + h = mix(h, key.f); + + // MurmurHash2 tail + h ^= h >> 13; + h *= m; + h ^= h >> 15; + + return h; + } +}; + +struct IrInstEq +{ + bool operator()(const IrInst& a, const IrInst& b) const + { + return a.cmd == b.cmd && a.a == b.a && a.b == b.b && a.c == b.c && a.d == b.d && a.e == b.e && a.f == b.f; + } +}; + enum class IrBlockKind : uint8_t { Bytecode, diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index ed9dc91..3cf18cd 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -135,7 +135,6 @@ inline bool hasResult(IrCmd cmd) case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: - case IrCmd::POW_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: @@ -231,7 +230,7 @@ bool compare(double a, double b, IrCondition cond); // But it can also be successful on conditional control-flow, replacing it with an unconditional IrCmd::JUMP void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint32_t instIdx); -uint32_t getNativeContextOffset(LuauBuiltinFunction bfid); +uint32_t getNativeContextOffset(int bfid); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/OptimizeConstProp.h b/CodeGen/include/Luau/OptimizeConstProp.h index 619165d..74ae131 100644 --- a/CodeGen/include/Luau/OptimizeConstProp.h +++ b/CodeGen/include/Luau/OptimizeConstProp.h @@ -10,8 +10,8 @@ namespace CodeGen struct IrBuilder; -void constPropInBlockChains(IrBuilder& build); -void createLinearBlocks(IrBuilder& build); +void constPropInBlockChains(IrBuilder& build, bool useValueNumbering); +void createLinearBlocks(IrBuilder& build, bool useValueNumbering); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/RegisterA64.h b/CodeGen/include/Luau/RegisterA64.h index c3a9ae0..d50369e 100644 --- a/CodeGen/include/Luau/RegisterA64.h +++ b/CodeGen/include/Luau/RegisterA64.h @@ -17,6 +17,7 @@ enum class KindA64 : uint8_t none, w, // 32-bit GPR x, // 64-bit GPR + s, // 32-bit SIMD&FP scalar d, // 64-bit SIMD&FP scalar q, // 128-bit SIMD&FP vector }; @@ -128,6 +129,39 @@ constexpr RegisterA64 xzr{KindA64::x, 31}; constexpr RegisterA64 sp{KindA64::none, 31}; +constexpr RegisterA64 s0{KindA64::s, 0}; +constexpr RegisterA64 s1{KindA64::s, 1}; +constexpr RegisterA64 s2{KindA64::s, 2}; +constexpr RegisterA64 s3{KindA64::s, 3}; +constexpr RegisterA64 s4{KindA64::s, 4}; +constexpr RegisterA64 s5{KindA64::s, 5}; +constexpr RegisterA64 s6{KindA64::s, 6}; +constexpr RegisterA64 s7{KindA64::s, 7}; +constexpr RegisterA64 s8{KindA64::s, 8}; +constexpr RegisterA64 s9{KindA64::s, 9}; +constexpr RegisterA64 s10{KindA64::s, 10}; +constexpr RegisterA64 s11{KindA64::s, 11}; +constexpr RegisterA64 s12{KindA64::s, 12}; +constexpr RegisterA64 s13{KindA64::s, 13}; +constexpr RegisterA64 s14{KindA64::s, 14}; +constexpr RegisterA64 s15{KindA64::s, 15}; +constexpr RegisterA64 s16{KindA64::s, 16}; +constexpr RegisterA64 s17{KindA64::s, 17}; +constexpr RegisterA64 s18{KindA64::s, 18}; +constexpr RegisterA64 s19{KindA64::s, 19}; +constexpr RegisterA64 s20{KindA64::s, 20}; +constexpr RegisterA64 s21{KindA64::s, 21}; +constexpr RegisterA64 s22{KindA64::s, 22}; +constexpr RegisterA64 s23{KindA64::s, 23}; +constexpr RegisterA64 s24{KindA64::s, 24}; +constexpr RegisterA64 s25{KindA64::s, 25}; +constexpr RegisterA64 s26{KindA64::s, 26}; +constexpr RegisterA64 s27{KindA64::s, 27}; +constexpr RegisterA64 s28{KindA64::s, 28}; +constexpr RegisterA64 s29{KindA64::s, 29}; +constexpr RegisterA64 s30{KindA64::s, 30}; +constexpr RegisterA64 s31{KindA64::s, 31}; + constexpr RegisterA64 d0{KindA64::d, 0}; constexpr RegisterA64 d1{KindA64::d, 1}; constexpr RegisterA64 d2{KindA64::d, 2}; diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index ab84bef..33e3c96 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -282,7 +282,7 @@ void AssemblyBuilderA64::ror(RegisterA64 dst, RegisterA64 src1, uint8_t src2) void AssemblyBuilderA64::ldr(RegisterA64 dst, AddressA64 src) { - LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w || dst.kind == KindA64::d || dst.kind == KindA64::q); + LUAU_ASSERT(dst.kind == KindA64::x || dst.kind == KindA64::w || dst.kind == KindA64::s || dst.kind == KindA64::d || dst.kind == KindA64::q); switch (dst.kind) { @@ -292,6 +292,9 @@ void AssemblyBuilderA64::ldr(RegisterA64 dst, AddressA64 src) case KindA64::x: placeA("ldr", dst, src, 0b11100001, 0b11, /* sizelog= */ 3); break; + case KindA64::s: + placeA("ldr", dst, src, 0b11110001, 0b10, /* sizelog= */ 2); + break; case KindA64::d: placeA("ldr", dst, src, 0b11110001, 0b11, /* sizelog= */ 3); break; @@ -348,7 +351,7 @@ void AssemblyBuilderA64::ldp(RegisterA64 dst1, RegisterA64 dst2, AddressA64 src) void AssemblyBuilderA64::str(RegisterA64 src, AddressA64 dst) { - LUAU_ASSERT(src.kind == KindA64::x || src.kind == KindA64::w || src.kind == KindA64::d || src.kind == KindA64::q); + LUAU_ASSERT(src.kind == KindA64::x || src.kind == KindA64::w || src.kind == KindA64::s || src.kind == KindA64::d || src.kind == KindA64::q); switch (src.kind) { @@ -358,6 +361,9 @@ void AssemblyBuilderA64::str(RegisterA64 src, AddressA64 dst) case KindA64::x: placeA("str", src, dst, 0b11100000, 0b11, /* sizelog= */ 3); break; + case KindA64::s: + placeA("str", src, dst, 0b11110000, 0b10, /* sizelog= */ 2); + break; case KindA64::d: placeA("str", src, dst, 0b11110000, 0b11, /* sizelog= */ 3); break; @@ -570,6 +576,16 @@ void AssemblyBuilderA64::frintp(RegisterA64 dst, RegisterA64 src) placeR1("frintp", dst, src, 0b000'11110'01'1'001'001'10000); } +void AssemblyBuilderA64::fcvt(RegisterA64 dst, RegisterA64 src) +{ + if (dst.kind == KindA64::s && src.kind == KindA64::d) + placeR1("fcvt", dst, src, 0b11110'01'1'0001'00'10000); + else if (dst.kind == KindA64::d && src.kind == KindA64::s) + placeR1("fcvt", dst, src, 0b11110'00'1'0001'01'10000); + else + LUAU_ASSERT(!"Unexpected register kind"); +} + void AssemblyBuilderA64::fcvtzs(RegisterA64 dst, RegisterA64 src) { LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); @@ -1229,6 +1245,10 @@ void AssemblyBuilderA64::log(RegisterA64 reg) logAppend("x%d", reg.index); break; + case KindA64::s: + logAppend("s%d", reg.index); + break; + case KindA64::d: logAppend("d%d", reg.index); break; diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index b8f3940..4c9ad6d 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -831,7 +831,7 @@ void AssemblyBuilderX64::vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 placeAvx("vblendvpd", dst, src1, mask, src3.index << 4, 0x4b, false, AVX_0F3A, AVX_66); } -void AssemblyBuilderX64::finalize() +bool AssemblyBuilderX64::finalize() { code.resize(codePos - code.data()); @@ -853,6 +853,8 @@ void AssemblyBuilderX64::finalize() data.resize(dataSize); finalized = true; + + return true; } Label AssemblyBuilderX64::setLabel() diff --git a/CodeGen/src/BitUtils.h b/CodeGen/src/BitUtils.h index 93f7cc8..31fc4bf 100644 --- a/CodeGen/src/BitUtils.h +++ b/CodeGen/src/BitUtils.h @@ -32,5 +32,25 @@ inline int countrz(uint32_t n) #endif } +inline int lrotate(uint32_t u, int s) +{ + // MSVC doesn't recognize the rotate form that is UB-safe +#ifdef _MSC_VER + return _rotl(u, s); +#else + return (u << (s & 31)) | (u >> ((32 - s) & 31)); +#endif +} + +inline int rrotate(uint32_t u, int s) +{ + // MSVC doesn't recognize the rotate form that is UB-safe +#ifdef _MSC_VER + return _rotr(u, s); +#else + return (u >> (s & 31)) | (u << ((32 - s) & 31)); +#endif +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index a86d5a2..f0be5b3 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -51,6 +51,7 @@ LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt, false) LUAU_FASTFLAGVARIABLE(DebugCodegenOptSize, false) +LUAU_FASTFLAGVARIABLE(DebugCodegenSkipNumbering, false) namespace Luau { @@ -59,21 +60,33 @@ namespace CodeGen static NativeProto* createNativeProto(Proto* proto, const IrBuilder& ir) { - NativeProto* result = new NativeProto(); + int sizecode = proto->sizecode; + int sizecodeAlloc = (sizecode + 1) & ~1; // align uint32_t array to 8 bytes so that NativeProto is aligned to 8 bytes + void* memory = ::operator new(sizeof(NativeProto) + sizecodeAlloc * sizeof(uint32_t)); + NativeProto* result = new (static_cast(memory) + sizecodeAlloc * sizeof(uint32_t)) NativeProto; result->proto = proto; - result->instTargets = new uintptr_t[proto->sizecode]; - for (int i = 0; i < proto->sizecode; i++) + uint32_t* instOffsets = result->instOffsets; + + for (int i = 0; i < sizecode; i++) { - auto [irLocation, asmLocation] = ir.function.bcMapping[i]; - - result->instTargets[i] = irLocation == ~0u ? 0 : asmLocation; + // instOffsets uses negative indexing for optimal codegen for RETURN opcode + instOffsets[-i] = ir.function.bcMapping[i].asmLocation; } return result; } +static void destroyNativeProto(NativeProto* nativeProto) +{ + int sizecode = nativeProto->proto->sizecode; + int sizecodeAlloc = (sizecode + 1) & ~1; // align uint32_t array to 8 bytes so that NativeProto is aligned to 8 bytes + void* memory = reinterpret_cast(nativeProto) - sizecodeAlloc * sizeof(uint32_t); + + ::operator delete(memory); +} + template static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options) { @@ -95,30 +108,19 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& return a.start < b.start; }); - DenseHashMap bcLocations{~0u}; + // For each IR instruction that begins a bytecode instruction, which bytecode instruction is it? + std::vector bcLocations(function.instructions.size() + 1, ~0u); - // Create keys for IR assembly locations that original bytecode instruction are interested in - for (const auto& [irLocation, asmLocation] : function.bcMapping) + for (size_t i = 0; i < function.bcMapping.size(); ++i) { + uint32_t irLocation = function.bcMapping[i].irLocation; + if (irLocation != ~0u) - bcLocations[irLocation] = 0; + bcLocations[irLocation] = uint32_t(i); } - DenseHashMap indexIrToBc{~0u}; bool outputEnabled = options.includeAssembly || options.includeIr; - if (outputEnabled && options.annotator) - { - // Create reverse mapping from IR location to bytecode location - for (size_t i = 0; i < function.bcMapping.size(); ++i) - { - uint32_t irLocation = function.bcMapping[i].irLocation; - - if (irLocation != ~0u) - indexIrToBc[irLocation] = uint32_t(i); - } - } - IrToStringContext ctx{build.text, function.blocks, function.constants, function.cfg}; // We use this to skip outlined fallback blocks from IR/asm text output @@ -164,18 +166,19 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& { LUAU_ASSERT(index < function.instructions.size()); + uint32_t bcLocation = bcLocations[index]; + // If IR instruction is the first one for the original bytecode, we can annotate it with source code text - if (outputEnabled && options.annotator) + if (outputEnabled && options.annotator && bcLocation != ~0u) { - if (uint32_t* bcIndex = indexIrToBc.find(index)) - options.annotator(options.annotatorContext, build.text, bytecodeid, *bcIndex); + options.annotator(options.annotatorContext, build.text, bytecodeid, bcLocation); } // If bytecode needs the location of this instruction for jumps, record it - if (uint32_t* bcLocation = bcLocations.find(index)) + if (bcLocation != ~0u) { Label label = (index == block.start) ? block.label : build.setLabel(); - *bcLocation = build.getLabelOffset(label); + function.bcMapping[bcLocation].asmLocation = build.getLabelOffset(label); } IrInst& inst = function.instructions[index]; @@ -227,13 +230,6 @@ static bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& build.logAppend("; skipping %u bytes of outlined code\n", unsigned((build.getCodeSize() - codeSize) * sizeof(build.code[0]))); } - // Copy assembly locations of IR instructions that are mapped to bytecode instructions - for (auto& [irLocation, asmLocation] : function.bcMapping) - { - if (irLocation != ~0u) - asmLocation = bcLocations[irLocation]; - } - return true; } @@ -293,10 +289,12 @@ static NativeProto* assembleFunction(AssemblyBuilder& build, NativeState& data, if (!FFlag::DebugCodegenNoOpt) { - constPropInBlockChains(ir); + bool useValueNumbering = !FFlag::DebugCodegenSkipNumbering; + + constPropInBlockChains(ir, useValueNumbering); if (!FFlag::DebugCodegenOptSize) - createLinearBlocks(ir); + createLinearBlocks(ir, useValueNumbering); } if (!lowerIr(build, ir, data, helpers, proto, options)) @@ -313,12 +311,6 @@ static NativeProto* assembleFunction(AssemblyBuilder& build, NativeState& data, return createNativeProto(proto, ir); } -static void destroyNativeProto(NativeProto* nativeProto) -{ - delete[] nativeProto->instTargets; - delete nativeProto; -} - static void onCloseState(lua_State* L) { destroyNativeState(L); @@ -347,7 +339,9 @@ static int onEnter(lua_State* L, Proto* proto) bool (*gate)(lua_State*, Proto*, uintptr_t, NativeContext*) = (bool (*)(lua_State*, Proto*, uintptr_t, NativeContext*))data->context.gateEntry; NativeProto* nativeProto = getProtoExecData(proto); - uintptr_t target = nativeProto->instTargets[L->ci->savedpc - proto->code]; + + // instOffsets uses negative indexing for optimal codegen for RETURN opcode + uintptr_t target = nativeProto->instBase + nativeProto->instOffsets[-(L->ci->savedpc - proto->code)]; // Returns 1 to finish the function in the VM return gate(L, proto, target, &data->context); @@ -517,7 +511,14 @@ void compile(lua_State* L, int idx) if (NativeProto* np = assembleFunction(build, *data, helpers, p, {})) results.push_back(np); - build.finalize(); + // Very large modules might result in overflowing a jump offset; in this case we currently abandon the entire module + if (!build.finalize()) + { + for (NativeProto* result : results) + destroyNativeProto(result); + + return; + } // If no functions were assembled, we don't need to allocate/copy executable pages for helpers if (results.empty()) @@ -535,14 +536,11 @@ void compile(lua_State* L, int idx) return; } - // Relocate instruction offsets + // Record instruction base address; at runtime, instOffsets[] will be used as offsets from instBase for (NativeProto* result : results) { - for (int i = 0; i < result->proto->sizecode; i++) - result->instTargets[i] += uintptr_t(codeStart); - - LUAU_ASSERT(result->proto->sizecode); - result->entryTarget = result->instTargets[0]; + result->instBase = uintptr_t(codeStart); + result->entryTarget = uintptr_t(codeStart) + result->instOffsets[0]; } // Link native proto objects to Proto; the memory is now managed by VM and will be freed via onDestroyFunction @@ -579,7 +577,8 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) if (NativeProto* np = assembleFunction(build, data, helpers, p, options)) destroyNativeProto(np); - build.finalize(); + if (!build.finalize()) + return std::string(); if (options.outputBinary) return std::string(reinterpret_cast(build.code.data()), reinterpret_cast(build.code.data() + build.code.size())) + diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index 7f29beb..415cfc9 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -35,11 +35,17 @@ static void emitInterrupt(AssemblyBuilderA64& build) { // x0 = pc offset // x1 = return address in native code - // x2 = interrupt + + Label skip; // Stash return address in rBase; we need to reload rBase anyway build.mov(rBase, x1); + // Load interrupt handler; it may be nullptr in case the update raced with the check before we got here + build.ldr(x2, mem(rState, offsetof(lua_State, global))); + build.ldr(x2, mem(x2, offsetof(global_State, cb.interrupt))); + build.cbz(x2, skip); + // Update savedpc; required in case interrupt errors build.add(x0, rCode, x0); build.ldr(x1, mem(rState, offsetof(lua_State, ci))); @@ -51,7 +57,6 @@ static void emitInterrupt(AssemblyBuilderA64& build) build.blr(x2); // Check if we need to exit - Label skip; build.ldrb(w0, mem(rState, offsetof(lua_State, status))); build.cbz(w0, skip); @@ -92,11 +97,11 @@ static void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) // Get instruction index from instruction pointer // To get instruction index from instruction pointer, we need to divide byte offset by 4 - // But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out + // But we will actually need to scale instruction index by 4 back to byte offset later so it cancels out + // Note that we're computing negative offset here (code-savedpc) so that we can add it to NativeProto address, as we use reverse indexing build.ldr(x2, mem(rState, offsetof(lua_State, ci))); // L->ci build.ldr(x2, mem(x2, offsetof(CallInfo, savedpc))); // L->ci->savedpc - build.sub(x2, x2, rCode); - build.add(x2, x2, x2); // TODO: this would not be necessary if we supported shifted register offsets in loads + build.sub(x2, rCode, x2); // We need to check if the new function can be executed natively // TODO: This can be done earlier in the function flow, to reduce the JIT->VM transition penalty @@ -104,8 +109,10 @@ static void emitReentry(AssemblyBuilderA64& build, ModuleHelpers& helpers) build.cbz(x1, helpers.exitContinueVm); // Get new instruction location and jump to it - build.ldr(x1, mem(x1, offsetof(NativeProto, instTargets))); - build.ldr(x1, mem(x1, x2)); + LUAU_ASSERT(offsetof(NativeProto, instOffsets) == 0); + build.ldr(w2, mem(x1, x2)); + build.ldr(x1, mem(x1, offsetof(NativeProto, instBase))); + build.add(x1, x1, x2); build.br(x1); } diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 3e6d26b..af4c529 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -18,24 +18,10 @@ namespace CodeGen namespace X64 { -void emitBuiltinMathLog(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) -{ - regs.assertAllFree(); - build.vmovsd(xmm0, luauRegValue(arg)); - - // TODO: IR builtin lowering assumes that the only valid 2-argument call is log2; ideally, we use a less hacky way to indicate that - if (nparams == 2) - build.call(qword[rNativeContext + offsetof(NativeContext, libm_log2)]); - else - build.call(qword[rNativeContext + offsetof(NativeContext, libm_log)]); - - build.vmovsd(luauRegValue(ra), xmm0); -} - -void emitBuiltinMathLdexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +static void emitBuiltinMathLdexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, OperandX64 arg2) { ScopedRegX64 tmp{regs, SizeX64::qword}; - build.vcvttsd2si(tmp.reg, qword[args + offsetof(TValue, value)]); + build.vcvttsd2si(tmp.reg, arg2); IrCallWrapperX64 callWrap(regs, build); callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); @@ -45,7 +31,7 @@ void emitBuiltinMathLdexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int np build.vmovsd(luauRegValue(ra), xmm0); } -void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +static void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, int nresults) { IrCallWrapperX64 callWrap(regs, build); callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); @@ -61,7 +47,7 @@ void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, int np } } -void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +static void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg, int nresults) { IrCallWrapperX64 callWrap(regs, build); callWrap.addArgument(SizeX64::xmmword, luauRegValue(arg)); @@ -75,7 +61,7 @@ void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npa build.vmovsd(luauRegValue(ra + 1), xmm0); } -void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +static void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg) { ScopedRegX64 tmp0{regs, SizeX64::xmmword}; ScopedRegX64 tmp1{regs, SizeX64::xmmword}; @@ -102,7 +88,7 @@ void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npa build.vmovsd(luauRegValue(ra), tmp0.reg); } -void emitBuiltinType(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +static void emitBuiltinType(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg) { ScopedRegX64 tmp0{regs, SizeX64::qword}; ScopedRegX64 tag{regs, SizeX64::dword}; @@ -115,7 +101,7 @@ void emitBuiltinType(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams build.mov(luauRegValue(ra), tmp0.reg); } -void emitBuiltinTypeof(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +static void emitBuiltinTypeof(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int arg) { IrCallWrapperX64 callWrap(regs, build); callWrap.addArgument(SizeX64::qword, rState); @@ -125,38 +111,28 @@ void emitBuiltinTypeof(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npara build.mov(luauRegValue(ra), rax); } -void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults) +void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, OperandX64 arg2, int nparams, int nresults) { - OperandX64 argsOp = 0; - - if (args.kind == IrOpKind::VmReg) - argsOp = luauRegAddress(vmRegOp(args)); - else if (args.kind == IrOpKind::VmConst) - argsOp = luauConstantAddress(vmConstOp(args)); - switch (bfid) { - case LBF_MATH_LOG: - LUAU_ASSERT((nparams == 1 || nparams == 2) && nresults == 1); - return emitBuiltinMathLog(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_LDEXP: LUAU_ASSERT(nparams == 2 && nresults == 1); - return emitBuiltinMathLdexp(regs, build, nparams, ra, arg, argsOp, nresults); + return emitBuiltinMathLdexp(regs, build, ra, arg, arg2); case LBF_MATH_FREXP: LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); - return emitBuiltinMathFrexp(regs, build, nparams, ra, arg, argsOp, nresults); + return emitBuiltinMathFrexp(regs, build, ra, arg, nresults); case LBF_MATH_MODF: LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); - return emitBuiltinMathModf(regs, build, nparams, ra, arg, argsOp, nresults); + return emitBuiltinMathModf(regs, build, ra, arg, nresults); case LBF_MATH_SIGN: LUAU_ASSERT(nparams == 1 && nresults == 1); - return emitBuiltinMathSign(regs, build, nparams, ra, arg, argsOp, nresults); + return emitBuiltinMathSign(regs, build, ra, arg); case LBF_TYPE: LUAU_ASSERT(nparams == 1 && nresults == 1); - return emitBuiltinType(regs, build, nparams, ra, arg, argsOp, nresults); + return emitBuiltinType(regs, build, ra, arg); case LBF_TYPEOF: LUAU_ASSERT(nparams == 1 && nresults == 1); - return emitBuiltinTypeof(regs, build, nparams, ra, arg, argsOp, nresults); + return emitBuiltinTypeof(regs, build, ra, arg); default: LUAU_ASSERT(!"Missing x64 lowering"); break; diff --git a/CodeGen/src/EmitBuiltinsX64.h b/CodeGen/src/EmitBuiltinsX64.h index 5925a2b..cd8b525 100644 --- a/CodeGen/src/EmitBuiltinsX64.h +++ b/CodeGen/src/EmitBuiltinsX64.h @@ -16,7 +16,7 @@ class AssemblyBuilderX64; struct OperandX64; struct IrRegAllocX64; -void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults); +void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, OperandX64 arg2, int nparams, int nresults); } // namespace X64 } // namespace CodeGen diff --git a/CodeGen/src/EmitCommonA64.h b/CodeGen/src/EmitCommonA64.h index f590df9..9e89b1c 100644 --- a/CodeGen/src/EmitCommonA64.h +++ b/CodeGen/src/EmitCommonA64.h @@ -13,8 +13,9 @@ // Arguments: x0-x7, v0-v7 // Return: x0, v0 (or x8 that points to the address of the resulting structure) // Volatile: x9-x15, v16-v31 ("caller-saved", any call may change them) +// Intra-procedure-call temporary: x16-x17 (any call or relocated jump may change them, as linker may point branches to veneers to perform long jumps) // Non-volatile: x19-x28, v8-v15 ("callee-saved", preserved after calls, only bottom half of SIMD registers is preserved!) -// Reserved: x16-x18: reserved for linker/platform use; x29: frame pointer (unless omitted); x30: link register; x31: stack pointer +// Reserved: x18: reserved for platform use; x29: frame pointer (unless omitted); x30: link register; x31: stack pointer namespace Luau { diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index 9a10bfd..19f0cb8 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -308,12 +308,15 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i build.mov(rax, qword[cip + offsetof(CallInfo, savedpc)]); // To get instruction index from instruction pointer, we need to divide byte offset by 4 - // But we will actually need to scale instruction index by 8 back to byte offset later so it cancels out - build.sub(rax, rdx); + // But we will actually need to scale instruction index by 4 back to byte offset later so it cancels out + // Note that we're computing negative offset here (code-savedpc) so that we can add it to NativeProto address, as we use reverse indexing + build.sub(rdx, rax); // Get new instruction location and jump to it - build.mov(rdx, qword[execdata + offsetof(NativeProto, instTargets)]); - build.jmp(qword[rdx + rax * 2]); + LUAU_ASSERT(offsetof(NativeProto, instOffsets) == 0); + build.mov(edx, dword[execdata + rdx]); + build.add(rdx, qword[execdata + offsetof(NativeProto, instBase)]); + build.jmp(rdx); } void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index) diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index f3870e9..efe9fcc 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -529,7 +529,8 @@ static void computeCfgLiveInOutRegSets(IrFunction& function) RegisterSet& outRs = info.out[blockIdx]; // Current block has to provide all registers in successor blocks - for (uint32_t succIdx : successors(info, blockIdx)) + BlockIteratorWrapper successorsIt = successors(info, blockIdx); + for (uint32_t succIdx : successorsIt) { IrBlock& succ = function.blocks[succIdx]; @@ -538,7 +539,11 @@ static void computeCfgLiveInOutRegSets(IrFunction& function) // This is because fallback blocks define an alternative implementation of the same operations // This can cause the current block to define more registers that actually were available at fallback entry if (curr.kind != IrBlockKind::Fallback && succ.kind == IrBlockKind::Fallback) + { + // If this is the only successor, this skip will not be valid + LUAU_ASSERT(successorsIt.size() != 1); continue; + } const RegisterSet& succRs = info.in[succIdx]; diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 48c0e25..86986fe 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -30,7 +30,7 @@ void IrBuilder::buildFunctionIr(Proto* proto) // Rebuild original control flow blocks rebuildBytecodeBasicBlocks(proto); - function.bcMapping.resize(proto->sizecode, {~0u, 0}); + function.bcMapping.resize(proto->sizecode, {~0u, ~0u}); // Translate all instructions to IR inside blocks for (int i = 0; i < proto->sizecode;) @@ -41,7 +41,7 @@ void IrBuilder::buildFunctionIr(Proto* proto) int nexti = i + getOpLength(op); LUAU_ASSERT(nexti <= proto->sizecode); - function.bcMapping[i] = {uint32_t(function.instructions.size()), 0}; + function.bcMapping[i] = {uint32_t(function.instructions.size()), ~0u}; // Begin new block at this instruction if it was in the bytecode or requested during translation if (instIndexToBlock[i] != kNoAssociatedBlockIndex) @@ -293,7 +293,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) int skip = LUAU_INSN_C(*pc); IrOp next = blockAtInst(i + skip + 2); - translateFastCallN(*this, pc, i, true, 1, constBool(false), next); + translateFastCallN(*this, pc, i, true, 1, undef(), next); activeFastcallFallback = true; fastcallFallbackReturn = next; @@ -496,6 +496,11 @@ void IrBuilder::clone(const IrBlock& source, bool removeCurrentTerminator) } } +IrOp IrBuilder::undef() +{ + return {IrOpKind::Undef, 0}; +} + IrOp IrBuilder::constBool(bool value) { IrConst constant; diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 084a727..50c1848 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -120,8 +120,6 @@ const char* getCmdName(IrCmd cmd) return "DIV_NUM"; case IrCmd::MOD_NUM: return "MOD_NUM"; - case IrCmd::POW_NUM: - return "POW_NUM"; case IrCmd::MIN_NUM: return "MIN_NUM"; case IrCmd::MAX_NUM: @@ -359,6 +357,9 @@ void toString(IrToStringContext& ctx, IrOp op) { case IrOpKind::None: break; + case IrOpKind::Undef: + append(ctx.result, "undef"); + break; case IrOpKind::Constant: toString(ctx.result, ctx.constants[op.index]); break; @@ -398,7 +399,10 @@ void toString(std::string& result, IrConst constant) append(result, "%uu", constant.valueUint); break; case IrConstKind::Double: - append(result, "%.17g", constant.valueDouble); + if (constant.valueDouble != constant.valueDouble) + append(result, "nan"); + else + append(result, "%.17g", constant.valueDouble); break; case IrConstKind::Tag: result.append(getTagName(constant.valueTag)); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 18df26b..7fd684b 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -109,46 +109,6 @@ static void emitFallback(AssemblyBuilderA64& build, int op, int pcpos) emitUpdateBase(build); } -static void emitInvokeLibm1(AssemblyBuilderA64& build, size_t func, int res, int arg) -{ - build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); - build.ldr(x0, mem(rNativeContext, uint32_t(func))); - build.blr(x0); - build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); -} - -static void emitInvokeLibm2(AssemblyBuilderA64& build, size_t func, int res, int arg, IrOp args, bool argsInt = false) -{ - if (args.kind == IrOpKind::VmReg) - build.ldr(d1, mem(rBase, args.index * sizeof(TValue) + offsetof(TValue, value.n))); - else if (args.kind == IrOpKind::VmConst) - { - size_t constantOffset = args.index * sizeof(TValue) + offsetof(TValue, value.n); - - // Note: cumulative offset is guaranteed to be divisible by 8 (since we're loading a double); we can use that to expand the useful range that - // doesn't require temporaries - if (constantOffset / 8 <= AddressA64::kMaxOffset) - { - build.ldr(d1, mem(rConstants, int(constantOffset))); - } - else - { - emitAddOffset(build, x0, rConstants, constantOffset); - build.ldr(d1, x0); - } - } - else - LUAU_ASSERT(!"Unsupported instruction form"); - - if (argsInt) - build.fcvtzs(w0, d1); - - build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); - build.ldr(x1, mem(rNativeContext, uint32_t(func))); - build.blr(x1); - build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); -} - static void emitInvokeLibm1P(AssemblyBuilderA64& build, size_t func, int arg) { build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); @@ -157,21 +117,46 @@ static void emitInvokeLibm1P(AssemblyBuilderA64& build, size_t func, int arg) build.blr(x1); } -static bool emitBuiltin(AssemblyBuilderA64& build, IrRegAllocA64& regs, int bfid, int res, int arg, IrOp args, int nparams, int nresults) +static bool emitBuiltin( + AssemblyBuilderA64& build, IrFunction& function, IrRegAllocA64& regs, int bfid, int res, int arg, IrOp args, int nparams, int nresults) { switch (bfid) { - case LBF_MATH_LOG: - LUAU_ASSERT((nparams == 1 || nparams == 2) && nresults == 1); - // TODO: IR builtin lowering assumes that the only valid 2-argument call is log2; ideally, we use a less hacky way to indicate that - if (nparams == 2) - emitInvokeLibm1(build, offsetof(NativeContext, libm_log2), res, arg); - else - emitInvokeLibm1(build, offsetof(NativeContext, libm_log), res, arg); - return true; case LBF_MATH_LDEXP: LUAU_ASSERT(nparams == 2 && nresults == 1); - emitInvokeLibm2(build, offsetof(NativeContext, libm_ldexp), res, arg, args, /* argsInt= */ true); + + if (args.kind == IrOpKind::VmReg) + { + build.ldr(d1, mem(rBase, args.index * sizeof(TValue) + offsetof(TValue, value.n))); + build.fcvtzs(w0, d1); + } + else if (args.kind == IrOpKind::VmConst) + { + size_t constantOffset = args.index * sizeof(TValue) + offsetof(TValue, value.n); + + // Note: cumulative offset is guaranteed to be divisible by 8 (since we're loading a double); we can use that to expand the useful range + // that doesn't require temporaries + if (constantOffset / 8 <= AddressA64::kMaxOffset) + { + build.ldr(d1, mem(rConstants, int(constantOffset))); + } + else + { + emitAddOffset(build, x0, rConstants, constantOffset); + build.ldr(d1, x0); + } + + build.fcvtzs(w0, d1); + } + else if (args.kind == IrOpKind::Constant) + build.mov(w0, int(function.doubleOp(args))); + else if (args.kind != IrOpKind::Undef) + LUAU_ASSERT(!"Unsupported instruction form"); + + build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); + build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, libm_ldexp))); + build.blr(x1); + build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); return true; case LBF_MATH_FREXP: LUAU_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); @@ -233,14 +218,22 @@ IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, , data(data) , proto(proto) , function(function) - , regs(function, {{x0, x15}, {q0, q7}, {q16, q31}}) + , regs(function, {{x0, x15}, {x16, x17}, {q0, q7}, {q16, q31}}) + , valueTracker(function) { // In order to allocate registers during lowering, we need to know where instruction results are last used updateLastUseLocations(function); + + valueTracker.setRestoreCallack(this, [](void* context, IrInst& inst) { + IrLoweringA64* self = static_cast(context); + self->regs.restoreReg(self->build, inst); + }); } void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { + valueTracker.beforeInstLowering(inst); + switch (inst.cmd) { case IrCmd::LOAD_TAG: @@ -299,7 +292,6 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } else if (inst.b.kind == IrOpKind::Constant) { - // TODO: refactor into a common helper? can't use emitAddOffset because we need a temp register if (intOp(inst.b) * sizeof(TValue) <= AssemblyBuilderA64::kMaxImmediate) { build.add(inst.regA64, inst.regA64, uint16_t(intOp(inst.b) * sizeof(TValue))); @@ -387,6 +379,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.str(temp, addr); break; } + case IrCmd::STORE_VECTOR: + { + RegisterA64 temp1 = tempDouble(inst.b); + RegisterA64 temp2 = tempDouble(inst.c); + RegisterA64 temp3 = tempDouble(inst.d); + RegisterA64 temp4 = regs.allocTemp(KindA64::s); + + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); + LUAU_ASSERT(addr.kind == AddressKindA64::imm && addr.data % 4 == 0 && unsigned(addr.data + 8) / 4 <= AddressA64::kMaxOffset); + + build.fcvt(temp4, temp1); + build.str(temp4, AddressA64(addr.base, addr.data + 0)); + build.fcvt(temp4, temp2); + build.str(temp4, AddressA64(addr.base, addr.data + 4)); + build.fcvt(temp4, temp3); + build.str(temp4, AddressA64(addr.base, addr.data + 8)); + break; + } case IrCmd::STORE_TVALUE: { AddressA64 addr = tempAddr(inst.a, 0); @@ -400,6 +410,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); if (inst.b.kind == IrOpKind::Constant && unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate) build.add(inst.regA64, regOp(inst.a), uint16_t(intOp(inst.b))); + else if (inst.a.kind == IrOpKind::Constant && unsigned(intOp(inst.a)) <= AssemblyBuilderA64::kMaxImmediate) + build.add(inst.regA64, regOp(inst.b), uint16_t(intOp(inst.a))); else { RegisterA64 temp = tempInt(inst.b); @@ -459,21 +471,6 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.fsub(inst.regA64, temp1, inst.regA64); break; } - case IrCmd::POW_NUM: - { - RegisterA64 temp1 = tempDouble(inst.a); - RegisterA64 temp2 = tempDouble(inst.b); - build.fmov(d0, temp1); // TODO: aliasing hazard - build.fmov(d1, temp2); // TODO: aliasing hazard - regs.spill(build, index, {d0, d1}); - build.ldr(x0, mem(rNativeContext, offsetof(NativeContext, libm_pow))); - build.blr(x0); - - // TODO: we could takeReg d0 but it's unclear if we will be able to keep d0 allocatable due to aliasing concerns - inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); - build.fmov(inst.regA64, d0); - break; - } case IrCmd::MIN_NUM: { inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); @@ -635,8 +632,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; case IrCmd::JUMP_GE_UINT: { - LUAU_ASSERT(uintOp(inst.b) <= AssemblyBuilderA64::kMaxImmediate); - build.cmp(regOp(inst.a), uint16_t(uintOp(inst.b))); + LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); + build.cmp(regOp(inst.a), uint16_t(unsigned(intOp(inst.b)))); build.b(ConditionA64::CarrySet, labelOp(inst.c)); jumpOrFallthrough(blockOp(inst.d), next); break; @@ -723,8 +720,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::TABLE_LEN: { - build.mov(x0, regOp(inst.a)); // TODO: aliasing hazard - regs.spill(build, index, {x0}); + RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads + regs.spill(build, index, {reg}); + build.mov(x0, reg); build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaH_getn))); build.blr(x1); inst.regA64 = regs.allocReg(KindA64::d, index); @@ -739,21 +737,18 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(x2, uintOp(inst.b)); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaH_new))); build.blr(x3); - // TODO: we could takeReg x0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns - inst.regA64 = regs.allocReg(KindA64::x, index); - build.mov(inst.regA64, x0); + inst.regA64 = regs.takeReg(x0, index); break; } case IrCmd::DUP_TABLE: { - build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard - regs.spill(build, index, {x1}); + RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads + regs.spill(build, index, {reg}); + build.mov(x1, reg); build.mov(x0, rState); build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaH_clone))); build.blr(x2); - // TODO: we could takeReg x0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns - inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); - build.mov(inst.regA64, x0); + inst.regA64 = regs.takeReg(x0, index); break; } case IrCmd::TRY_NUM_TO_INDEX: @@ -789,17 +784,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.tst(temp2, 1 << intOp(inst.b)); // can't use tbz/tbnz because their jump offsets are too short build.b(ConditionA64::NotEqual, labelOp(inst.c)); // Equal = Zero after tst; tmcache caches *absence* of metamethods + regs.spill(build, index, {temp1}); build.mov(x0, temp1); - regs.spill(build, index, {x0}); build.mov(w1, intOp(inst.b)); build.ldr(x2, mem(rState, offsetof(lua_State, global))); build.ldr(x2, mem(x2, offsetof(global_State, tmname) + intOp(inst.b) * sizeof(TString*))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaT_gettm))); build.blr(x3); - - // TODO: we could takeReg x0 but it's unclear if we will be able to keep x0 allocatable due to aliasing concerns - inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); - build.mov(inst.regA64, x0); + inst.regA64 = regs.takeReg(x0, index); break; } case IrCmd::INT_TO_NUM: @@ -861,9 +853,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::FASTCALL: regs.spill(build, index); - // TODO: emitBuiltin should be exhaustive - if (!emitBuiltin(build, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f))) - error = true; + error |= emitBuiltin(build, function, regs, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f)); break; case IrCmd::INVOKE_FASTCALL: { @@ -878,7 +868,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) else if (inst.d.kind == IrOpKind::VmConst) emitAddOffset(build, x4, rConstants, vmConstOp(inst.d) * sizeof(TValue)); else - LUAU_ASSERT(boolOp(inst.d) == false); + LUAU_ASSERT(inst.d.kind == IrOpKind::Undef); // nparams if (intOp(inst.e) == LUA_MULTRET) @@ -1047,10 +1037,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) Label skip; checkObjectBarrierConditions(build, temp1, temp2, vmRegOp(inst.b), skip); - build.mov(x1, temp1); // TODO: aliasing hazard - - size_t spills = regs.spill(build, index, {x1}); + size_t spills = regs.spill(build, index, {temp1}); + build.mov(x1, temp1); build.mov(x0, rState); build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierf))); @@ -1108,7 +1097,6 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cmp(temp, regOp(inst.b)); else if (inst.b.kind == IrOpKind::Constant) { - // TODO: refactor into a common helper? if (size_t(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate) { build.cmp(temp, uint16_t(intOp(inst.b))); @@ -1159,17 +1147,17 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::INTERRUPT: { - unsigned int pcpos = uintOp(inst.a); + RegisterA64 temp = regs.allocTemp(KindA64::x); Label skip, next; - build.ldr(x2, mem(rState, offsetof(lua_State, global))); - build.ldr(x2, mem(x2, offsetof(global_State, cb.interrupt))); - build.cbz(x2, skip); + build.ldr(temp, mem(rState, offsetof(lua_State, global))); + build.ldr(temp, mem(temp, offsetof(global_State, cb.interrupt))); + build.cbz(temp, skip); - size_t spills = regs.spill(build, index, {x2}); + size_t spills = regs.spill(build, index); // Jump to outlined interrupt handler, it will give back control to x1 - build.mov(x0, (pcpos + 1) * sizeof(Instruction)); + build.mov(x0, (uintOp(inst.a) + 1) * sizeof(Instruction)); build.adr(x1, next); build.b(helpers.interrupt); @@ -1182,7 +1170,6 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::CHECK_GC: { - regs.spill(build, index); RegisterA64 temp1 = regs.allocTemp(KindA64::x); RegisterA64 temp2 = regs.allocTemp(KindA64::x); @@ -1193,12 +1180,17 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cmp(temp1, temp2); build.b(ConditionA64::UnsignedGreater, skip); + size_t spills = regs.spill(build, index); + build.mov(x0, rState); build.mov(w1, 1); build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaC_step))); build.blr(x1); emitUpdateBase(build); + + regs.restore(build, spills); // need to restore before skip so that registers are in a consistent state + build.setLabel(skip); break; } @@ -1209,8 +1201,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) Label skip; checkObjectBarrierConditions(build, regOp(inst.a), temp, vmRegOp(inst.b), skip); - build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard - size_t spills = regs.spill(build, index, {x1}); + RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads + size_t spills = regs.spill(build, index, {reg}); + build.mov(x1, reg); build.mov(x0, rState); build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierf))); @@ -1231,8 +1224,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.ldrb(temp, mem(regOp(inst.a), offsetof(GCheader, marked))); build.tbz(temp, BLACKBIT, skip); - build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard - size_t spills = regs.spill(build, index, {x1}); + RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads + size_t spills = regs.spill(build, index, {reg}); + build.mov(x1, reg); build.mov(x0, rState); build.add(x2, x1, uint16_t(offsetof(Table, gclist))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierback))); @@ -1251,8 +1245,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) Label skip; checkObjectBarrierConditions(build, regOp(inst.a), temp, vmRegOp(inst.b), skip); - build.mov(x1, regOp(inst.a)); // TODO: aliasing hazard - size_t spills = regs.spill(build, index, {x1}); + RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads + size_t spills = regs.spill(build, index, {reg}); + build.mov(x1, reg); build.mov(x0, rState); build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barriertable))); @@ -1290,8 +1285,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cmp(temp2, temp1); build.b(ConditionA64::UnsignedGreater, skip); - build.mov(x1, temp2); // TODO: aliasing hazard - size_t spills = regs.spill(build, index, {x1}); + size_t spills = regs.spill(build, index, {temp2}); + build.mov(x1, temp2); build.mov(x0, rState); build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaF_close))); build.blr(x2); @@ -1484,8 +1479,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITAND_UINT: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); - if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(uintOp(inst.b))) - build.and_(inst.regA64, regOp(inst.a), uintOp(inst.b)); + if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b)))) + build.and_(inst.regA64, regOp(inst.a), unsigned(intOp(inst.b))); else { RegisterA64 temp1 = tempUint(inst.a); @@ -1497,8 +1492,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITXOR_UINT: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); - if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(uintOp(inst.b))) - build.eor(inst.regA64, regOp(inst.a), uintOp(inst.b)); + if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b)))) + build.eor(inst.regA64, regOp(inst.a), unsigned(intOp(inst.b))); else { RegisterA64 temp1 = tempUint(inst.a); @@ -1510,8 +1505,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITOR_UINT: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); - if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(uintOp(inst.b))) - build.orr(inst.regA64, regOp(inst.a), uintOp(inst.b)); + if (inst.b.kind == IrOpKind::Constant && AssemblyBuilderA64::isMaskSupported(unsigned(intOp(inst.b)))) + build.orr(inst.regA64, regOp(inst.a), unsigned(intOp(inst.b))); else { RegisterA64 temp1 = tempUint(inst.a); @@ -1531,7 +1526,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); if (inst.b.kind == IrOpKind::Constant) - build.lsl(inst.regA64, regOp(inst.a), uint8_t(uintOp(inst.b) & 31)); + build.lsl(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31)); else { RegisterA64 temp1 = tempUint(inst.a); @@ -1544,7 +1539,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); if (inst.b.kind == IrOpKind::Constant) - build.lsr(inst.regA64, regOp(inst.a), uint8_t(uintOp(inst.b) & 31)); + build.lsr(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31)); else { RegisterA64 temp1 = tempUint(inst.a); @@ -1557,7 +1552,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); if (inst.b.kind == IrOpKind::Constant) - build.asr(inst.regA64, regOp(inst.a), uint8_t(uintOp(inst.b) & 31)); + build.asr(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31)); else { RegisterA64 temp1 = tempUint(inst.a); @@ -1571,7 +1566,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) if (inst.b.kind == IrOpKind::Constant) { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a}); - build.ror(inst.regA64, regOp(inst.a), uint8_t((32 - uintOp(inst.b)) & 31)); + build.ror(inst.regA64, regOp(inst.a), uint8_t((32 - unsigned(intOp(inst.b))) & 31)); } else { @@ -1587,7 +1582,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); if (inst.b.kind == IrOpKind::Constant) - build.ror(inst.regA64, regOp(inst.a), uint8_t(uintOp(inst.b) & 31)); + build.ror(inst.regA64, regOp(inst.a), uint8_t(unsigned(intOp(inst.b)) & 31)); else { RegisterA64 temp1 = tempUint(inst.a); @@ -1613,39 +1608,51 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::INVOKE_LIBM: { - RegisterA64 temp1 = tempDouble(inst.b); - - build.fmov(d0, temp1); // TODO: aliasing hazard - if (inst.c.kind != IrOpKind::None) { + RegisterA64 temp1 = tempDouble(inst.b); RegisterA64 temp2 = tempDouble(inst.c); - build.fmov(d1, temp2); // TODO: aliasing hazard - regs.spill(build, index, {d0, d1}); + RegisterA64 temp3 = regs.allocTemp(KindA64::d); // note: spill() frees all registers so we need to avoid alloc after spill + regs.spill(build, index, {temp1, temp2}); + + if (d0 != temp2) + { + build.fmov(d0, temp1); + build.fmov(d1, temp2); + } + else + { + build.fmov(temp3, d0); + build.fmov(d0, temp1); + build.fmov(d1, temp3); + } } else - regs.spill(build, index, {d0}); + { + RegisterA64 temp1 = tempDouble(inst.b); + regs.spill(build, index, {temp1}); + build.fmov(d0, temp1); + } - build.ldr(x0, mem(rNativeContext, getNativeContextOffset(LuauBuiltinFunction(uintOp(inst.a))))); + build.ldr(x0, mem(rNativeContext, getNativeContextOffset(uintOp(inst.a)))); build.blr(x0); - // TODO: we could takeReg d0 but it's unclear if we will be able to keep d0 allocatable due to aliasing concerns - inst.regA64 = regs.allocReg(KindA64::d, index); - build.fmov(inst.regA64, d0); + inst.regA64 = regs.takeReg(d0, index); break; } - // Unsupported instructions - // Note: when adding implementations for these, please move the case: label so that implemented instructions match the order in IrData.h - case IrCmd::STORE_VECTOR: - error = true; - break; + // To handle unsupported instructions, add "case IrCmd::OP" and make sure to set error = true! } + valueTracker.afterInstLowering(inst, index); + regs.freeLastUseRegs(inst, index); regs.freeTempRegs(); } -void IrLoweringA64::finishBlock() {} +void IrLoweringA64::finishBlock() +{ + regs.assertNoSpills(); +} bool IrLoweringA64::hasError() const { @@ -1717,7 +1724,7 @@ RegisterA64 IrLoweringA64::tempUint(IrOp op) else if (op.kind == IrOpKind::Constant) { RegisterA64 temp = regs.allocTemp(KindA64::w); - build.mov(temp, uintOp(op)); + build.mov(temp, unsigned(intOp(op))); return temp; } else @@ -1762,7 +1769,7 @@ RegisterA64 IrLoweringA64::regOp(IrOp op) { IrInst& inst = function.instOp(op); - if (inst.spilled) + if (inst.spilled || inst.needsReload) regs.restoreReg(build, inst); LUAU_ASSERT(inst.regA64 != noreg); diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index d5f3a55..9eda897 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -5,6 +5,7 @@ #include "Luau/IrData.h" #include "IrRegAllocA64.h" +#include "IrValueLocationTracking.h" #include @@ -64,6 +65,8 @@ struct IrLoweringA64 IrRegAllocA64 regs; + IrValueLocationTracking valueTracker; + bool error = false; }; diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 9af3f73..bc61757 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -237,17 +237,38 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.vmovups(luauNodeValue(regOp(inst.a)), regOp(inst.b)); break; case IrCmd::ADD_INT: + { inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); - if (inst.b.kind == IrOpKind::Inst) - build.lea(inst.regX64, addr[regOp(inst.a) + regOp(inst.b)]); - else if (inst.regX64 == regOp(inst.a) && intOp(inst.b) == 1) - build.inc(inst.regX64); - else if (inst.regX64 == regOp(inst.a)) - build.add(inst.regX64, intOp(inst.b)); + if (inst.a.kind == IrOpKind::Constant) + { + build.lea(inst.regX64, addr[regOp(inst.b) + intOp(inst.a)]); + } + else if (inst.a.kind == IrOpKind::Inst) + { + if (inst.regX64 == regOp(inst.a)) + { + if (inst.b.kind == IrOpKind::Inst) + build.add(inst.regX64, regOp(inst.b)); + else if (intOp(inst.b) == 1) + build.inc(inst.regX64); + else + build.add(inst.regX64, intOp(inst.b)); + } + else + { + if (inst.b.kind == IrOpKind::Inst) + build.lea(inst.regX64, addr[regOp(inst.a) + regOp(inst.b)]); + else + build.lea(inst.regX64, addr[regOp(inst.a) + intOp(inst.b)]); + } + } else - build.lea(inst.regX64, addr[regOp(inst.a) + intOp(inst.b)]); + { + LUAU_ASSERT(!"Unsupported instruction form"); + } break; + } case IrCmd::SUB_INT: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); @@ -359,15 +380,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; } - case IrCmd::POW_NUM: - { - IrCallWrapperX64 callWrap(regs, build, index); - callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.a), inst.a); - callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.b), inst.b); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); - inst.regX64 = regs.takeReg(xmm0, index); - break; - } case IrCmd::MIN_NUM: inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); @@ -537,7 +549,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.d), next); break; case IrCmd::JUMP_GE_UINT: - build.cmp(regOp(inst.a), uintOp(inst.b)); + build.cmp(regOp(inst.a), unsigned(intOp(inst.b))); build.jcc(ConditionX64::AboveEqual, labelOp(inst.c)); jumpOrFallthrough(blockOp(inst.d), next); @@ -690,8 +702,12 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::FASTCALL: - emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), inst.d, intOp(inst.e), intOp(inst.f)); + { + OperandX64 arg2 = inst.d.kind != IrOpKind::Undef ? memRegDoubleOp(inst.d) : OperandX64{0}; + + emitBuiltin(regs, build, uintOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c), arg2, intOp(inst.e), intOp(inst.f)); break; + } case IrCmd::INVOKE_FASTCALL: { unsigned bfid = uintOp(inst.a); @@ -703,7 +719,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) else if (inst.d.kind == IrOpKind::VmConst) args = luauConstantAddress(vmConstOp(inst.d)); else - LUAU_ASSERT(boolOp(inst.d) == false); + LUAU_ASSERT(inst.d.kind == IrOpKind::Undef); int ra = vmRegOp(inst.b); int arg = vmRegOp(inst.c); @@ -1141,32 +1157,32 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::BITAND_UINT: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); - if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.and_(inst.regX64, memRegUintOp(inst.b)); break; case IrCmd::BITXOR_UINT: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); - if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.xor_(inst.regX64, memRegUintOp(inst.b)); break; case IrCmd::BITOR_UINT: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); - if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.or_(inst.regX64, memRegUintOp(inst.b)); break; case IrCmd::BITNOT_UINT: inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); - if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.not_(inst.regX64); break; @@ -1179,10 +1195,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(shiftTmp.reg, memRegUintOp(inst.b)); - if (inst.a.kind == IrOpKind::Constant) - build.mov(inst.regX64, uintOp(inst.a)); - else if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.shl(inst.regX64, byteReg(shiftTmp.reg)); break; @@ -1196,10 +1210,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(shiftTmp.reg, memRegUintOp(inst.b)); - if (inst.a.kind == IrOpKind::Constant) - build.mov(inst.regX64, uintOp(inst.a)); - else if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.shr(inst.regX64, byteReg(shiftTmp.reg)); break; @@ -1213,10 +1225,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(shiftTmp.reg, memRegUintOp(inst.b)); - if (inst.a.kind == IrOpKind::Constant) - build.mov(inst.regX64, uintOp(inst.a)); - else if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.sar(inst.regX64, byteReg(shiftTmp.reg)); break; @@ -1230,10 +1240,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(shiftTmp.reg, memRegUintOp(inst.b)); - if (inst.a.kind == IrOpKind::Constant) - build.mov(inst.regX64, uintOp(inst.a)); - else if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.rol(inst.regX64, byteReg(shiftTmp.reg)); break; @@ -1247,10 +1255,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(shiftTmp.reg, memRegUintOp(inst.b)); - if (inst.a.kind == IrOpKind::Constant) - build.mov(inst.regX64, uintOp(inst.a)); - else if (inst.regX64 != regOp(inst.a)) - build.mov(inst.regX64, regOp(inst.a)); + if (inst.a.kind != IrOpKind::Inst || inst.regX64 != regOp(inst.a)) + build.mov(inst.regX64, memRegUintOp(inst.a)); build.ror(inst.regX64, byteReg(shiftTmp.reg)); break; @@ -1294,15 +1300,13 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::INVOKE_LIBM: { - LuauBuiltinFunction bfid = LuauBuiltinFunction(uintOp(inst.a)); - IrCallWrapperX64 callWrap(regs, build, index); callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.b), inst.b); if (inst.c.kind != IrOpKind::None) callWrap.addArgument(SizeX64::xmmword, memRegDoubleOp(inst.c), inst.c); - callWrap.call(qword[rNativeContext + getNativeContextOffset(bfid)]); + callWrap.call(qword[rNativeContext + getNativeContextOffset(uintOp(inst.a))]); inst.regX64 = regs.takeReg(xmm0, index); break; } @@ -1370,7 +1374,7 @@ OperandX64 IrLoweringX64::memRegUintOp(IrOp op) case IrOpKind::Inst: return regOp(op); case IrOpKind::Constant: - return OperandX64(uintOp(op)); + return OperandX64(unsigned(intOp(op))); default: LUAU_ASSERT(!"Unsupported operand kind"); } diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp index 17ae70f..a4cfeae 100644 --- a/CodeGen/src/IrRegAllocA64.cpp +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -2,12 +2,15 @@ #include "IrRegAllocA64.h" #include "Luau/AssemblyBuilderA64.h" +#include "Luau/IrUtils.h" #include "BitUtils.h" #include "EmitCommonA64.h" #include +LUAU_FASTFLAGVARIABLE(DebugLuauCodegenChaosA64, false) + namespace Luau { namespace CodeGen @@ -41,6 +44,68 @@ static void freeSpill(uint32_t& free, KindA64 kind, uint8_t slot) free |= mask; } +static int getReloadOffset(IrCmd cmd) +{ + switch (getCmdValueKind(cmd)) + { + case IrValueKind::Unknown: + case IrValueKind::None: + LUAU_ASSERT(!"Invalid operand restore value kind"); + break; + case IrValueKind::Tag: + return offsetof(TValue, tt); + case IrValueKind::Int: + return offsetof(TValue, value); + case IrValueKind::Pointer: + return offsetof(TValue, value.gc); + case IrValueKind::Double: + return offsetof(TValue, value.n); + case IrValueKind::Tvalue: + return 0; + } + + LUAU_ASSERT(!"Invalid operand restore value kind"); + LUAU_UNREACHABLE(); +} + +static AddressA64 getReloadAddress(const IrFunction& function, const IrInst& inst) +{ + IrOp location = function.findRestoreOp(inst); + + if (location.kind == IrOpKind::VmReg) + return mem(rBase, vmRegOp(location) * sizeof(TValue) + getReloadOffset(inst.cmd)); + + // loads are 4/8/16 bytes; we conservatively limit the offset to fit assuming a 4b index + if (location.kind == IrOpKind::VmConst && vmConstOp(location) * sizeof(TValue) <= AddressA64::kMaxOffset * 4) + return mem(rConstants, vmConstOp(location) * sizeof(TValue) + getReloadOffset(inst.cmd)); + + return AddressA64(xzr); // dummy +} + +static void restoreInst(AssemblyBuilderA64& build, uint32_t& freeSpillSlots, IrFunction& function, const IrRegAllocA64::Spill& s, RegisterA64 reg) +{ + IrInst& inst = function.instructions[s.inst]; + LUAU_ASSERT(inst.regA64 == noreg); + + if (s.slot >= 0) + { + build.ldr(reg, mem(sp, sSpillArea.data + s.slot * 8)); + + freeSpill(freeSpillSlots, reg.kind, s.slot); + } + else + { + LUAU_ASSERT(!inst.spilled && inst.needsReload); + AddressA64 addr = getReloadAddress(function, function.instructions[s.inst]); + LUAU_ASSERT(addr.base != xzr); + build.ldr(reg, addr); + } + + inst.spilled = false; + inst.needsReload = false; + inst.regA64 = reg; +} + IrRegAllocA64::IrRegAllocA64(IrFunction& function, std::initializer_list> regs) : function(function) { @@ -70,11 +135,16 @@ RegisterA64 IrRegAllocA64::allocReg(KindA64 kind, uint32_t index) if (set.free == 0) { + // TODO: remember the error and fail lowering LUAU_ASSERT(!"Out of registers to allocate"); return noreg; } int reg = 31 - countlz(set.free); + + if (FFlag::DebugLuauCodegenChaosA64) + reg = countrz(set.free); // allocate from low end; this causes extra conflicts for calls + set.free &= ~(1u << reg); set.defs[reg] = index; @@ -87,12 +157,16 @@ RegisterA64 IrRegAllocA64::allocTemp(KindA64 kind) if (set.free == 0) { + // TODO: remember the error and fail lowering LUAU_ASSERT(!"Out of registers to allocate"); return noreg; } int reg = 31 - countlz(set.free); + if (FFlag::DebugLuauCodegenChaosA64) + reg = countrz(set.free); // allocate from low end; this causes extra conflicts for calls + set.free &= ~(1u << reg); set.temp |= 1u << reg; LUAU_ASSERT(set.defs[reg] == kInvalidInstIdx); @@ -109,8 +183,9 @@ RegisterA64 IrRegAllocA64::allocReuse(KindA64 kind, uint32_t index, std::initial IrInst& source = function.instructions[op.index]; - if (source.lastUse == index && !source.reusedReg && !source.spilled && source.regA64 != noreg) + if (source.lastUse == index && !source.reusedReg && source.regA64 != noreg) { + LUAU_ASSERT(!source.spilled && !source.needsReload); LUAU_ASSERT(source.regA64.kind == kind); Set& set = getSet(kind); @@ -154,7 +229,7 @@ void IrRegAllocA64::freeLastUseReg(IrInst& target, uint32_t index) { if (target.lastUse == index && !target.reusedReg) { - LUAU_ASSERT(!target.spilled); + LUAU_ASSERT(!target.spilled && !target.needsReload); // Register might have already been freed if it had multiple uses inside a single instruction if (target.regA64 == noreg) @@ -197,13 +272,19 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init size_t start = spills.size(); - for (RegisterA64 reg : live) - { - Set& set = getSet(reg.kind); + uint32_t poisongpr = 0; + uint32_t poisonsimd = 0; - // make sure registers that we expect to survive past spill barrier are not allocated - // TODO: we need to handle this condition somehow in the future; if this fails, this likely means the caller has an aliasing hazard - LUAU_ASSERT(set.free & (1u << reg.index)); + if (FFlag::DebugLuauCodegenChaosA64) + { + poisongpr = gpr.base & ~gpr.free; + poisonsimd = simd.base & ~simd.free; + + for (RegisterA64 reg : live) + { + Set& set = getSet(reg.kind); + (&set == &simd ? poisonsimd : poisongpr) &= ~(1u << reg.index); + } } for (KindA64 kind : sets) @@ -231,26 +312,38 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init IrInst& def = function.instructions[inst]; LUAU_ASSERT(def.regA64.index == reg); - LUAU_ASSERT(!def.spilled); LUAU_ASSERT(!def.reusedReg); + LUAU_ASSERT(!def.spilled); + LUAU_ASSERT(!def.needsReload); if (def.lastUse == index) { // instead of spilling the register to never reload it, we assume the register is not needed anymore - def.regA64 = noreg; + } + else if (getReloadAddress(function, def).base != xzr) + { + // instead of spilling the register to stack, we can reload it from VM stack/constants + // we still need to record the spill for restore(start) to work + Spill s = {inst, def.regA64, -1}; + spills.push_back(s); + + def.needsReload = true; } else { int slot = allocSpill(freeSpillSlots, def.regA64.kind); LUAU_ASSERT(slot >= 0); // TODO: remember the error and fail lowering - Spill s = {inst, def.regA64, uint8_t(slot)}; + build.str(def.regA64, mem(sp, sSpillArea.data + slot * 8)); + + Spill s = {inst, def.regA64, int8_t(slot)}; spills.push_back(s); def.spilled = true; - def.regA64 = noreg; } + def.regA64 = noreg; + regs &= ~(1u << reg); set.free |= 1u << reg; set.defs[reg] = kInvalidInstIdx; @@ -259,11 +352,15 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init LUAU_ASSERT(set.free == set.base); } - if (start < spills.size()) + if (FFlag::DebugLuauCodegenChaosA64) { - // TODO: use stp for consecutive slots - for (size_t i = start; i < spills.size(); ++i) - build.str(spills[i].origin, mem(sp, sSpillArea.data + spills[i].slot * 8)); + for (int reg = 0; reg < 32; ++reg) + { + if (poisongpr & (1u << reg)) + build.mov(RegisterA64{KindA64::x, uint8_t(reg)}, 0xdead); + if (poisonsimd & (1u << reg)) + build.fmov(RegisterA64{KindA64::d, uint8_t(reg)}, -0.125); + } } return start; @@ -275,22 +372,12 @@ void IrRegAllocA64::restore(AssemblyBuilderA64& build, size_t start) if (start < spills.size()) { - // TODO: use ldp for consecutive slots - for (size_t i = start; i < spills.size(); ++i) - build.ldr(spills[i].origin, mem(sp, sSpillArea.data + spills[i].slot * 8)); - for (size_t i = start; i < spills.size(); ++i) { Spill s = spills[i]; // copy in case takeReg reallocates spills + RegisterA64 reg = takeReg(s.origin, s.inst); - IrInst& def = function.instructions[s.inst]; - LUAU_ASSERT(def.spilled); - LUAU_ASSERT(def.regA64 == noreg); - - def.spilled = false; - def.regA64 = takeReg(s.origin, s.inst); - - freeSpill(freeSpillSlots, s.origin.kind, s.slot); + restoreInst(build, freeSpillSlots, function, s, reg); } spills.resize(start); @@ -301,9 +388,6 @@ void IrRegAllocA64::restoreReg(AssemblyBuilderA64& build, IrInst& inst) { uint32_t index = function.getInstIndex(inst); - LUAU_ASSERT(inst.spilled); - LUAU_ASSERT(inst.regA64 == noreg); - for (size_t i = 0; i < spills.size(); ++i) { if (spills[i].inst == index) @@ -311,12 +395,7 @@ void IrRegAllocA64::restoreReg(AssemblyBuilderA64& build, IrInst& inst) Spill s = spills[i]; // copy in case allocReg reallocates spills RegisterA64 reg = allocReg(s.origin.kind, index); - build.ldr(reg, mem(sp, sSpillArea.data + s.slot * 8)); - - inst.spilled = false; - inst.regA64 = reg; - - freeSpill(freeSpillSlots, reg.kind, s.slot); + restoreInst(build, freeSpillSlots, function, s, reg); spills[i] = spills.back(); spills.pop_back(); @@ -340,6 +419,7 @@ IrRegAllocA64::Set& IrRegAllocA64::getSet(KindA64 kind) case KindA64::w: return gpr; + case KindA64::s: case KindA64::d: case KindA64::q: return simd; diff --git a/CodeGen/src/IrRegAllocA64.h b/CodeGen/src/IrRegAllocA64.h index 940b511..6897437 100644 --- a/CodeGen/src/IrRegAllocA64.h +++ b/CodeGen/src/IrRegAllocA64.h @@ -65,7 +65,7 @@ struct IrRegAllocA64 uint32_t inst; RegisterA64 origin; - uint8_t slot; + int8_t slot; }; Set& getSet(KindA64 kind); diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index a6da1e4..e58d0a1 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -6,6 +6,8 @@ #include "lstate.h" +#include + // TODO: when nresults is less than our actual result count, we can skip computing/writing unused results static const int kMinMaxUnrolledParams = 5; @@ -16,16 +18,32 @@ namespace Luau namespace CodeGen { +static void builtinCheckDouble(IrBuilder& build, IrOp arg, IrOp fallback) +{ + if (arg.kind == IrOpKind::Constant) + LUAU_ASSERT(build.function.constOp(arg).kind == IrConstKind::Double); + else + build.loadAndCheckTag(arg, LUA_TNUMBER, fallback); +} + +static IrOp builtinLoadDouble(IrBuilder& build, IrOp arg) +{ + if (arg.kind == IrOpKind::Constant) + return arg; + + return build.inst(IrCmd::LOAD_DOUBLE, arg); +} + // Wrapper code for all builtins with a fixed signature and manual assembly lowering of the body // (number, ...) -> number -BuiltinImplResult translateBuiltinNumberToNumber( +static BuiltinImplResult translateBuiltinNumberToNumber( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(1)); if (ra != arg) @@ -34,14 +52,14 @@ BuiltinImplResult translateBuiltinNumberToNumber( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinNumberToNumberLibm( +static BuiltinImplResult translateBuiltinNumberToNumberLibm( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + builtinCheckDouble(build, build.vmReg(arg), fallback); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp res = build.inst(IrCmd::INVOKE_LIBM, build.constUint(bfid), va); @@ -54,14 +72,14 @@ BuiltinImplResult translateBuiltinNumberToNumberLibm( } // (number, number, ...) -> number -BuiltinImplResult translateBuiltin2NumberToNumber( +static BuiltinImplResult translateBuiltin2NumberToNumber( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(2), build.constInt(1)); if (ra != arg) @@ -70,17 +88,17 @@ BuiltinImplResult translateBuiltin2NumberToNumber( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltin2NumberToNumberLibm( +static BuiltinImplResult translateBuiltin2NumberToNumberLibm( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); + IrOp vb = builtinLoadDouble(build, args); IrOp res = build.inst(IrCmd::INVOKE_LIBM, build.constUint(bfid), va, vb); @@ -93,13 +111,13 @@ BuiltinImplResult translateBuiltin2NumberToNumberLibm( } // (number, ...) -> (number, number) -BuiltinImplResult translateBuiltinNumberTo2Number( +static BuiltinImplResult translateBuiltinNumberTo2Number( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 2) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); build.inst( IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(nresults == 1 ? 1 : 2)); @@ -112,7 +130,7 @@ BuiltinImplResult translateBuiltinNumberTo2Number( return {BuiltinImplType::UsesFallback, 2}; } -BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults != 0) return {BuiltinImplType::None, -1}; @@ -126,16 +144,16 @@ BuiltinImplResult translateBuiltinAssert(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 0}; } -BuiltinImplResult translateBuiltinMathDeg(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathDeg(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); const double rpd = (3.14159265358979323846 / 180.0); - IrOp varg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp varg = builtinLoadDouble(build, build.vmReg(arg)); IrOp value = build.inst(IrCmd::DIV_NUM, varg, build.constDouble(rpd)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); @@ -145,16 +163,16 @@ BuiltinImplResult translateBuiltinMathDeg(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinMathRad(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathRad(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); const double rpd = (3.14159265358979323846 / 180.0); - IrOp varg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp varg = builtinLoadDouble(build, build.vmReg(arg)); IrOp value = build.inst(IrCmd::MUL_NUM, varg, build.constDouble(rpd)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); @@ -164,48 +182,40 @@ BuiltinImplResult translateBuiltinMathRad(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinMathLog( +static BuiltinImplResult translateBuiltinMathLog( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - LuauBuiltinFunction fcId = bfid; - int fcParams = 1; + int libmId = bfid; + std::optional denom; if (nparams != 1) { - if (args.kind != IrOpKind::VmConst) + std::optional y = build.function.asDoubleOp(args); + + if (!y) return {BuiltinImplType::None, -1}; - LUAU_ASSERT(build.function.proto); - TValue protok = build.function.proto->k[vmConstOp(args)]; - - if (protok.tt != LUA_TNUMBER) - return {BuiltinImplType::None, -1}; - - // TODO: IR builtin lowering assumes that the only valid 2-argument call is log2; ideally, we use a less hacky way to indicate that - if (protok.value.n == 2.0) - fcParams = 2; - else if (protok.value.n == 10.0) - fcId = LBF_MATH_LOG10; + if (*y == 2.0) + libmId = LBF_IR_MATH_LOG2; + else if (*y == 10.0) + libmId = LBF_MATH_LOG10; else - // TODO: We can precompute log(args) and divide by it, but that requires extra LOAD/STORE so for now just fall back as this is rare - return {BuiltinImplType::None, -1}; + denom = log(*y); } - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); - if (fcId == LBF_MATH_LOG10) - { - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); - IrOp res = build.inst(IrCmd::INVOKE_LIBM, build.constUint(fcId), va); + IrOp res = build.inst(IrCmd::INVOKE_LIBM, build.constUint(libmId), va); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), res); - } - else - build.inst(IrCmd::FASTCALL, build.constUint(fcId), build.vmReg(ra), build.vmReg(arg), args, build.constInt(fcParams), build.constInt(1)); + if (denom) + res = build.inst(IrCmd::DIV_NUM, res, build.constDouble(*denom)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), res); if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); @@ -213,25 +223,25 @@ BuiltinImplResult translateBuiltinMathLog( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinMathMin(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathMin(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nparams > kMinMaxUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); for (int i = 3; i <= nparams; ++i) - build.loadAndCheckTag(build.vmReg(vmRegOp(args) + (i - 2)), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), fallback); - IrOp varg1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp varg2 = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp varg1 = builtinLoadDouble(build, build.vmReg(arg)); + IrOp varg2 = builtinLoadDouble(build, args); IrOp res = build.inst(IrCmd::MIN_NUM, varg2, varg1); // Swapped arguments are required for consistency with VM builtins for (int i = 3; i <= nparams; ++i) { - IrOp arg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + (i - 2))); + IrOp arg = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + (i - 2))); res = build.inst(IrCmd::MIN_NUM, arg, res); } @@ -243,25 +253,25 @@ BuiltinImplResult translateBuiltinMathMin(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinMathMax(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathMax(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nparams > kMinMaxUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); for (int i = 3; i <= nparams; ++i) - build.loadAndCheckTag(build.vmReg(vmRegOp(args) + (i - 2)), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), fallback); - IrOp varg1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp varg2 = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp varg1 = builtinLoadDouble(build, build.vmReg(arg)); + IrOp varg2 = builtinLoadDouble(build, args); IrOp res = build.inst(IrCmd::MAX_NUM, varg2, varg1); // Swapped arguments are required for consistency with VM builtins for (int i = 3; i <= nparams; ++i) { - IrOp arg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + (i - 2))); + IrOp arg = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + (i - 2))); res = build.inst(IrCmd::MAX_NUM, arg, res); } @@ -273,7 +283,7 @@ BuiltinImplResult translateBuiltinMathMax(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -282,17 +292,17 @@ BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int r LUAU_ASSERT(args.kind == IrOpKind::VmReg); - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); - build.loadAndCheckTag(build.vmReg(vmRegOp(args) + 1), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), fallback); - IrOp min = build.inst(IrCmd::LOAD_DOUBLE, args); - IrOp max = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + 1)); + IrOp min = builtinLoadDouble(build, args); + IrOp max = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + 1)); build.inst(IrCmd::JUMP_CMP_NUM, min, max, build.cond(IrCondition::NotLessEqual), fallback, block); build.beginBlock(block); - IrOp v = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp v = builtinLoadDouble(build, build.vmReg(arg)); IrOp r = build.inst(IrCmd::MAX_NUM, min, v); IrOp clamped = build.inst(IrCmd::MIN_NUM, max, r); @@ -304,14 +314,14 @@ BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int r return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); - IrOp varg = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp varg = builtinLoadDouble(build, build.vmReg(arg)); IrOp result = build.inst(cmd, varg); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), result); @@ -322,27 +332,7 @@ BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int npa return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinMathBinary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) -{ - if (nparams < 2 || nresults > 1) - return {BuiltinImplType::None, -1}; - - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); - - IrOp lhs = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp rhs = build.inst(IrCmd::LOAD_DOUBLE, args); - IrOp result = build.inst(cmd, lhs, rhs); - - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), result); - - if (ra != arg) - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); - - return {BuiltinImplType::UsesFallback, 1}; -} - -BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -354,7 +344,7 @@ BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, in return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; @@ -366,20 +356,20 @@ BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32BinaryOp( +static BuiltinImplResult translateBuiltinBit32BinaryOp( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nparams > kBit32BinaryOpUnrolledParams || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); for (int i = 3; i <= nparams; ++i) - build.loadAndCheckTag(build.vmReg(vmRegOp(args) + (i - 2)), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + (i - 2)), fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); + IrOp vb = builtinLoadDouble(build, args); IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va); IrOp vbui = build.inst(IrCmd::NUM_TO_UINT, vb); @@ -399,7 +389,7 @@ BuiltinImplResult translateBuiltinBit32BinaryOp( for (int i = 3; i <= nparams; ++i) { - IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + (i - 2))); + IrOp vc = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + (i - 2))); IrOp arg = build.inst(IrCmd::NUM_TO_UINT, vc); res = build.inst(cmd, res, arg); @@ -436,14 +426,14 @@ BuiltinImplResult translateBuiltinBit32BinaryOp( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32Bnot( +static BuiltinImplResult translateBuiltinBit32Bnot( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + builtinCheckDouble(build, build.vmReg(arg), fallback); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va); IrOp not_ = build.inst(IrCmd::BITNOT_UINT, vaui); @@ -457,7 +447,7 @@ BuiltinImplResult translateBuiltinBit32Bnot( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32Shift( +static BuiltinImplResult translateBuiltinBit32Shift( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nresults > 1) @@ -465,16 +455,16 @@ BuiltinImplResult translateBuiltinBit32Shift( IrOp block = build.block(IrBlockKind::Internal); - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); + IrOp vb = builtinLoadDouble(build, args); IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va); IrOp vbi = build.inst(IrCmd::NUM_TO_INT, vb); - build.inst(IrCmd::JUMP_GE_UINT, vbi, build.constUint(32), fallback, block); + build.inst(IrCmd::JUMP_GE_UINT, vbi, build.constInt(32), fallback, block); build.beginBlock(block); IrCmd cmd = IrCmd::NOP; @@ -498,17 +488,17 @@ BuiltinImplResult translateBuiltinBit32Shift( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32Rotate( +static BuiltinImplResult translateBuiltinBit32Rotate( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); + IrOp vb = builtinLoadDouble(build, args); IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va); IrOp vbi = build.inst(IrCmd::NUM_TO_INT, vb); @@ -525,17 +515,17 @@ BuiltinImplResult translateBuiltinBit32Rotate( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32Extract( +static BuiltinImplResult translateBuiltinBit32Extract( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, args); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); + IrOp vb = builtinLoadDouble(build, args); IrOp n = build.inst(IrCmd::NUM_TO_UINT, va); IrOp f = build.inst(IrCmd::NUM_TO_INT, vb); @@ -544,17 +534,17 @@ BuiltinImplResult translateBuiltinBit32Extract( if (nparams == 2) { IrOp block = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_GE_UINT, f, build.constUint(32), fallback, block); + build.inst(IrCmd::JUMP_GE_UINT, f, build.constInt(32), fallback, block); build.beginBlock(block); // TODO: this can be optimized using a bit-select instruction (bt on x86) IrOp shift = build.inst(IrCmd::BITRSHIFT_UINT, n, f); - value = build.inst(IrCmd::BITAND_UINT, shift, build.constUint(1)); + value = build.inst(IrCmd::BITAND_UINT, shift, build.constInt(1)); } else { - build.loadAndCheckTag(build.vmReg(args.index + 1), LUA_TNUMBER, fallback); - IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(args.index + 1)); + builtinCheckDouble(build, build.vmReg(args.index + 1), fallback); + IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1)); IrOp w = build.inst(IrCmd::NUM_TO_INT, vc); IrOp block1 = build.block(IrBlockKind::Internal); @@ -570,7 +560,7 @@ BuiltinImplResult translateBuiltinBit32Extract( build.inst(IrCmd::JUMP_LT_INT, fw, build.constInt(33), block3, fallback); build.beginBlock(block3); - IrOp shift = build.inst(IrCmd::BITLSHIFT_UINT, build.constUint(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); + IrOp shift = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); IrOp m = build.inst(IrCmd::BITNOT_UINT, shift); IrOp nf = build.inst(IrCmd::BITRSHIFT_UINT, n, f); @@ -585,15 +575,15 @@ BuiltinImplResult translateBuiltinBit32Extract( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32ExtractK( +static BuiltinImplResult translateBuiltinBit32ExtractK( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 2 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp n = build.inst(IrCmd::NUM_TO_UINT, va); double a2 = build.function.doubleOp(args); @@ -604,8 +594,8 @@ BuiltinImplResult translateBuiltinBit32ExtractK( uint32_t m = ~(0xfffffffeu << w1); - IrOp nf = build.inst(IrCmd::BITRSHIFT_UINT, n, build.constUint(f)); - IrOp and_ = build.inst(IrCmd::BITAND_UINT, nf, build.constUint(m)); + IrOp nf = build.inst(IrCmd::BITRSHIFT_UINT, n, build.constInt(f)); + IrOp and_ = build.inst(IrCmd::BITAND_UINT, nf, build.constInt(m)); IrOp value = build.inst(IrCmd::UINT_TO_NUM, and_); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); @@ -616,14 +606,14 @@ BuiltinImplResult translateBuiltinBit32ExtractK( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32Countz( +static BuiltinImplResult translateBuiltinBit32Countz( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 1 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); + builtinCheckDouble(build, build.vmReg(arg), fallback); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); IrOp vaui = build.inst(IrCmd::NUM_TO_UINT, va); @@ -640,19 +630,19 @@ BuiltinImplResult translateBuiltinBit32Countz( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinBit32Replace( +static BuiltinImplResult translateBuiltinBit32Replace( IrBuilder& build, LuauBuiltinFunction bfid, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); - build.loadAndCheckTag(build.vmReg(args.index + 1), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); + builtinCheckDouble(build, build.vmReg(args.index + 1), fallback); - IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, args); - IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(args.index + 1)); + IrOp va = builtinLoadDouble(build, build.vmReg(arg)); + IrOp vb = builtinLoadDouble(build, args); + IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1)); IrOp n = build.inst(IrCmd::NUM_TO_UINT, va); IrOp v = build.inst(IrCmd::NUM_TO_UINT, vb); @@ -662,11 +652,11 @@ BuiltinImplResult translateBuiltinBit32Replace( if (nparams == 3) { IrOp block = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_GE_UINT, f, build.constUint(32), fallback, block); + build.inst(IrCmd::JUMP_GE_UINT, f, build.constInt(32), fallback, block); build.beginBlock(block); // TODO: this can be optimized using a bit-select instruction (btr on x86) - IrOp m = build.constUint(1); + IrOp m = build.constInt(1); IrOp shift = build.inst(IrCmd::BITLSHIFT_UINT, m, f); IrOp not_ = build.inst(IrCmd::BITNOT_UINT, shift); IrOp lhs = build.inst(IrCmd::BITAND_UINT, n, not_); @@ -678,8 +668,8 @@ BuiltinImplResult translateBuiltinBit32Replace( } else { - build.loadAndCheckTag(build.vmReg(args.index + 2), LUA_TNUMBER, fallback); - IrOp vd = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(args.index + 2)); + builtinCheckDouble(build, build.vmReg(args.index + 2), fallback); + IrOp vd = builtinLoadDouble(build, build.vmReg(args.index + 2)); IrOp w = build.inst(IrCmd::NUM_TO_INT, vd); IrOp block1 = build.block(IrBlockKind::Internal); @@ -695,7 +685,7 @@ BuiltinImplResult translateBuiltinBit32Replace( build.inst(IrCmd::JUMP_LT_INT, fw, build.constInt(33), block3, fallback); build.beginBlock(block3); - IrOp shift1 = build.inst(IrCmd::BITLSHIFT_UINT, build.constUint(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); + IrOp shift1 = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); IrOp m = build.inst(IrCmd::BITNOT_UINT, shift1); IrOp shift2 = build.inst(IrCmd::BITLSHIFT_UINT, m, f); @@ -716,20 +706,20 @@ BuiltinImplResult translateBuiltinBit32Replace( return {BuiltinImplType::UsesFallback, 1}; } -BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) { if (nparams < 3 || nresults > 1) return {BuiltinImplType::None, -1}; LUAU_ASSERT(LUA_VECTOR_SIZE == 3); - build.loadAndCheckTag(build.vmReg(arg), LUA_TNUMBER, fallback); - build.loadAndCheckTag(args, LUA_TNUMBER, fallback); - build.loadAndCheckTag(build.vmReg(vmRegOp(args) + 1), LUA_TNUMBER, fallback); + builtinCheckDouble(build, build.vmReg(arg), fallback); + builtinCheckDouble(build, args, fallback); + builtinCheckDouble(build, build.vmReg(vmRegOp(args) + 1), fallback); - IrOp x = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(arg)); - IrOp y = build.inst(IrCmd::LOAD_DOUBLE, args); - IrOp z = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(vmRegOp(args) + 1)); + IrOp x = builtinLoadDouble(build, build.vmReg(arg)); + IrOp y = builtinLoadDouble(build, args); + IrOp z = builtinLoadDouble(build, build.vmReg(vmRegOp(args) + 1)); build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), x, y, z); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); @@ -769,8 +759,6 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, return translateBuiltinMathUnary(build, IrCmd::ABS_NUM, nparams, ra, arg, nresults, fallback); case LBF_MATH_ROUND: return translateBuiltinMathUnary(build, IrCmd::ROUND_NUM, nparams, ra, arg, nresults, fallback); - case LBF_MATH_POW: - return translateBuiltinMathBinary(build, IrCmd::POW_NUM, nparams, ra, arg, args, nresults, fallback); case LBF_MATH_EXP: case LBF_MATH_ASIN: case LBF_MATH_SIN: @@ -785,6 +773,7 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, return translateBuiltinNumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); case LBF_MATH_SIGN: return translateBuiltinNumberToNumber(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + case LBF_MATH_POW: case LBF_MATH_FMOD: case LBF_MATH_ATAN2: return translateBuiltin2NumberToNumberLibm(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index a42a726..ebbcd87 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -342,7 +342,7 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, result = build.inst(IrCmd::MOD_NUM, vb, vc); break; case TM_POW: - result = build.inst(IrCmd::POW_NUM, vb, vc); + result = build.inst(IrCmd::INVOKE_LIBM, build.constUint(LBF_MATH_POW), vb, vc); break; default: LUAU_ASSERT(!"unsupported binary op"); @@ -498,8 +498,6 @@ void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool int bfid = LUAU_INSN_A(*pc); int skip = LUAU_INSN_C(*pc); - IrOp fallback = build.block(IrBlockKind::Fallback); - Instruction call = pc[skip + 1]; LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); int ra = LUAU_INSN_A(call); @@ -509,15 +507,21 @@ void translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool int arg = customParams ? LUAU_INSN_B(*pc) : ra + 1; IrOp args = customParams ? customArgs : build.vmReg(ra + 2); - build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + IrOp builtinArgs = args; - if (bfid == LBF_BIT32_EXTRACTK) + if (customArgs.kind == IrOpKind::VmConst) { - TValue protok = build.function.proto->k[pc[1]]; - args = build.constDouble(protok.value.n); + TValue protok = build.function.proto->k[customArgs.index]; + + if (protok.tt == LUA_TNUMBER) + builtinArgs = build.constDouble(protok.value.n); } - BuiltinImplResult br = translateBuiltin(build, LuauBuiltinFunction(bfid), ra, arg, args, nparams, nresults, fallback); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + + BuiltinImplResult br = translateBuiltin(build, LuauBuiltinFunction(bfid), ra, arg, builtinArgs, nparams, nresults, fallback); if (br.type == BuiltinImplType::UsesFallback) { diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 8be9e1b..a3af434 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -3,6 +3,7 @@ #include "Luau/IrBuilder.h" +#include "BitUtils.h" #include "NativeState.h" #include "lua.h" @@ -54,7 +55,6 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: - case IrCmd::POW_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: @@ -312,6 +312,8 @@ void substitute(IrFunction& function, IrInst& inst, IrOp replacement) inst.cmd = IrCmd::SUBSTITUTE; + addUse(function, replacement); + removeUse(function, inst.a); removeUse(function, inst.b); removeUse(function, inst.c); @@ -349,6 +351,9 @@ void applySubstitutions(IrFunction& function, IrOp& op) LUAU_ASSERT(src.useCount > 0); src.useCount--; + + if (src.useCount == 0) + removeUse(function, src.a); } } } @@ -444,17 +449,13 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) substitute(function, inst, build.constDouble(luai_nummod(function.doubleOp(inst.a), function.doubleOp(inst.b)))); break; - case IrCmd::POW_NUM: - if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) - substitute(function, inst, build.constDouble(pow(function.doubleOp(inst.a), function.doubleOp(inst.b)))); - break; case IrCmd::MIN_NUM: if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) { double a1 = function.doubleOp(inst.a); double a2 = function.doubleOp(inst.b); - substitute(function, inst, build.constDouble((a2 < a1) ? a2 : a1)); + substitute(function, inst, build.constDouble(a1 < a2 ? a1 : a2)); } break; case IrCmd::MAX_NUM: @@ -463,7 +464,7 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 double a1 = function.doubleOp(inst.a); double a2 = function.doubleOp(inst.b); - substitute(function, inst, build.constDouble((a2 > a1) ? a2 : a1)); + substitute(function, inst, build.constDouble(a1 > a2 ? a1 : a2)); } break; case IrCmd::UNM_NUM: @@ -533,7 +534,7 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 case IrCmd::JUMP_GE_UINT: if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) { - if (function.uintOp(inst.a) >= function.uintOp(inst.b)) + if (unsigned(function.intOp(inst.a)) >= unsigned(function.intOp(inst.b))) replace(function, block, index, {IrCmd::JUMP, inst.c}); else replace(function, block, index, {IrCmd::JUMP, inst.d}); @@ -573,6 +574,30 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 if (inst.a.kind == IrOpKind::Constant) substitute(function, inst, build.constDouble(double(function.intOp(inst.a)))); break; + case IrCmd::UINT_TO_NUM: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(double(unsigned(function.intOp(inst.a))))); + break; + case IrCmd::NUM_TO_INT: + if (inst.a.kind == IrOpKind::Constant) + { + double value = function.doubleOp(inst.a); + + // To avoid undefined behavior of casting a value not representable in the target type, we check the range + if (value >= INT_MIN && value <= INT_MAX) + substitute(function, inst, build.constInt(int(value))); + } + break; + case IrCmd::NUM_TO_UINT: + if (inst.a.kind == IrOpKind::Constant) + { + double value = function.doubleOp(inst.a); + + // To avoid undefined behavior of casting a value not representable in the target type, we check the range + if (value >= 0 && value <= UINT_MAX) + substitute(function, inst, build.constInt(unsigned(function.doubleOp(inst.a)))); + } + break; case IrCmd::CHECK_TAG: if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) { @@ -582,12 +607,139 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 replace(function, block, index, {IrCmd::JUMP, inst.c}); // Shows a conflict in assumptions on this path } break; + case IrCmd::BITAND_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + unsigned op1 = unsigned(function.intOp(inst.a)); + unsigned op2 = unsigned(function.intOp(inst.b)); + substitute(function, inst, build.constInt(op1 & op2)); + } + else + { + if (inst.a.kind == IrOpKind::Constant && function.intOp(inst.a) == 0) // (0 & b) -> 0 + substitute(function, inst, build.constInt(0)); + else if (inst.a.kind == IrOpKind::Constant && function.intOp(inst.a) == -1) // (-1 & b) -> b + substitute(function, inst, inst.b); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) // (a & 0) -> 0 + substitute(function, inst, build.constInt(0)); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == -1) // (a & -1) -> a + substitute(function, inst, inst.a); + } + break; + case IrCmd::BITXOR_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + unsigned op1 = unsigned(function.intOp(inst.a)); + unsigned op2 = unsigned(function.intOp(inst.b)); + substitute(function, inst, build.constInt(op1 ^ op2)); + } + else + { + if (inst.a.kind == IrOpKind::Constant && function.intOp(inst.a) == 0) // (0 ^ b) -> b + substitute(function, inst, inst.b); + else if (inst.a.kind == IrOpKind::Constant && function.intOp(inst.a) == -1) // (-1 ^ b) -> ~b + replace(function, block, index, {IrCmd::BITNOT_UINT, inst.b}); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) // (a ^ 0) -> a + substitute(function, inst, inst.a); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == -1) // (a ^ -1) -> ~a + replace(function, block, index, {IrCmd::BITNOT_UINT, inst.a}); + } + break; + case IrCmd::BITOR_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + unsigned op1 = unsigned(function.intOp(inst.a)); + unsigned op2 = unsigned(function.intOp(inst.b)); + substitute(function, inst, build.constInt(op1 | op2)); + } + else + { + if (inst.a.kind == IrOpKind::Constant && function.intOp(inst.a) == 0) // (0 | b) -> b + substitute(function, inst, inst.b); + else if (inst.a.kind == IrOpKind::Constant && function.intOp(inst.a) == -1) // (-1 | b) -> -1 + substitute(function, inst, build.constInt(-1)); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) // (a | 0) -> a + substitute(function, inst, inst.a); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == -1) // (a | -1) -> -1 + substitute(function, inst, build.constInt(-1)); + } + break; + case IrCmd::BITNOT_UINT: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constInt(~unsigned(function.intOp(inst.a)))); + break; + case IrCmd::BITLSHIFT_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + unsigned op1 = unsigned(function.intOp(inst.a)); + int op2 = function.intOp(inst.b); + + if (unsigned(op2) < 32) + substitute(function, inst, build.constInt(op1 << op2)); + } + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) + { + substitute(function, inst, inst.a); + } + break; + case IrCmd::BITRSHIFT_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + unsigned op1 = unsigned(function.intOp(inst.a)); + int op2 = function.intOp(inst.b); + + if (unsigned(op2) < 32) + substitute(function, inst, build.constInt(op1 >> op2)); + } + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) + { + substitute(function, inst, inst.a); + } + break; + case IrCmd::BITARSHIFT_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + { + int op1 = function.intOp(inst.a); + int op2 = function.intOp(inst.b); + + if (unsigned(op2) < 32) + { + // note: technically right shift of negative values is UB, but this behavior is getting defined in C++20 and all compilers do the + // right (shift) thing. + substitute(function, inst, build.constInt(op1 >> op2)); + } + } + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) + { + substitute(function, inst, inst.a); + } + break; + case IrCmd::BITLROTATE_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + substitute(function, inst, build.constInt(lrotate(unsigned(function.intOp(inst.a)), function.intOp(inst.b)))); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) + substitute(function, inst, inst.a); + break; + case IrCmd::BITRROTATE_UINT: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + substitute(function, inst, build.constInt(rrotate(unsigned(function.intOp(inst.a)), function.intOp(inst.b)))); + else if (inst.b.kind == IrOpKind::Constant && function.intOp(inst.b) == 0) + substitute(function, inst, inst.a); + break; + case IrCmd::BITCOUNTLZ_UINT: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constInt(countlz(unsigned(function.intOp(inst.a))))); + break; + case IrCmd::BITCOUNTRZ_UINT: + if (inst.a.kind == IrOpKind::Constant) + substitute(function, inst, build.constInt(countrz(unsigned(function.intOp(inst.a))))); + break; default: break; } } -uint32_t getNativeContextOffset(LuauBuiltinFunction bfid) +uint32_t getNativeContextOffset(int bfid) { switch (bfid) { @@ -607,6 +759,8 @@ uint32_t getNativeContextOffset(LuauBuiltinFunction bfid) return offsetof(NativeContext, libm_exp); case LBF_MATH_LOG10: return offsetof(NativeContext, libm_log10); + case LBF_MATH_LOG: + return offsetof(NativeContext, libm_log); case LBF_MATH_SINH: return offsetof(NativeContext, libm_sinh); case LBF_MATH_SIN: @@ -617,6 +771,10 @@ uint32_t getNativeContextOffset(LuauBuiltinFunction bfid) return offsetof(NativeContext, libm_tan); case LBF_MATH_FMOD: return offsetof(NativeContext, libm_fmod); + case LBF_MATH_POW: + return offsetof(NativeContext, libm_pow); + case LBF_IR_MATH_LOG2: + return offsetof(NativeContext, libm_log2); default: LUAU_ASSERT(!"Unsupported bfid"); } diff --git a/CodeGen/src/IrValueLocationTracking.cpp b/CodeGen/src/IrValueLocationTracking.cpp index b8220f5..be661a7 100644 --- a/CodeGen/src/IrValueLocationTracking.cpp +++ b/CodeGen/src/IrValueLocationTracking.cpp @@ -117,7 +117,6 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: - case IrCmd::POW_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: case IrCmd::JUMP_EQ_TAG: diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 9587a22..459aeaa 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -27,9 +27,12 @@ using FallbackFn = const Instruction* (*)(lua_State* L, const Instruction* pc, S struct NativeProto { - uintptr_t entryTarget = 0; - uintptr_t* instTargets = nullptr; // TODO: NativeProto should be variable-size with all target embedded + // This array is stored before NativeProto in reverse order, so to get offset of instruction i you need to index instOffsets[-i] + // This awkward layout is helpful for maximally efficient address computation on X64/A64 + uint32_t instOffsets[1]; + uintptr_t instBase = 0; + uintptr_t entryTarget = 0; // = instOffsets[0] + instBase Proto* proto = nullptr; }; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 37af5a3..e766366 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -8,6 +8,7 @@ #include "lua.h" +#include #include LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) @@ -42,8 +43,9 @@ struct RegisterLink // Data we know about the current VM state struct ConstPropState { - ConstPropState(const IrFunction& function) + ConstPropState(IrFunction& function) : function(function) + , valueMap({}) { } @@ -58,7 +60,13 @@ struct ConstPropState void saveTag(IrOp op, uint8_t tag) { if (RegisterInfo* info = tryGetRegisterInfo(op)) - info->tag = tag; + { + if (info->tag != tag) + { + info->tag = tag; + info->version++; + } + } } IrOp tryGetValue(IrOp op) @@ -74,7 +82,15 @@ struct ConstPropState LUAU_ASSERT(value.kind == IrOpKind::Constant); if (RegisterInfo* info = tryGetRegisterInfo(op)) - info->value = value; + { + if (info->value != value) + { + info->value = value; + info->knownNotReadonly = false; + info->knownNoMetatable = false; + info->version++; + } + } } void invalidate(RegisterInfo& reg, bool invalidateTag, bool invalidateValue) @@ -96,16 +112,22 @@ struct ConstPropState void invalidateTag(IrOp regOp) { + // TODO: use maxstacksize from Proto + maxReg = vmRegOp(regOp) > maxReg ? vmRegOp(regOp) : maxReg; invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ true, /* invalidateValue */ false); } void invalidateValue(IrOp regOp) { + // TODO: use maxstacksize from Proto + maxReg = vmRegOp(regOp) > maxReg ? vmRegOp(regOp) : maxReg; invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ false, /* invalidateValue */ true); } void invalidate(IrOp regOp) { + // TODO: use maxstacksize from Proto + maxReg = vmRegOp(regOp) > maxReg ? vmRegOp(regOp) : maxReg; invalidate(regs[vmRegOp(regOp)], /* invalidateTag */ true, /* invalidateValue */ true); } @@ -113,8 +135,6 @@ struct ConstPropState { for (int i = firstReg; i <= maxReg; ++i) invalidate(regs[i], /* invalidateTag */ true, /* invalidateValue */ true); - - maxReg = int(firstReg) - 1; } void invalidateRegisterRange(int firstReg, int count) @@ -191,9 +211,90 @@ struct ConstPropState return nullptr; } - const IrFunction& function; + // Attach register version number to the register operand in a load instruction + // This is used to allow instructions with register references to be compared for equality + IrInst versionedVmRegLoad(IrCmd loadCmd, IrOp op) + { + LUAU_ASSERT(op.kind == IrOpKind::VmReg); + uint32_t version = regs[vmRegOp(op)].version; + LUAU_ASSERT(version <= 0xffffff); + op.index = vmRegOp(op) | (version << 8); + return IrInst{loadCmd, op}; + } - RegisterInfo regs[256]; + // Find existing value of the instruction that is exactly the same, or record current on for future lookups + void substituteOrRecord(IrInst& inst, uint32_t instIdx) + { + if (!useValueNumbering) + return; + + if (uint32_t* prevIdx = valueMap.find(inst)) + substitute(function, inst, IrOp{IrOpKind::Inst, *prevIdx}); + else + valueMap[inst] = instIdx; + } + + // Vm register load can be replaced by a previous load of the same version of the register + // If there is no previous load, we record the current one for future lookups + void substituteOrRecordVmRegLoad(IrInst& loadInst) + { + LUAU_ASSERT(loadInst.a.kind == IrOpKind::VmReg); + + if (!useValueNumbering) + return; + + // To avoid captured register invalidation tracking in lowering later, values from loads from captured registers are not propagated + // This prevents the case where load value location is linked to memory in case of a spill and is then cloberred in a user call + if (function.cfg.captured.regs.test(vmRegOp(loadInst.a))) + return; + + IrInst versionedLoad = versionedVmRegLoad(loadInst.cmd, loadInst.a); + + // Check if there is a value that already has this version of the register + if (uint32_t* prevIdx = valueMap.find(versionedLoad)) + { + // Previous value might not be linked to a register yet + // For example, it could be a NEW_TABLE stored into a register and we might need to track guards made with this value + if (!instLink.contains(*prevIdx)) + createRegLink(*prevIdx, loadInst.a); + + // Substitute load instructon with the previous value + substitute(function, loadInst, IrOp{IrOpKind::Inst, *prevIdx}); + } + else + { + uint32_t instIdx = function.getInstIndex(loadInst); + + // Record load of this register version for future substitution + valueMap[versionedLoad] = instIdx; + + createRegLink(instIdx, loadInst.a); + } + } + + // VM register loads can use the value that was stored in the same Vm register earlier + void forwardVmRegStoreToLoad(const IrInst& storeInst, IrCmd loadCmd) + { + LUAU_ASSERT(storeInst.a.kind == IrOpKind::VmReg); + LUAU_ASSERT(storeInst.b.kind == IrOpKind::Inst); + + if (!useValueNumbering) + return; + + // To avoid captured register invalidation tracking in lowering later, values from stores into captured registers are not propagated + // This prevents the case where store creates an alternative value location in case of a spill and is then cloberred in a user call + if (function.cfg.captured.regs.test(vmRegOp(storeInst.a))) + return; + + // Future loads of this register version can use the value we stored + valueMap[versionedVmRegLoad(loadCmd, storeInst.a)] = storeInst.b.index; + } + + IrFunction& function; + + bool useValueNumbering = false; + + std::array regs; // For range/full invalidations, we only want to visit a limited number of data that we have recorded int maxReg = 0; @@ -202,6 +303,8 @@ struct ConstPropState bool checkedGc = false; DenseHashMap instLink{~0u}; + + DenseHashMap valueMap; }; static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid, uint32_t firstReturnReg, int nresults) @@ -277,6 +380,7 @@ static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid } // TODO: classify further using switch above, some fastcalls only modify the value, not the tag + // TODO: fastcalls are different from calls and it might be possible to not invalidate all register starting from return state.invalidateRegistersFrom(firstReturnReg); } @@ -292,45 +396,65 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& break; case IrCmd::LOAD_POINTER: if (inst.a.kind == IrOpKind::VmReg) - state.createRegLink(index, inst.a); + state.substituteOrRecordVmRegLoad(inst); break; case IrCmd::LOAD_DOUBLE: if (IrOp value = state.tryGetValue(inst.a); value.kind == IrOpKind::Constant) substitute(function, inst, value); else if (inst.a.kind == IrOpKind::VmReg) - state.createRegLink(index, inst.a); + state.substituteOrRecordVmRegLoad(inst); break; case IrCmd::LOAD_INT: if (IrOp value = state.tryGetValue(inst.a); value.kind == IrOpKind::Constant) substitute(function, inst, value); else if (inst.a.kind == IrOpKind::VmReg) - state.createRegLink(index, inst.a); + state.substituteOrRecordVmRegLoad(inst); break; case IrCmd::LOAD_TVALUE: if (inst.a.kind == IrOpKind::VmReg) - state.createRegLink(index, inst.a); + state.substituteOrRecordVmRegLoad(inst); break; case IrCmd::STORE_TAG: if (inst.a.kind == IrOpKind::VmReg) { + const IrOp source = inst.a; + uint32_t activeLoadDoubleValue = kInvalidInstIdx; + if (inst.b.kind == IrOpKind::Constant) { uint8_t value = function.tagOp(inst.b); - if (state.tryGetTag(inst.a) == value) + // STORE_TAG usually follows a store of the value, but it also bumps the version of the whole register + // To be able to propagate STORE_DOUBLE into LOAD_DOUBLE, we find active LOAD_DOUBLE value and recreate it with updated version + // Register in this optimization cannot be captured to avoid complications in lowering (IrValueLocationTracking doesn't model it) + // If stored tag is not a number, we can skip the lookup as there won't be future loads of this register as a number + if (value == LUA_TNUMBER && !function.cfg.captured.regs.test(vmRegOp(source))) + { + if (uint32_t* prevIdx = state.valueMap.find(state.versionedVmRegLoad(IrCmd::LOAD_DOUBLE, source))) + activeLoadDoubleValue = *prevIdx; + } + + if (state.tryGetTag(source) == value) kill(function, inst); else - state.saveTag(inst.a, value); + state.saveTag(source, value); } else { - state.invalidateTag(inst.a); + state.invalidateTag(source); } + + // Future LOAD_DOUBLE instructions can re-use previous register version load + if (activeLoadDoubleValue != kInvalidInstIdx) + state.valueMap[state.versionedVmRegLoad(IrCmd::LOAD_DOUBLE, source)] = activeLoadDoubleValue; } break; case IrCmd::STORE_POINTER: if (inst.a.kind == IrOpKind::VmReg) + { state.invalidateValue(inst.a); + state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_POINTER); + } break; case IrCmd::STORE_DOUBLE: if (inst.a.kind == IrOpKind::VmReg) @@ -345,6 +469,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& else { state.invalidateValue(inst.a); + state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_DOUBLE); } } break; @@ -361,6 +486,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& else { state.invalidateValue(inst.a); + state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_INT); } } break; @@ -377,6 +503,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& if (IrOp value = state.tryGetValue(inst.b); value.kind != IrOpKind::None) state.saveValue(inst.a, value); + + state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_TVALUE); } break; case IrCmd::JUMP_IF_TRUTHY: @@ -540,11 +668,12 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& // These instructions don't have an effect on register/memory state we are tracking case IrCmd::NOP: case IrCmd::LOAD_NODE_VALUE_TV: + case IrCmd::STORE_NODE_VALUE_TV: case IrCmd::LOAD_ENV: case IrCmd::GET_ARR_ADDR: case IrCmd::GET_SLOT_NODE_ADDR: case IrCmd::GET_HASH_NODE_ADDR: - case IrCmd::STORE_NODE_VALUE_TV: + break; case IrCmd::ADD_INT: case IrCmd::SUB_INT: case IrCmd::ADD_NUM: @@ -552,7 +681,6 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: - case IrCmd::POW_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: case IrCmd::UNM_NUM: @@ -562,6 +690,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: case IrCmd::NOT_ANY: + state.substituteOrRecord(inst, index); + break; case IrCmd::JUMP: case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_SLOT_MATCH: @@ -581,7 +711,6 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::RETURN: case IrCmd::COVERAGE: case IrCmd::SET_UPVALUE: - case IrCmd::SETLIST: // We don't track table state that this can invalidate case IrCmd::SET_SAVEDPC: // TODO: we may be able to remove some updates to PC case IrCmd::CLOSE_UPVALS: // Doesn't change memory that we track case IrCmd::CAPTURE: @@ -642,12 +771,21 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::INTERRUPT: state.invalidateUserCall(); break; + case IrCmd::SETLIST: + state.valueMap.clear(); // TODO: this can be relaxed when x64 emitInstSetList becomes aware of register allocator + break; case IrCmd::CALL: state.invalidateRegistersFrom(vmRegOp(inst.a)); state.invalidateUserCall(); + + // We cannot guarantee right now that all live values can be remeterialized from non-stack memory locations + // To prevent earlier values from being propagated to after the call, we have to clear the map + // TODO: remove only the values that don't have a guaranteed restore location + state.valueMap.clear(); break; case IrCmd::FORGLOOP: state.invalidateRegistersFrom(vmRegOp(inst.a) + 2); // Rn and Rn+1 are not modified + state.valueMap.clear(); // TODO: this can be relaxed when x64 emitInstForGLoop becomes aware of register allocator break; case IrCmd::FORGLOOP_FALLBACK: state.invalidateRegistersFrom(vmRegOp(inst.a) + 2); // Rn and Rn+1 are not modified @@ -656,6 +794,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::FORGPREP_XNEXT_FALLBACK: // This fallback only conditionally throws an exception break; + + // Full fallback instructions case IrCmd::FALLBACK_GETGLOBAL: state.invalidate(inst.b); state.invalidateUserCall(); @@ -678,7 +818,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::FALLBACK_PREPVARARGS: break; case IrCmd::FALLBACK_GETVARARGS: - state.invalidateRegistersFrom(vmRegOp(inst.b)); + state.invalidateRegisterRange(vmRegOp(inst.b), function.intOp(inst.c)); break; case IrCmd::FALLBACK_NEWCLOSURE: state.invalidate(inst.b); @@ -709,13 +849,17 @@ static void constPropInBlock(IrBuilder& build, IrBlock& block, ConstPropState& s constPropInInst(state, build, function, block, inst, index); } + + // Value numbering and load/store propagation is not performed between blocks + state.valueMap.clear(); } -static void constPropInBlockChain(IrBuilder& build, std::vector& visited, IrBlock* block) +static void constPropInBlockChain(IrBuilder& build, std::vector& visited, IrBlock* block, bool useValueNumbering) { IrFunction& function = build.function; ConstPropState state{function}; + state.useValueNumbering = useValueNumbering; while (block) { @@ -792,7 +936,7 @@ static std::vector collectDirectBlockJumpPath(IrFunction& function, st return path; } -static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited, IrBlock& startingBlock) +static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited, IrBlock& startingBlock, bool useValueNumbering) { IrFunction& function = build.function; @@ -822,6 +966,7 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited // Initialize state with the knowledge of our current block ConstPropState state{function}; + state.useValueNumbering = useValueNumbering; constPropInBlock(build, startingBlock, state); // Veryfy that target hasn't changed @@ -845,7 +990,7 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited constPropInBlock(build, linearBlock, state); } -void constPropInBlockChains(IrBuilder& build) +void constPropInBlockChains(IrBuilder& build, bool useValueNumbering) { IrFunction& function = build.function; @@ -859,11 +1004,11 @@ void constPropInBlockChains(IrBuilder& build) if (visited[function.getBlockIndex(block)]) continue; - constPropInBlockChain(build, visited, &block); + constPropInBlockChain(build, visited, &block, useValueNumbering); } } -void createLinearBlocks(IrBuilder& build) +void createLinearBlocks(IrBuilder& build, bool useValueNumbering) { // Go through internal block chains and outline them into a single new block. // Outlining will be able to linearize the execution, even if there was a jump to a block with multiple users, @@ -884,7 +1029,7 @@ void createLinearBlocks(IrBuilder& build) if (visited[function.getBlockIndex(block)]) continue; - tryCreateLinearBlock(build, visited, block); + tryCreateLinearBlock(build, visited, block, useValueNumbering); } } diff --git a/CodeGen/src/OptimizeFinalX64.cpp b/CodeGen/src/OptimizeFinalX64.cpp index dd31fcc..5ee626a 100644 --- a/CodeGen/src/OptimizeFinalX64.cpp +++ b/CodeGen/src/OptimizeFinalX64.cpp @@ -40,7 +40,6 @@ static void optimizeMemoryOperandsX64(IrFunction& function, IrBlock& block) case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: case IrCmd::MOD_NUM: - case IrCmd::POW_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: { diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index 047c1b6..ba4232a 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -79,6 +79,8 @@ public: void setDebugLine(int line); void pushDebugLocal(StringRef name, uint8_t reg, uint32_t startpc, uint32_t endpc); void pushDebugUpval(StringRef name); + + size_t getInstructionCount() const; uint32_t getDebugPC() const; void addDebugRemark(const char* format, ...) LUAU_PRINTF_ATTR(2, 3); diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 8e450f4..b5690ac 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -556,6 +556,11 @@ void BytecodeBuilder::pushDebugUpval(StringRef name) debugUpvals.push_back(upval); } +size_t BytecodeBuilder::getInstructionCount() const +{ + return insns.size(); +} + uint32_t BytecodeBuilder::getDebugPC() const { return uint32_t(insns.size()); diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 9478404..9eda214 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -25,6 +25,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) +LUAU_FASTFLAGVARIABLE(LuauCompileLimitInsns, false) + namespace Luau { @@ -33,6 +35,7 @@ using namespace Luau::Compile; static const uint32_t kMaxRegisterCount = 255; static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxLocalCount = 200; +static const uint32_t kMaxInstructionCount = 1'000'000'000; static const uint8_t kInvalidReg = 255; @@ -247,6 +250,9 @@ struct Compiler popLocals(0); + if (FFlag::LuauCompileLimitInsns && bytecode.getInstructionCount() > kMaxInstructionCount) + CompileError::raise(func->location, "Exceeded function instruction limit; split the function into parts to compile"); + bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size())); Function& f = functions[func]; diff --git a/Makefile b/Makefile index bbc66c2..aead3d3 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ ifneq ($(opt),) TESTS_ARGS+=-O$(opt) endif -OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS) +OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(REPL_CLI_OBJECTS) $(ANALYZE_CLI_OBJECTS) $(FUZZ_OBJECTS) EXECUTABLE_ALIASES = luau luau-analyze luau-tests # common flags diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 264388b..0f4df67 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,8 +17,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauBetterOOMHandling, false) - /* ** {====================================================== ** Error-recovery functions @@ -82,7 +80,7 @@ public: const char* what() const throw() override { // LUA_ERRRUN passes error object on the stack - if (status == LUA_ERRRUN || (status == LUA_ERRSYNTAX && !FFlag::LuauBetterOOMHandling)) + if (status == LUA_ERRRUN) if (const char* str = lua_tostring(L, -1)) return str; @@ -552,30 +550,21 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e // call user-defined error function (used in xpcall) if (ef) { - if (FFlag::LuauBetterOOMHandling) - { - // push error object to stack top if it's not already there - if (status != LUA_ERRRUN) - seterrorobj(L, status, L->top); + // push error object to stack top if it's not already there + if (status != LUA_ERRRUN) + seterrorobj(L, status, L->top); - // if errfunc fails, we fail with "error in error handling" or "not enough memory" - int err = luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)); + // if errfunc fails, we fail with "error in error handling" or "not enough memory" + int err = luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)); - // in general we preserve the status, except for cases when the error handler fails - // out of memory is treated specially because it's common for it to be cascading, in which case we preserve the code - if (err == 0) - errstatus = LUA_ERRRUN; - else if (status == LUA_ERRMEM && err == LUA_ERRMEM) - errstatus = LUA_ERRMEM; - else - errstatus = status = LUA_ERRERR; - } + // in general we preserve the status, except for cases when the error handler fails + // out of memory is treated specially because it's common for it to be cascading, in which case we preserve the code + if (err == 0) + errstatus = LUA_ERRRUN; + else if (status == LUA_ERRMEM && err == LUA_ERRMEM) + errstatus = LUA_ERRMEM; else - { - // if errfunc fails, we fail with "error in error handling" - if (luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)) != 0) - status = LUA_ERRERR; - } + errstatus = status = LUA_ERRERR; } // since the call failed with an error, we might have to reset the 'active' thread state @@ -597,7 +586,7 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e StkId oldtop = restorestack(L, old_top); luaF_close(L, oldtop); // close eventual pending closures - seterrorobj(L, FFlag::LuauBetterOOMHandling ? errstatus : status, oldtop); + seterrorobj(L, errstatus, oldtop); L->ci = restoreci(L, old_ci); L->base = L->ci->base; restore_stack_limit(L); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 4443be3..9bc624e 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -10,8 +10,6 @@ #include "ldebug.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauIntrosort, false) - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -389,7 +387,7 @@ static void sort_rec(lua_State* L, Table* t, int l, int u, int limit, SortPredic while (l < u) { // if the limit has been reached, quick sort is going over the permitted nlogn complexity, so we fall back to heap sort - if (FFlag::LuauIntrosort && limit == 0) + if (limit == 0) return sort_heap(L, t, l, u, pred); // sort elements a[l], a[(l+u)/2] and a[u] @@ -435,43 +433,20 @@ static void sort_rec(lua_State* L, Table* t, int l, int u, int limit, SortPredic // swap pivot a[p] with a[i], which is the new midpoint sort_swap(L, t, p, i); - if (FFlag::LuauIntrosort) - { - // adjust limit to allow 1.5 log2N recursive steps - limit = (limit >> 1) + (limit >> 2); + // adjust limit to allow 1.5 log2N recursive steps + limit = (limit >> 1) + (limit >> 2); - // a[l..i-1] <= a[i] == P <= a[i+1..u] - // sort smaller half recursively; the larger half is sorted in the next loop iteration - if (i - l < u - i) - { - sort_rec(L, t, l, i - 1, limit, pred); - l = i + 1; - } - else - { - sort_rec(L, t, i + 1, u, limit, pred); - u = i - 1; - } + // a[l..i-1] <= a[i] == P <= a[i+1..u] + // sort smaller half recursively; the larger half is sorted in the next loop iteration + if (i - l < u - i) + { + sort_rec(L, t, l, i - 1, limit, pred); + l = i + 1; } else { - // a[l..i-1] <= a[i] == P <= a[i+1..u] - // adjust so that smaller half is in [j..i] and larger one in [l..u] - if (i - l < u - i) - { - j = l; - i = i - 1; - l = i + 2; - } - else - { - j = i + 1; - i = u; - u = j - 2; - } - - // sort smaller half recursively; the larger half is sorted in the next loop iteration - sort_rec(L, t, j, i, limit, pred); + sort_rec(L, t, i + 1, u, limit, pred); + u = i - 1; } } } diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index 5dafb6b..082fe7a 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -354,6 +354,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPMath") SINGLE_COMPARE(frintm(d1, d2), 0x1E654041); SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041); + SINGLE_COMPARE(fcvt(s1, d2), 0x1E624041); + SINGLE_COMPARE(fcvt(d1, s2), 0x1E22C041); + SINGLE_COMPARE(fcvtzs(w1, d2), 0x1E780041); SINGLE_COMPARE(fcvtzs(x1, d2), 0x9E780041); SINGLE_COMPARE(fcvtzu(w1, d2), 0x1E790041); @@ -384,16 +387,20 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPLoadStore") SINGLE_COMPARE(str(d0, mem(x1, -7)), 0xFC1F9020); // load/store sizes + SINGLE_COMPARE(ldr(s0, x1), 0xBD400020); SINGLE_COMPARE(ldr(d0, x1), 0xFD400020); SINGLE_COMPARE(ldr(q0, x1), 0x3DC00020); + SINGLE_COMPARE(str(s0, x1), 0xBD000020); SINGLE_COMPARE(str(d0, x1), 0xFD000020); SINGLE_COMPARE(str(q0, x1), 0x3D800020); // load/store sizes x offset scaling SINGLE_COMPARE(ldr(q0, mem(x1, 16)), 0x3DC00420); SINGLE_COMPARE(ldr(d0, mem(x1, 16)), 0xFD400820); + SINGLE_COMPARE(ldr(s0, mem(x1, 16)), 0xBD401020); SINGLE_COMPARE(str(q0, mem(x1, 16)), 0x3D800420); SINGLE_COMPARE(str(d0, mem(x1, 16)), 0xFD000820); + SINGLE_COMPARE(str(s0, mem(x1, 16)), 0xBD001020); } TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPCompare") @@ -471,6 +478,8 @@ TEST_CASE("LogTest") build.fmov(d0, 0.25); build.tbz(x0, 5, l); + build.fcvt(s1, d2); + build.setLabel(l); build.ret(); @@ -502,6 +511,7 @@ TEST_CASE("LogTest") fcmp d0,#0 fmov d0,#0.25 tbz x0,#5,.L1 + fcvt s1,d2 .L1: ret )"; diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 2464d32..fc802e1 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -408,8 +408,6 @@ static int cxxthrow(lua_State* L) TEST_CASE("PCall") { - ScopedFastFlag sff("LuauBetterOOMHandling", true); - runConformance( "pcall.lua", [](lua_State* L) { @@ -504,7 +502,7 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) for (const auto& [name, prop] : t->props) { - populateRTTI(L, prop.type); + populateRTTI(L, prop.type()); lua_setfield(L, -2, name.c_str()); } } @@ -1012,8 +1010,6 @@ TEST_CASE("ApiCalls") lua_pop(L, 1); } - ScopedFastFlag sff("LuauBetterOOMHandling", true); - // lua_pcall on OOM { lua_getfield(L, LUA_GLOBALSINDEX, "largealloc"); diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index 0823eab..7b93398 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -23,8 +23,8 @@ void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code) { AstStatBlock* root = parse(code); dfg = std::make_unique(DataFlowGraphBuilder::build(root, NotNull{&ice})); - cgb = std::make_unique( - mainModule, &arena, NotNull(&moduleResolver), builtinTypes, NotNull(&ice), frontend.globals.globalScope, &logger, NotNull{dfg.get()}); + cgb = std::make_unique(mainModule, &arena, NotNull(&moduleResolver), builtinTypes, NotNull(&ice), + frontend.globals.globalScope, /*prepareModuleScope*/ nullptr, &logger, NotNull{dfg.get()}); cgb->visit(root); rootScope = cgb->rootScope; constraints = Luau::borrowConstraints(cgb->constraints); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index aba2891..c6fc475 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -21,7 +21,6 @@ static const char* mainModuleName = "MainModule"; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauOnDemandTypecheckers); extern std::optional randomSeed; // tests/main.cpp @@ -177,13 +176,7 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars if (FFlag::DebugLuauDeferredConstraintResolution) { ModulePtr module = Luau::check(*sourceModule, {}, builtinTypes, NotNull{&ice}, NotNull{&moduleResolver}, NotNull{&fileResolver}, - frontend.globals.globalScope, frontend.options); - - Luau::lint(sourceModule->root, *sourceModule->names, frontend.globals.globalScope, module.get(), sourceModule->hotcomments, {}); - } - else if (!FFlag::LuauOnDemandTypecheckers) - { - ModulePtr module = frontend.typeChecker_DEPRECATED.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); + frontend.globals.globalScope, /*prepareModuleScope*/ nullptr, frontend.options); Luau::lint(sourceModule->root, *sourceModule->names, frontend.globals.globalScope, module.get(), sourceModule->hotcomments, {}); } diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index f4b9f62..0b9c872 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1129,4 +1129,21 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "reexport_type_alias") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "module_scope_check") +{ + frontend.prepareModuleScope = [this](const ModuleName& name, const ScopePtr& scope, bool forAutocomplete) { + scope->bindings[Luau::AstName{"x"}] = Luau::Binding{frontend.globals.builtinTypes->numberType}; + }; + + fileResolver.source["game/A"] = R"( + local a = x + )"; + + CheckResult result = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(result); + + auto ty = requireType("game/A", "a"); + CHECK_EQ(toString(ty), "number"); +} + TEST_SUITE_END(); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 1419c95..f09f174 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -294,28 +294,31 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") build.beginBlock(block); build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::ADD_INT, build.constInt(10), build.constInt(20))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::ADD_INT, build.constInt(INT_MAX), build.constInt(1))); + build.inst(IrCmd::STORE_INT, build.vmReg(1), build.inst(IrCmd::ADD_INT, build.constInt(INT_MAX), build.constInt(1))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::SUB_INT, build.constInt(10), build.constInt(20))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::SUB_INT, build.constInt(INT_MIN), build.constInt(1))); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.inst(IrCmd::SUB_INT, build.constInt(10), build.constInt(20))); + build.inst(IrCmd::STORE_INT, build.vmReg(3), build.inst(IrCmd::SUB_INT, build.constInt(INT_MIN), build.constInt(1))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::ADD_NUM, build.constDouble(2), build.constDouble(5))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::SUB_NUM, build.constDouble(2), build.constDouble(5))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MUL_NUM, build.constDouble(2), build.constDouble(5))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::DIV_NUM, build.constDouble(2), build.constDouble(5))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MOD_NUM, build.constDouble(5), build.constDouble(2))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::POW_NUM, build.constDouble(5), build.constDouble(2))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MIN_NUM, build.constDouble(5), build.constDouble(2))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MAX_NUM, build.constDouble(5), build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(4), build.inst(IrCmd::ADD_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::SUB_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(6), build.inst(IrCmd::MUL_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(7), build.inst(IrCmd::DIV_NUM, build.constDouble(2), build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(8), build.inst(IrCmd::MOD_NUM, build.constDouble(5), build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(10), build.inst(IrCmd::MIN_NUM, build.constDouble(5), build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(11), build.inst(IrCmd::MAX_NUM, build.constDouble(5), build.constDouble(2))); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::UNM_NUM, build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(12), build.inst(IrCmd::UNM_NUM, build.constDouble(5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(13), build.inst(IrCmd::FLOOR_NUM, build.constDouble(2.5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(14), build.inst(IrCmd::CEIL_NUM, build.constDouble(2.5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(15), build.inst(IrCmd::ROUND_NUM, build.constDouble(2.5))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(16), build.inst(IrCmd::SQRT_NUM, build.constDouble(16))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(17), build.inst(IrCmd::ABS_NUM, build.constDouble(-4))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tnil), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tnumber), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(0))); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(1))); - - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::INT_TO_NUM, build.constInt(8))); + build.inst(IrCmd::STORE_INT, build.vmReg(18), build.inst(IrCmd::NOT_ANY, build.constTag(tnil), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)))); + build.inst( + IrCmd::STORE_INT, build.vmReg(19), build.inst(IrCmd::NOT_ANY, build.constTag(tnumber), build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)))); + build.inst(IrCmd::STORE_INT, build.vmReg(20), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(21), build.inst(IrCmd::NOT_ANY, build.constTag(tboolean), build.constInt(1))); build.inst(IrCmd::RETURN, build.constUint(0)); @@ -325,23 +328,233 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_INT R0, 30i - STORE_INT R0, -2147483648i - STORE_INT R0, -10i - STORE_INT R0, 2147483647i - STORE_DOUBLE R0, 7 - STORE_DOUBLE R0, -3 - STORE_DOUBLE R0, 10 - STORE_DOUBLE R0, 0.40000000000000002 - STORE_DOUBLE R0, 1 - STORE_DOUBLE R0, 25 - STORE_DOUBLE R0, 2 - STORE_DOUBLE R0, 5 - STORE_DOUBLE R0, -5 - STORE_INT R0, 1i - STORE_INT R0, 0i - STORE_INT R0, 1i - STORE_INT R0, 0i + STORE_INT R1, -2147483648i + STORE_INT R2, -10i + STORE_INT R3, 2147483647i + STORE_DOUBLE R4, 7 + STORE_DOUBLE R5, -3 + STORE_DOUBLE R6, 10 + STORE_DOUBLE R7, 0.40000000000000002 + STORE_DOUBLE R8, 1 + STORE_DOUBLE R10, 2 + STORE_DOUBLE R11, 5 + STORE_DOUBLE R12, -5 + STORE_DOUBLE R13, 2 + STORE_DOUBLE R14, 3 + STORE_DOUBLE R15, 3 + STORE_DOUBLE R16, 4 + STORE_DOUBLE R17, 4 + STORE_INT R18, 1i + STORE_INT R19, 0i + STORE_INT R20, 1i + STORE_INT R21, 0i + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumericConversions") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::INT_TO_NUM, build.constInt(8))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.inst(IrCmd::UINT_TO_NUM, build.constInt(0xdeee0000u))); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.inst(IrCmd::NUM_TO_INT, build.constDouble(200.0))); + build.inst(IrCmd::STORE_INT, build.vmReg(3), build.inst(IrCmd::NUM_TO_UINT, build.constDouble(3740139520.0))); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: STORE_DOUBLE R0, 8 + STORE_DOUBLE R1, 3740139520 + STORE_INT R2, 200i + STORE_INT R3, -554827776i + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumericConversionsBlocked") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INT, build.constDouble(1e20))); + build.inst(IrCmd::STORE_INT, build.vmReg(1), build.inst(IrCmd::NUM_TO_UINT, build.constDouble(-10))); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.inst(IrCmd::NUM_TO_INT, nan)); + build.inst(IrCmd::STORE_INT, build.vmReg(3), build.inst(IrCmd::NUM_TO_UINT, nan)); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %1 = NUM_TO_INT 1e+20 + STORE_INT R0, %1 + %3 = NUM_TO_UINT -10 + STORE_INT R1, %3 + %5 = NUM_TO_INT nan + STORE_INT R2, %5 + %7 = NUM_TO_UINT nan + STORE_INT R3, %7 + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp unk = build.inst(IrCmd::LOAD_INT, build.vmReg(0)); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::BITAND_UINT, build.constInt(0xfe), build.constInt(0xe))); + build.inst(IrCmd::STORE_INT, build.vmReg(1), build.inst(IrCmd::BITAND_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.inst(IrCmd::BITAND_UINT, build.constInt(0), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(3), build.inst(IrCmd::BITAND_UINT, unk, build.constInt(~0u))); + build.inst(IrCmd::STORE_INT, build.vmReg(4), build.inst(IrCmd::BITAND_UINT, build.constInt(~0u), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(5), build.inst(IrCmd::BITXOR_UINT, build.constInt(0xfe), build.constInt(0xe))); + build.inst(IrCmd::STORE_INT, build.vmReg(6), build.inst(IrCmd::BITXOR_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(7), build.inst(IrCmd::BITXOR_UINT, build.constInt(0), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(8), build.inst(IrCmd::BITXOR_UINT, unk, build.constInt(~0u))); + build.inst(IrCmd::STORE_INT, build.vmReg(9), build.inst(IrCmd::BITXOR_UINT, build.constInt(~0u), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITOR_UINT, build.constInt(0xf0), build.constInt(0xe))); + build.inst(IrCmd::STORE_INT, build.vmReg(11), build.inst(IrCmd::BITOR_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(12), build.inst(IrCmd::BITOR_UINT, build.constInt(0), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(13), build.inst(IrCmd::BITOR_UINT, unk, build.constInt(~0u))); + build.inst(IrCmd::STORE_INT, build.vmReg(14), build.inst(IrCmd::BITOR_UINT, build.constInt(~0u), unk)); + build.inst(IrCmd::STORE_INT, build.vmReg(15), build.inst(IrCmd::BITNOT_UINT, build.constInt(0xe))); + build.inst(IrCmd::STORE_INT, build.vmReg(16), build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xf0), build.constInt(4))); + build.inst(IrCmd::STORE_INT, build.vmReg(17), build.inst(IrCmd::BITLSHIFT_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(18), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xdeee0000u), build.constInt(8))); + build.inst(IrCmd::STORE_INT, build.vmReg(19), build.inst(IrCmd::BITRSHIFT_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(20), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xdeee0000u), build.constInt(8))); + build.inst(IrCmd::STORE_INT, build.vmReg(21), build.inst(IrCmd::BITARSHIFT_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(22), build.inst(IrCmd::BITLROTATE_UINT, build.constInt(0xdeee0000u), build.constInt(8))); + build.inst(IrCmd::STORE_INT, build.vmReg(23), build.inst(IrCmd::BITLROTATE_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(24), build.inst(IrCmd::BITRROTATE_UINT, build.constInt(0xdeee0000u), build.constInt(8))); + build.inst(IrCmd::STORE_INT, build.vmReg(25), build.inst(IrCmd::BITRROTATE_UINT, unk, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(26), build.inst(IrCmd::BITCOUNTLZ_UINT, build.constInt(0xff00))); + build.inst(IrCmd::STORE_INT, build.vmReg(27), build.inst(IrCmd::BITCOUNTLZ_UINT, build.constInt(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(28), build.inst(IrCmd::BITCOUNTRZ_UINT, build.constInt(0xff00))); + build.inst(IrCmd::STORE_INT, build.vmReg(29), build.inst(IrCmd::BITCOUNTRZ_UINT, build.constInt(0))); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_INT R0 + STORE_INT R0, 14i + STORE_INT R1, 0i + STORE_INT R2, 0i + STORE_INT R3, %0 + STORE_INT R4, %0 + STORE_INT R5, 240i + STORE_INT R6, %0 + STORE_INT R7, %0 + %17 = BITNOT_UINT %0 + STORE_INT R8, %17 + %19 = BITNOT_UINT %0 + STORE_INT R9, %19 + STORE_INT R10, 254i + STORE_INT R11, %0 + STORE_INT R12, %0 + STORE_INT R13, -1i + STORE_INT R14, -1i + STORE_INT R15, -15i + STORE_INT R16, 3840i + STORE_INT R17, %0 + STORE_INT R18, 14609920i + STORE_INT R19, %0 + STORE_INT R20, -2167296i + STORE_INT R21, %0 + STORE_INT R22, -301989666i + STORE_INT R23, %0 + STORE_INT R24, 14609920i + STORE_INT R25, %0 + STORE_INT R26, 16i + STORE_INT R27, 32i + STORE_INT R28, 8i + STORE_INT R29, 32i + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "Bit32Blocked") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xf), build.constInt(-10))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xf), build.constInt(140))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xf), build.constInt(-10))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITRSHIFT_UINT, build.constInt(0xf), build.constInt(140))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xf), build.constInt(-10))); + build.inst(IrCmd::STORE_INT, build.vmReg(10), build.inst(IrCmd::BITARSHIFT_UINT, build.constInt(0xf), build.constInt(140))); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = BITLSHIFT_UINT 15i, -10i + STORE_INT R10, %0 + %2 = BITLSHIFT_UINT 15i, 140i + STORE_INT R10, %2 + %4 = BITRSHIFT_UINT 15i, -10i + STORE_INT R10, %4 + %6 = BITRSHIFT_UINT 15i, 140i + STORE_INT R10, %6 + %8 = BITARSHIFT_UINT 15i, -10i + STORE_INT R10, %8 + %10 = BITARSHIFT_UINT 15i, 140i + STORE_INT R10, %10 + RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumericNan") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MIN_NUM, nan, build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MIN_NUM, build.constDouble(1), nan)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MAX_NUM, nan, build.constDouble(2))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.inst(IrCmd::MAX_NUM, build.constDouble(1), nan)); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constantFold(); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + STORE_DOUBLE R0, 2 + STORE_DOUBLE R0, nan + STORE_DOUBLE R0, 2 + STORE_DOUBLE R0, nan RETURN 0u )"); @@ -571,7 +784,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTagsAndValues") build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -589,10 +802,8 @@ bb_0: STORE_DOUBLE R2, %16 %18 = LOAD_TAG R0 STORE_TAG R9, %18 - %20 = LOAD_INT R1 - STORE_INT R10, %20 - %22 = LOAD_DOUBLE R2 - STORE_DOUBLE R11, %22 + STORE_INT R10, %14 + STORE_DOUBLE R11, %16 RETURN 0u )"); @@ -617,7 +828,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "PropagateThroughTvalue") build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -647,7 +858,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipCheckTag") build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -674,7 +885,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipOncePerBlockChecks") build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -713,7 +924,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTableState") build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -745,7 +956,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipUselessBarriers") build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -776,7 +987,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ConcatInvalidation") build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -825,7 +1036,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -858,7 +1069,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType") build.inst(IrCmd::RETURN, build.constUint(0)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -888,7 +1099,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagation") build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -920,7 +1131,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagationConflicting") build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -956,7 +1167,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TruthyTestRemoval") build.inst(IrCmd::RETURN, build.constUint(3)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -995,7 +1206,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FalsyTestRemoval") build.inst(IrCmd::RETURN, build.constUint(3)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1030,7 +1241,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagEqRemoval") build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1062,7 +1273,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval") build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1093,7 +1304,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumCmpRemoval") build.inst(IrCmd::RETURN, build.constUint(2)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1121,7 +1332,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataFlowsThroughDirectJumpToUniqueSuccessor build.inst(IrCmd::RETURN, build.constUint(1)); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1154,7 +1365,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataDoesNotFlowThroughDirectJumpToNonUnique build.inst(IrCmd::JUMP, block2); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1190,7 +1401,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "EntryBlockUseRemoval") build.inst(IrCmd::JUMP, entry); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1225,7 +1436,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval1") build.inst(IrCmd::JUMP, block); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1267,7 +1478,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") build.inst(IrCmd::JUMP, block); updateUseCounts(build.function); - constPropInBlockChains(build); + constPropInBlockChains(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1325,8 +1536,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); updateUseCounts(build.function); - constPropInBlockChains(build); - createLinearBlocks(build); + constPropInBlockChains(build, true); + createLinearBlocks(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1401,8 +1612,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NoPathExtractionForBlocksWithLiveOutValues" build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); updateUseCounts(build.function); - constPropInBlockChains(build); - createLinearBlocks(build); + constPropInBlockChains(build, true); + createLinearBlocks(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1453,8 +1664,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "InfiniteLoopInPathAnalysis") build.inst(IrCmd::JUMP, block2); updateUseCounts(build.function); - constPropInBlockChains(build); - createLinearBlocks(build); + constPropInBlockChains(build, true); + createLinearBlocks(build, true); CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: @@ -1468,6 +1679,38 @@ bb_1: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "PartialStoreInvalidation") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); // Should be reloaded + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tboolean)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); // Should be reloaded + + build.inst(IrCmd::RETURN, build.constUint(0)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_TVALUE R0 + STORE_TVALUE R1, %0 + STORE_DOUBLE R0, 0.5 + %3 = LOAD_TVALUE R0 + STORE_TVALUE R1, %3 + STORE_TAG R0, tboolean + %6 = LOAD_TVALUE R0 + STORE_TVALUE R1, %6 + RETURN 0u + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("Analysis"); @@ -1777,7 +2020,6 @@ bb_1: )"); } - TEST_CASE_FIXTURE(IrBuilderFixture, "SetTable") { IrOp entry = build.block(IrBlockKind::Internal); @@ -1799,3 +2041,192 @@ bb_0: } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("ValueNumbering"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "RemoveDuplicateCalculation") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op2 = build.inst(IrCmd::UNM_NUM, op1); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), op2); + IrOp op3 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); // Load propagation is tested here + IrOp op4 = build.inst(IrCmd::UNM_NUM, op3); // And allows value numbering to trigger here + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), op4); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + %1 = UNM_NUM %0 + STORE_DOUBLE R1, %1 + STORE_DOUBLE R2, %1 + RETURN R1, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "LateTableStateLink") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + IrOp tmp = build.inst(IrCmd::DUP_TABLE, build.vmReg(0)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(0), tmp); // Late tmp -> R0 link is tested here + IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); // Store to load propagation test + + build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); + build.inst(IrCmd::CHECK_READONLY, table, fallback); + + build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); + build.inst(IrCmd::CHECK_READONLY, table, fallback); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = DUP_TABLE R0 + STORE_POINTER R0, %0 + CHECK_NO_METATABLE %0, bb_fallback_1 + CHECK_READONLY %0, bb_fallback_1 + RETURN 0u + +bb_fallback_1: + RETURN 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "RegisterVersioning") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op2 = build.inst(IrCmd::UNM_NUM, op1); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), op2); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); // Doesn't prevent previous store propagation + IrOp op3 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); // No longer 'op1' + IrOp op4 = build.inst(IrCmd::UNM_NUM, op3); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), op4); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + %1 = UNM_NUM %0 + STORE_DOUBLE R0, %1 + STORE_TAG R0, tnumber + %5 = UNM_NUM %1 + STORE_DOUBLE R1, %5 + RETURN R0, 2i + +)"); +} + +// This can be relaxed in the future when SETLIST becomes aware of register allocator +TEST_CASE_FIXTURE(IrBuilderFixture, "SetListIsABlocker") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + build.inst(IrCmd::SETLIST); + IrOp op2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp sum = build.inst(IrCmd::ADD_NUM, op1, op2); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), sum); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + SETLIST + %2 = LOAD_DOUBLE R0 + %3 = ADD_NUM %0, %2 + STORE_DOUBLE R0, %3 + RETURN R0, 1i + +)"); +} + +// Luau call will reuse the same stack and spills will be lost +// However, in the future we might propagate values that can be rematerialized +TEST_CASE_FIXTURE(IrBuilderFixture, "CallIsABlocker") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + build.inst(IrCmd::CALL, build.vmReg(1), build.constInt(1), build.vmReg(2), build.constInt(1)); + IrOp op2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp sum = build.inst(IrCmd::ADD_NUM, op1, op2); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), sum); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + CALL R1, 1i, R2, 1i + %2 = LOAD_DOUBLE R0 + %3 = ADD_NUM %0, %2 + STORE_DOUBLE R1, %3 + RETURN R1, 2i + +)"); +} + +// While constant propagation correctly versions captured registers, IrValueLocationTracking doesn't (yet) +TEST_CASE_FIXTURE(IrBuilderFixture, "NoPropagationOfCapturedRegs") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::CAPTURE, build.vmReg(0), build.constBool(true)); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp sum = build.inst(IrCmd::ADD_NUM, op1, op2); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), sum); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +; captured regs: R0 + +bb_0: +; in regs: R0 + CAPTURE R0, true + %1 = LOAD_DOUBLE R0 + %2 = LOAD_DOUBLE R0 + %3 = ADD_NUM %1, %2 + STORE_DOUBLE R1, %3 + RETURN R1, 1i + +)"); +} + +TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 00cf5ca..22530a2 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -137,7 +137,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") CHECK_EQ(std::optional{"Cyclic"}, ttv->syntheticName); - TypeId methodType = ttv->props["get"].type; + TypeId methodType = ttv->props["get"].type(); REQUIRE(methodType != nullptr); const FunctionType* ftv = get(methodType); @@ -161,7 +161,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table_2") TypeId methodTy = src.addType(FunctionType{src.addTypePack({}), src.addTypePack({tableTy})}); - tt->props["get"].type = methodTy; + tt->props["get"].setType(methodTy); TypeArena dest; @@ -170,7 +170,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table_2") TableType* ctt = getMutable(cloneTy); REQUIRE(ctt); - TypeId clonedMethodType = ctt->props["get"].type; + TypeId clonedMethodType = ctt->props["get"].type(); REQUIRE(clonedMethodType); const FunctionType* cmf = get(clonedMethodType); @@ -199,7 +199,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") TableType* exportsTable = getMutable(*exports); REQUIRE(exportsTable != nullptr); - TypeId signType = exportsTable->props["sign"].type; + TypeId signType = exportsTable->props["sign"].type(); REQUIRE(signType != nullptr); CHECK(!isInArena(signType, module->interfaceTypes)); @@ -340,8 +340,8 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") { TableType* ttv = getMutable(nested); - ttv->props["a"].type = src.addType(TableType{}); - nested = ttv->props["a"].type; + ttv->props["a"].setType(src.addType(TableType{})); + nested = ttv->props["a"].type(); } TypeArena dest; @@ -411,7 +411,7 @@ return {} TypeId typeB = modBiter->second.type; TableType* tableB = getMutable(typeB); REQUIRE(tableB); - CHECK(typeA == tableB->props["q"].type); + CHECK(typeA == tableB->props["q"].type()); } TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_types_of_reexported_values") @@ -447,7 +447,7 @@ return exports REQUIRE(typeB); TableType* tableA = getMutable(*typeA); TableType* tableB = getMutable(*typeB); - CHECK(tableA->props["a"].type == tableB->props["b"].type); + CHECK(tableA->props["a"].type() == tableB->props["b"].type()); } TEST_SUITE_END(); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index fddab80..7130a71 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -170,7 +170,7 @@ TEST_CASE_FIXTURE(Fixture, "table_props_are_any") REQUIRE(ttv != nullptr); REQUIRE(ttv->props.count("foo")); - TypeId fooProp = ttv->props["foo"].type; + TypeId fooProp = ttv->props["foo"].type(); REQUIRE(fooProp != nullptr); CHECK_EQ(*fooProp, *builtinTypes->anyType); @@ -192,9 +192,9 @@ TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") TableType* ttv = getMutable(requireType("T")); REQUIRE_MESSAGE(ttv, "Should be a table: " << toString(requireType("T"))); - CHECK_EQ(*builtinTypes->anyType, *ttv->props["one"].type); - CHECK_EQ(*builtinTypes->anyType, *ttv->props["two"].type); - CHECK_MESSAGE(get(follow(ttv->props["three"].type)), "Should be a function: " << *ttv->props["three"].type); + CHECK_EQ(*builtinTypes->anyType, *ttv->props["one"].type()); + CHECK_EQ(*builtinTypes->anyType, *ttv->props["two"].type()); + CHECK_MESSAGE(get(follow(ttv->props["three"].type())), "Should be a function: " << *ttv->props["three"].type()); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_iterator_variables_are_any") diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index fd24539..160757e 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -514,14 +514,14 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") REQUIRE(tMeta2); REQUIRE(tMeta2->props.count("__index")); - const MetatableType* tMeta3 = get(tMeta2->props["__index"].type); + const MetatableType* tMeta3 = get(tMeta2->props["__index"].type()); REQUIRE(tMeta3); TableType* tMeta4 = getMutable(tMeta3->metatable); REQUIRE(tMeta4); REQUIRE(tMeta4->props.count("__index")); - TableType* tMeta5 = getMutable(tMeta4->props["__index"].type); + TableType* tMeta5 = getMutable(tMeta4->props["__index"].type()); REQUIRE(tMeta5); REQUIRE(tMeta5->props.count("one") > 0); @@ -529,9 +529,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") REQUIRE(tMeta6); REQUIRE(tMeta6->props.count("two") > 0); - ToStringResult oneResult = toStringDetailed(tMeta5->props["one"].type, opts); + ToStringResult oneResult = toStringDetailed(tMeta5->props["one"].type(), opts); - std::string twoResult = toString(tMeta6->props["two"].type, opts); + std::string twoResult = toString(tMeta6->props["two"].type(), opts); CHECK_EQ("(a) -> number", oneResult.name); CHECK_EQ("(b) -> number", twoResult); @@ -786,7 +786,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") TypeId parentTy = requireType("foo"); auto ttv = get(follow(parentTy)); - auto ftv = get(follow(ttv->props.at("method").type)); + auto ftv = get(follow(ttv->props.at("method").type())); CHECK_EQ("foo:method(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv)); } @@ -809,7 +809,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") TypeId parentTy = requireType("foo"); auto ttv = get(follow(parentTy)); REQUIRE_MESSAGE(ttv, "Expected a table but got " << toString(parentTy, opts)); - TypeId methodTy = follow(ttv->props.at("method").type); + TypeId methodTy = follow(ttv->props.at("method").type()); auto ftv = get(methodTy); REQUIRE_MESSAGE(ftv, "Expected a function but got " << toString(methodTy, opts)); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 3de5299..84b057d 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -330,7 +330,7 @@ TEST_CASE_FIXTURE(Fixture, "self_referential_type_alias") std::optional incr = get(oTable->props, "incr"); REQUIRE(incr); - const FunctionType* incrFunc = get(incr->type); + const FunctionType* incrFunc = get(incr->type()); REQUIRE(incrFunc); std::optional firstArg = first(incrFunc->argTypes); @@ -493,7 +493,7 @@ TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") TableType* exportsTable = getMutable(*exportsType); REQUIRE(exportsTable != nullptr); - TypeId n = exportsTable->props["n"].type; + TypeId n = exportsTable->props["n"].type(); REQUIRE(n != nullptr); CHECK(isInArena(n, mod.interfaceTypes)); @@ -548,10 +548,10 @@ TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definiti TableType* exportsTable = getMutable(*exportsType); REQUIRE(exportsTable != nullptr); - TypeId aType = exportsTable->props["a"].type; + TypeId aType = exportsTable->props["a"].type(); REQUIRE(aType); - TypeId bType = exportsTable->props["b"].type; + TypeId bType = exportsTable->props["b"].type(); REQUIRE(bType); CHECK(isInArena(recordType, mod.interfaceTypes)); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index c6766ca..687bc76 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -195,7 +195,7 @@ TEST_CASE_FIXTURE(Fixture, "assign_prop_to_table_by_calling_any_yields_any") REQUIRE(ttv); REQUIRE(ttv->props.count("prop")); - REQUIRE_EQ("any", toString(ttv->props["prop"].type)); + REQUIRE_EQ("any", toString(ttv->props["prop"].type())); } TEST_CASE_FIXTURE(Fixture, "quantify_any_does_not_bind_to_itself") diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 79d9108..07cf539 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -1031,7 +1031,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_persistent_typelevel_change") REQUIRE(mathTy); TableType* ttv = getMutable(mathTy); REQUIRE(ttv); - const FunctionType* ftv = get(ttv->props["frexp"].type); + const FunctionType* ftv = get(ttv->props["frexp"].type()); REQUIRE(ftv); auto original = ftv->level; diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 942ce19..9086a60 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -111,7 +111,7 @@ TEST_CASE_FIXTURE(Fixture, "generalize_table_property") const TableType* tt = get(follow(t)); REQUIRE(tt); - TypeId fooTy = tt->props.at("foo").type; + TypeId fooTy = tt->props.at("foo").type(); CHECK("(a) -> a" == toString(fooTy)); } @@ -156,7 +156,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") REQUIRE(ttv); REQUIRE(ttv->props.count("f")); - TypeId k = ttv->props["f"].type; + TypeId k = ttv->props["f"].type(); REQUIRE(k); } diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index b978481..5ab27f6 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -865,7 +865,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_table_method") REQUIRE(tTable != nullptr); REQUIRE(tTable->props.count("bar")); - TypeId barType = tTable->props["bar"].type; + TypeId barType = tTable->props["bar"].type(); REQUIRE(barType != nullptr); const FunctionType* ftv = get(follow(barType)); @@ -900,7 +900,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") std::optional fooProp = get(t->props, "foo"); REQUIRE(bool(fooProp)); - const FunctionType* foo = get(follow(fooProp->type)); + const FunctionType* foo = get(follow(fooProp->type())); REQUIRE(bool(foo)); std::optional ret_ = first(foo->retTypes); @@ -947,7 +947,7 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") std::optional methodProp = get(argTable->props, "method"); REQUIRE(bool(methodProp)); - const FunctionType* methodFunction = get(methodProp->type); + const FunctionType* methodFunction = get(methodProp->type()); REQUIRE(methodFunction != nullptr); std::optional methodArg = first(methodFunction->argTypes); diff --git a/tests/TypeInfer.negations.test.cpp b/tests/TypeInfer.negations.test.cpp index adf0365..af73607 100644 --- a/tests/TypeInfer.negations.test.cpp +++ b/tests/TypeInfer.negations.test.cpp @@ -2,6 +2,7 @@ #include "Fixture.h" +#include "Luau/ToString.h" #include "doctest.h" #include "Luau/Common.h" #include "ScopedFlags.h" @@ -47,4 +48,32 @@ TEST_CASE_FIXTURE(NegationFixture, "string_is_not_a_subtype_of_negated_string") LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "cofinite_strings_can_be_compared_for_equality") +{ + CheckResult result = check(R"( + function f(e) + if e == 'strictEqual' then + e = 'strictEqualObject' + end + if e == 'deepStrictEqual' or e == 'strictEqual' then + elseif e == 'notDeepStrictEqual' or e == 'notStrictEqual' then + end + return e + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK("(string) -> string" == toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(NegationFixture, "compare_cofinite_strings") +{ + CheckResult result = check(R"( +local u : Not<"a"> +local v : "b" +if u == v then +end +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 890e9b6..06cbe0c 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1749,4 +1749,36 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_annotations_arent_relevant_when_doing_d CHECK_EQ("nil", toString(requireTypeAtPosition({9, 28}))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "function_call_with_colon_after_refining_not_to_be_nil") +{ + ScopedFastFlag sff{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + --!strict + export type Observer = { + complete: ((self: Observer) -> ())?, + } + + local function _f(handler: Observer) + assert(handler.complete ~= nil) + handler:complete() -- incorrectly gives Value of type '((Observer) -> ())?' could be nil + handler.complete(handler) -- works fine, both forms should avoid the error + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "refinements_should_not_affect_assignment") +{ + CheckResult result = check(R"( + local a: unknown = true + if a == true then + a = 'not even remotely similar to a boolean' + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 468adc2..fcf2c8a 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -31,15 +31,15 @@ TEST_CASE_FIXTURE(Fixture, "basic") std::optional fooProp = get(tType->props, "foo"); REQUIRE(bool(fooProp)); - CHECK_EQ(PrimitiveType::String, getPrimitiveType(fooProp->type)); + CHECK_EQ(PrimitiveType::String, getPrimitiveType(fooProp->type())); std::optional bazProp = get(tType->props, "baz"); REQUIRE(bool(bazProp)); - CHECK_EQ(PrimitiveType::Number, getPrimitiveType(bazProp->type)); + CHECK_EQ(PrimitiveType::Number, getPrimitiveType(bazProp->type())); std::optional quuxProp = get(tType->props, "quux"); REQUIRE(bool(quuxProp)); - CHECK_EQ(PrimitiveType::NilType, getPrimitiveType(quuxProp->type)); + CHECK_EQ(PrimitiveType::NilType, getPrimitiveType(quuxProp->type())); } TEST_CASE_FIXTURE(Fixture, "augment_table") @@ -65,7 +65,7 @@ TEST_CASE_FIXTURE(Fixture, "augment_nested_table") REQUIRE(tType != nullptr); REQUIRE(tType->props.find("p") != tType->props.end()); - const TableType* pType = get(tType->props["p"].type); + const TableType* pType = get(tType->props["p"].type()); REQUIRE(pType != nullptr); CHECK("{ p: { foo: string } }" == toString(requireType("t"), {true})); @@ -159,7 +159,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_member_function") std::optional fooProp = get(tableType->props, "foo"); REQUIRE(bool(fooProp)); - const FunctionType* methodType = get(follow(fooProp->type)); + const FunctionType* methodType = get(follow(fooProp->type())); REQUIRE(methodType != nullptr); } @@ -173,7 +173,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_member_function_2") std::optional uProp = get(tableType->props, "U"); REQUIRE(bool(uProp)); - TypeId uType = uProp->type; + TypeId uType = uProp->type(); const TableType* uTable = get(uType); REQUIRE(uTable != nullptr); @@ -181,7 +181,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_member_function_2") std::optional fooProp = get(uTable->props, "foo"); REQUIRE(bool(fooProp)); - const FunctionType* methodType = get(follow(fooProp->type)); + const FunctionType* methodType = get(follow(fooProp->type())); REQUIRE(methodType != nullptr); std::vector methodArgs = flatten(methodType->argTypes).first; @@ -935,7 +935,7 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s REQUIRE(tableType->indexer == std::nullopt); REQUIRE(0 != tableType->props.count("a")); - TypeId propertyA = tableType->props["a"].type; + TypeId propertyA = tableType->props["a"].type(); REQUIRE(propertyA != nullptr); CHECK_EQ(*builtinTypes->stringType, *propertyA); } @@ -1925,8 +1925,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") REQUIRE(ttv); REQUIRE(ttv->props.count("new")); Property& prop = ttv->props["new"]; - REQUIRE(prop.type); - const FunctionType* ftv = get(follow(prop.type)); + REQUIRE(prop.type()); + const FunctionType* ftv = get(follow(prop.type())); REQUIRE(ftv); const TypePack* res = get(follow(ftv->retTypes)); REQUIRE(res); @@ -2647,7 +2647,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dont_quantify_table_that_belongs_to_outer_sc REQUIRE(counterType); REQUIRE(counterType->props.count("new")); - const FunctionType* newType = get(follow(counterType->props["new"].type)); + const FunctionType* newType = get(follow(counterType->props["new"].type())); REQUIRE(newType); std::optional newRetType = *first(newType->retTypes); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 5a9c77d..fa52a74 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -101,7 +101,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") TableType{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_NE(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); state.tryUnify(&tableTwo, &tableOne); @@ -110,7 +110,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") state.log.commit(); - CHECK_EQ(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_EQ(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); } TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") @@ -129,14 +129,14 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") TableState::Unsealed}, }}; - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_NE(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); state.tryUnify(&tableTwo, &tableOne); CHECK(state.failure); CHECK_EQ(1, state.errors.size()); - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + CHECK_NE(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); } TEST_CASE_FIXTURE(TryUnifyFixture, "uninhabited_intersection_sub_never") diff --git a/tools/faillist.txt b/tools/faillist.txt index 18bc0c7..38fa7f5 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -12,8 +12,6 @@ BuiltinTests.assert_returns_false_and_string_iff_it_knows_the_first_argument_can BuiltinTests.bad_select_should_not_crash BuiltinTests.dont_add_definitions_to_persistent_types BuiltinTests.gmatch_definition -BuiltinTests.match_capture_types -BuiltinTests.match_capture_types2 BuiltinTests.math_max_checks_for_numbers BuiltinTests.select_slightly_out_of_range BuiltinTests.select_way_out_of_range @@ -33,6 +31,7 @@ GenericsTests.bound_tables_do_not_clone_original_fields GenericsTests.check_mutual_generic_functions GenericsTests.correctly_instantiate_polymorphic_member_functions GenericsTests.do_not_infer_generic_functions +GenericsTests.dont_unify_bound_types GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_many GenericsTests.generic_functions_should_be_memory_safe @@ -86,7 +85,6 @@ TableTests.give_up_after_one_metatable_index_look_up TableTests.indexer_on_sealed_table_must_unify_with_free_table TableTests.indexing_from_a_table_should_prefer_properties_when_possible TableTests.inequality_operators_imply_exactly_matching_types -TableTests.infer_array_2 TableTests.inferred_return_type_of_free_table TableTests.instantiate_table_cloning_3 TableTests.leaking_bad_metatable_errors @@ -138,7 +136,6 @@ TypeInfer.fuzz_free_table_type_change_during_index_check TypeInfer.infer_assignment_value_types_mutable_lval TypeInfer.no_stack_overflow_from_isoptional TypeInfer.no_stack_overflow_from_isoptional2 -TypeInfer.should_be_able_to_infer_this_without_stack_overflowing TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer @@ -207,11 +204,9 @@ TypePackTests.unify_variadic_tails_in_arguments TypePackTests.variadic_packs TypeSingletons.function_call_with_singletons TypeSingletons.function_call_with_singletons_mismatch -TypeSingletons.indexing_on_union_of_string_singletons TypeSingletons.no_widening_from_callsites TypeSingletons.return_type_of_f_is_not_widened TypeSingletons.table_properties_type_error_escapes -TypeSingletons.taking_the_length_of_union_of_string_singleton TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeSingletons.widening_happens_almost_everywhere UnionTypes.generic_function_with_optional_arg diff --git a/tools/test_dcr.py b/tools/test_dcr.py index 817d083..208096f 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -108,10 +108,10 @@ def main(): help="Write a new faillist.txt after running tests.", ) parser.add_argument( - "--lti", - dest="lti", + "--rwp", + dest="rwp", action="store_true", - help="Run the tests with local type inference enabled.", + help="Run the tests with read-write properties enabled.", ) parser.add_argument("--randomize", action="store_true", help="Pick a random seed") @@ -126,17 +126,17 @@ def main(): args = parser.parse_args() - if args.write and args.lti: + if args.write and args.rwp: print_stderr( - "Cannot run test_dcr.py with --write *and* --lti. You don't want to commit local type inference faillist.txt yet." + "Cannot run test_dcr.py with --write *and* --rwp. You don't want to commit local type inference faillist.txt yet." ) sys.exit(1) failList = loadFailList() flags = ["true", "DebugLuauDeferredConstraintResolution"] - if args.lti: - flags.append("DebugLuauLocalTypeInference") + if args.rwp: + flags.append("DebugLuauReadWriteProperties") commandLine = [args.path, "--reporters=xml", "--fflags=" + ",".join(flags)]