diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 8159b76..18ff309 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -106,6 +106,7 @@ struct FunctionCallConstraint TypePackId argsPack; TypePackId result; class AstExprCall* callSite; + std::vector> discriminantTypes; }; // result ~ prim ExpectedType SomeSingletonType MultitonType @@ -180,7 +181,7 @@ struct Constraint Constraint& operator=(const Constraint&) = delete; NotNull scope; - Location location; // TODO: Extract this out into only the constraints that needs a location. Not all constraints needs locations. + Location location; ConstraintV c; std::vector> dependencies; diff --git a/Analysis/include/Luau/DcrLogger.h b/Analysis/include/Luau/DcrLogger.h index 30d2e15..45c84c6 100644 --- a/Analysis/include/Luau/DcrLogger.h +++ b/Analysis/include/Luau/DcrLogger.h @@ -65,6 +65,7 @@ struct ConstraintBlock struct ConstraintSnapshot { std::string stringification; + Location location; std::vector blocks; }; diff --git a/Analysis/include/Luau/Refinement.h b/Analysis/include/Luau/Refinement.h index 3e1f234..e7d3cf2 100644 --- a/Analysis/include/Luau/Refinement.h +++ b/Analysis/include/Luau/Refinement.h @@ -11,14 +11,20 @@ namespace Luau struct Type; using TypeId = const Type*; +struct Variadic; struct Negation; struct Conjunction; struct Disjunction; struct Equivalence; struct Proposition; -using Refinement = Variant; +using Refinement = Variant; using RefinementId = Refinement*; // Can and most likely is nullptr. +struct Variadic +{ + std::vector refinements; +}; + struct Negation { RefinementId refinement; @@ -56,13 +62,15 @@ const T* get(RefinementId refinement) struct RefinementArena { - TypedAllocator allocator; - + RefinementId variadic(const std::vector& refis); RefinementId negation(RefinementId refinement); RefinementId conjunction(RefinementId lhs, RefinementId rhs); RefinementId disjunction(RefinementId lhs, RefinementId rhs); RefinementId equivalence(RefinementId lhs, RefinementId rhs); RefinementId proposition(DefId def, TypeId discriminantTy); + +private: + TypedAllocator allocator; }; } // namespace Luau diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 6c8e1bc..00e6d6c 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -263,15 +263,12 @@ using DcrMagicFunction = bool (*)(MagicFunctionCallContext); struct MagicRefinementContext { - ScopePtr scope; - NotNull cgb; - NotNull dfg; - NotNull refinementArena; - std::vector argumentRefinements; + NotNull scope; const class AstExprCall* callSite; + std::vector> discriminantTypes; }; -using DcrMagicRefinement = std::vector (*)(const MagicRefinementContext&); +using DcrMagicRefinement = void (*)(const MagicRefinementContext&); struct FunctionType { @@ -304,8 +301,8 @@ struct FunctionType TypePackId argTypes; TypePackId retTypes; MagicFunction magicFunction = nullptr; - DcrMagicFunction dcrMagicFunction = nullptr; // Fired only while solving constraints - DcrMagicRefinement dcrMagicRefinement = nullptr; // Fired only while generating constraints + DcrMagicFunction dcrMagicFunction = nullptr; + DcrMagicRefinement dcrMagicRefinement = nullptr; bool hasSelf; bool hasNoGenerics = false; }; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 4c2d38a..d748a1f 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -57,6 +57,12 @@ public: } }; +enum class ValueContext +{ + LValue, + RValue +}; + // All Types are retained via Environment::types. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker @@ -119,14 +125,14 @@ struct TypeChecker std::optional expectedType); // Returns the type of the lvalue. - TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr); + TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx); // Returns the type of the lvalue. - TypeId checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx); TypeId checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); TypeId checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); - TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); - TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr, ValueContext ctx); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr, ValueContext ctx); TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 006df6e..c17169f 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -42,8 +42,6 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); static bool dcrMagicFunctionPack(MagicFunctionCallContext context); -static std::vector dcrMagicRefinementAssert(const MagicRefinementContext& context); - TypeId makeUnion(TypeArena& arena, std::vector&& types) { return arena.addType(UnionType{std::move(types)}); @@ -422,7 +420,6 @@ void registerBuiltinGlobals(Frontend& frontend) } attachMagicFunction(getGlobalBinding(frontend, "assert"), magicFunctionAssert); - attachDcrMagicRefinement(getGlobalBinding(frontend, "assert"), dcrMagicRefinementAssert); attachMagicFunction(getGlobalBinding(frontend, "setmetatable"), magicFunctionSetMetaTable); attachMagicFunction(getGlobalBinding(frontend, "select"), magicFunctionSelect); attachDcrMagicFunction(getGlobalBinding(frontend, "select"), dcrMagicFunctionSelect); @@ -624,15 +621,6 @@ static std::optional> magicFunctionAssert( return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; } -static std::vector dcrMagicRefinementAssert(const MagicRefinementContext& ctx) -{ - if (ctx.argumentRefinements.empty()) - return {}; - - ctx.cgb->applyRefinements(ctx.scope, ctx.callSite->location, ctx.argumentRefinements[0]); - return {}; -} - static std::optional> magicFunctionPack( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 3dd8df8..ff8e0c3 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -438,6 +438,7 @@ TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool alwaysCl clone.genericPacks = ftv->genericPacks; clone.magicFunction = ftv->magicFunction; clone.dcrMagicFunction = ftv->dcrMagicFunction; + clone.dcrMagicRefinement = ftv->dcrMagicRefinement; clone.tags = ftv->tags; clone.argNames = ftv->argNames; result = dest.addType(std::move(clone)); diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 09182f5..f773863 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -15,7 +15,6 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(LuauNegatedClassTypes); -LUAU_FASTFLAG(LuauScopelessModule); LUAU_FASTFLAG(SupportTypeAliasGoToDeclaration); namespace Luau @@ -96,6 +95,18 @@ static std::optional matchTypeGuard(const AstExprBinary* binary) }; } +static bool matchAssert(const AstExprCall& call) +{ + if (call.args.size < 1) + return false; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != "assert") + return false; + + return true; +} + namespace { @@ -198,6 +209,11 @@ static void computeRefinement(const ScopePtr& scope, RefinementId refinement, st if (!refinement) return; + else if (auto variadic = get(refinement)) + { + for (RefinementId refi : variadic->refinements) + computeRefinement(scope, refi, refis, sense, arena, eq, constraints); + } else if (auto negation = get(refinement)) return computeRefinement(scope, negation->refinement, refis, !sense, arena, eq, constraints); else if (auto conjunction = get(refinement)) @@ -546,8 +562,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (ModulePtr module = moduleResolver->getModule(moduleInfo->name)) { - scope->importedTypeBindings[name] = - FFlag::LuauScopelessModule ? module->exportedTypeBindings : module->getModuleScope()->exportedTypeBindings; + scope->importedTypeBindings[name] = module->exportedTypeBindings; if (FFlag::SupportTypeAliasGoToDeclaration) scope->importedModules[name] = moduleName; } @@ -697,18 +712,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct } else if (AstExprIndexName* indexName = function->name->as()) { - TypeId containingTableType = check(scope, indexName->expr).ty; - - // TODO look into stack utilization. This is probably ok because it scales with AST depth. - TypeId prospectiveTableType = arena->addType(TableType{TableState::Unsealed, TypeLevel{}, scope.get()}); - - NotNull prospectiveTable{getMutable(prospectiveTableType)}; - - Property& prop = prospectiveTable->props[indexName->index.value]; - prop.type = generalizedType; - prop.location = function->name->location; - - addConstraint(scope, indexName->location, SubtypeConstraint{containingTableType, prospectiveTableType}); + TypeId lvalueType = checkLValue(scope, indexName); + // TODO figure out how to populate the location field of the table Property. + addConstraint(scope, indexName->location, SubtypeConstraint{lvalueType, generalizedType}); } else if (AstExprError* err = function->name->as()) { @@ -783,13 +789,13 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement auto [_, refinement] = check(condScope, ifStatement->condition, std::nullopt); ScopePtr thenScope = childScope(ifStatement->thenbody, scope); - applyRefinements(thenScope, Location{}, refinement); + applyRefinements(thenScope, ifStatement->condition->location, refinement); visit(thenScope, ifStatement->thenbody); if (ifStatement->elsebody) { ScopePtr elseScope = childScope(ifStatement->elsebody, scope); - applyRefinements(elseScope, Location{}, refinementArena.negation(refinement)); + applyRefinements(elseScope, ifStatement->elseLocation.value_or(ifStatement->condition->location), refinementArena.negation(refinement)); visit(elseScope, ifStatement->elsebody); } } @@ -1059,6 +1065,10 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes) { std::vector exprArgs; + + std::vector returnRefinements; + std::vector> discriminantTypes; + if (call->self) { AstExprIndexName* indexExpr = call->func->as(); @@ -1066,13 +1076,37 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa ice->ice("method call expression has no 'self'"); exprArgs.push_back(indexExpr->expr); + + if (auto def = dfg->getDef(indexExpr->expr)) + { + TypeId discriminantTy = arena->addType(BlockedType{}); + returnRefinements.push_back(refinementArena.proposition(*def, discriminantTy)); + discriminantTypes.push_back(discriminantTy); + } + else + discriminantTypes.push_back(std::nullopt); + } + + for (AstExpr* arg : call->args) + { + exprArgs.push_back(arg); + + if (auto def = dfg->getDef(arg)) + { + TypeId discriminantTy = arena->addType(BlockedType{}); + returnRefinements.push_back(refinementArena.proposition(*def, discriminantTy)); + discriminantTypes.push_back(discriminantTy); + } + else + discriminantTypes.push_back(std::nullopt); } - exprArgs.insert(exprArgs.end(), call->args.begin(), call->args.end()); Checkpoint startCheckpoint = checkpoint(this); TypeId fnType = check(scope, call->func).ty; Checkpoint fnEndCheckpoint = checkpoint(this); + module->astOriginalCallTypes[call->func] = fnType; + TypePackId expectedArgPack = arena->freshTypePack(scope.get()); TypePackId expectedRetPack = arena->freshTypePack(scope.get()); TypeId expectedFunctionType = arena->addType(FunctionType{expectedArgPack, expectedRetPack}); @@ -1129,7 +1163,11 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa argumentRefinements.push_back(refinement); } else - argTail = checkPack(scope, arg, {}).tp; // FIXME? not sure about expectedTypes here + { + auto [tp, refis] = checkPack(scope, arg, {}); // FIXME? not sure about expectedTypes here + argTail = tp; + argumentRefinements.insert(argumentRefinements.end(), refis.begin(), refis.end()); + } } Checkpoint argEndCheckpoint = checkpoint(this); @@ -1140,13 +1178,6 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa constraint->dependencies.push_back(extractArgsConstraint); }); - std::vector returnRefinements; - if (auto ftv = get(follow(fnType)); ftv && ftv->dcrMagicRefinement) - { - MagicRefinementContext ctx{scope, NotNull{this}, dfg, NotNull{&refinementArena}, std::move(argumentRefinements), call}; - returnRefinements = ftv->dcrMagicRefinement(ctx); - } - if (matchSetmetatable(*call)) { TypePack argTailPack; @@ -1171,12 +1202,12 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa scope->dcrRefinements[*def] = resultTy; // TODO: typestates: track this as an assignment } - - return InferencePack{arena->addTypePack({resultTy}), std::move(returnRefinements)}; + return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}}; } else { - module->astOriginalCallTypes[call->func] = fnType; + if (matchAssert(*call) && !argumentRefinements.empty()) + applyRefinements(scope, call->args.data[0]->location, argumentRefinements[0]); TypeId instantiatedType = arena->addType(BlockedType{}); // TODO: How do expectedTypes play into this? Do they? @@ -1200,6 +1231,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa argPack, rets, call, + std::move(discriminantTypes), }); // We force constraints produced by checking function arguments to wait @@ -1211,7 +1243,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa fcc->dependencies.emplace_back(constraint.get()); }); - return InferencePack{rets, std::move(returnRefinements)}; + return InferencePack{rets, {refinementArena.variadic(returnRefinements)}}; } } @@ -1386,74 +1418,10 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprGlobal* gl return Inference{builtinTypes->errorRecoveryType()}; } -static std::optional lookupProp(TypeId ty, const std::string& propName, NotNull arena) -{ - ty = follow(ty); - - if (auto ctv = get(ty)) - { - if (auto prop = lookupClassProp(ctv, propName)) - return prop->type; - } - else if (auto ttv = get(ty)) - { - if (auto it = ttv->props.find(propName); it != ttv->props.end()) - return it->second.type; - } - else if (auto utv = get(ty)) - { - std::vector types; - - for (TypeId ty : utv) - { - if (auto prop = lookupProp(ty, propName, arena)) - { - if (std::find(begin(types), end(types), *prop) == end(types)) - types.push_back(*prop); - } - else - return std::nullopt; - } - - if (types.size() == 1) - return types[0]; - else - return arena->addType(IntersectionType{std::move(types)}); - } - else if (auto utv = get(ty)) - { - std::vector types; - - for (TypeId ty : utv) - { - if (auto prop = lookupProp(ty, propName, arena)) - { - if (std::find(begin(types), end(types), *prop) == end(types)) - types.push_back(*prop); - } - else - return std::nullopt; - } - - if (types.size() == 1) - return types[0]; - else - return arena->addType(UnionType{std::move(types)}); - } - - return std::nullopt; -} - Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIndexName* indexName) { TypeId obj = check(scope, indexName->expr).ty; - - // HACK: We need to return the actual type for type refinements so that it can invoke the dcrMagicRefinement function. - TypeId result; - if (auto prop = lookupProp(obj, indexName->index.value, arena)) - result = *prop; - else - result = freshType(scope); + TypeId result = freshType(scope); std::optional def = dfg->getDef(indexName); if (def) @@ -1723,11 +1691,6 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) TypeId updatedType = arena->addType(BlockedType{}); addConstraint(scope, expr->location, SetPropConstraint{updatedType, subjectType, std::move(segmentStrings), propTy}); - std::optional def = dfg->getDef(sym); - LUAU_ASSERT(def); - symbolScope->bindings[sym].typeId = updatedType; - symbolScope->dcrRefinements[*def] = updatedType; - TypeId prevSegmentTy = updatedType; for (size_t i = 0; i < segments.size(); ++i) { @@ -1739,7 +1702,16 @@ TypeId ConstraintGraphBuilder::checkLValue(const ScopePtr& scope, AstExpr* expr) module->astTypes[expr] = prevSegmentTy; module->astTypes[e] = updatedType; - // astTypes[expr] = propTy; + + symbolScope->bindings[sym].typeId = updatedType; + + std::optional def = dfg->getDef(sym); + if (def) + { + // This can fail if the user is erroneously trying to augment a builtin + // table like os or string. + symbolScope->dcrRefinements[*def] = updatedType; + } return propTy; } @@ -1765,9 +1737,30 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp addConstraint(scope, location, SubtypeConstraint{ttv->indexer->indexResultType, currentResultType}); }; + std::optional annotatedKeyType; + std::optional annotatedIndexResultType; + + if (expectedType) + { + if (const TableType* ttv = get(follow(*expectedType))) + { + if (ttv->indexer) + { + annotatedKeyType.emplace(follow(ttv->indexer->indexType)); + annotatedIndexResultType.emplace(ttv->indexer->indexResultType); + } + } + } + + bool isIndexedResultType = false; + std::optional pinnedIndexResultType; + + for (const AstExprTable::Item& item : expr->items) { std::optional expectedValueType; + if (item.kind == AstExprTable::Item::Kind::General || item.kind == AstExprTable::Item::Kind::List) + isIndexedResultType = true; if (item.key && expectedType) { @@ -1786,14 +1779,39 @@ Inference ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTable* exp } } - TypeId itemTy = check(scope, item.value, expectedValueType).ty; + + // We'll resolve the expected index result type here with the following priority: + // 1. Record table types - in which key, value pairs must be handled on a k,v pair basis. + // In this case, the above if-statement will populate expectedValueType + // 2. Someone places an annotation on a General or List table + // Trust the annotation and have the solver inform them if they get it wrong + // 3. Someone omits the annotation on a general or List table + // Use the type of the first indexResultType as the expected type + std::optional checkExpectedIndexResultType; + if (expectedValueType) + { + checkExpectedIndexResultType = expectedValueType; + } + else if (annotatedIndexResultType) + { + checkExpectedIndexResultType = annotatedIndexResultType; + } + else if (pinnedIndexResultType) + { + checkExpectedIndexResultType = pinnedIndexResultType; + } + + TypeId itemTy = check(scope, item.value, checkExpectedIndexResultType).ty; + + if (isIndexedResultType && !pinnedIndexResultType) + pinnedIndexResultType = itemTy; if (item.key) { // Even though we don't need to use the type of the item's key if // it's a string constant, we still want to check it to populate // astTypes. - TypeId keyTy = check(scope, item.key).ty; + TypeId keyTy = check(scope, item.key, annotatedKeyType).ty; if (AstExprConstantString* key = item.key->as()) { diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 96d16c4..76fd0bc 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); -LUAU_FASTFLAG(LuauScopelessModule); namespace Luau { @@ -424,9 +423,7 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo LUAU_ASSERT(false); if (success) - { unblock(constraint); - } return success; } @@ -1129,6 +1126,28 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull std::optional { + auto it = begin(t); + auto endIt = end(t); + + LUAU_ASSERT(it != endIt); + TypeId fst = follow(*it); + while (it != endIt) + { + if (follow(*it) != fst) + return std::nullopt; + ++it; + } + + return fst; + }; + + // Sometimes the `fn` type is a union/intersection, but whose constituents are all the same pointer. + if (auto ut = get(fn)) + fn = collapse(ut).value_or(fn); + else if (auto it = get(fn)) + fn = collapse(it).value_or(fn); + // We don't support magic __call metamethods. if (std::optional callMm = findMetatableEntry(builtinTypes, errors, fn, "__call", constraint->location)) { @@ -1140,69 +1159,73 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulladdType(BlockedType{}); TypeId inferredFnType = arena->addType(FunctionType(TypeLevel{}, constraint->scope.get(), arena->addTypePack(TypePack{args, {}}), c.result)); - // Alter the inner constraints. - LUAU_ASSERT(c.innerConstraints.size() == 2); - - // Anything that is blocked on this constraint must also be blocked on our inner constraints - auto blockedIt = blocked.find(constraint.get()); - if (blockedIt != blocked.end()) - { - for (const auto& ic : c.innerConstraints) - { - for (const auto& blockedConstraint : blockedIt->second) - block(ic, blockedConstraint); - } - } - asMutable(*c.innerConstraints.at(0)).c = InstantiationConstraint{instantiatedType, *callMm}; asMutable(*c.innerConstraints.at(1)).c = SubtypeConstraint{inferredFnType, instantiatedType}; - unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); - asMutable(c.result)->ty.emplace(constraint->scope); - unblock(c.result); - return true; - } - - const FunctionType* ftv = get(fn); - bool usedMagic = false; - - if (ftv && ftv->dcrMagicFunction != nullptr) - { - usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull(this), c.callSite, c.argsPack, result}); - } - - if (usedMagic) - { - // There are constraints that are blocked on these constraints. If we - // are never going to even examine them, then we should not block - // anything else on them. - // - // TODO CLI-58842 -#if 0 - for (auto& c: c.innerConstraints) - unblock(c); -#endif } else { - // Anything that is blocked on this constraint must also be blocked on our inner constraints - auto blockedIt = blocked.find(constraint.get()); - if (blockedIt != blocked.end()) + const FunctionType* ftv = get(fn); + bool usedMagic = false; + + if (ftv) { - for (const auto& ic : c.innerConstraints) - { - for (const auto& blockedConstraint : blockedIt->second) - block(ic, blockedConstraint); - } + if (ftv->dcrMagicFunction) + usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull(this), c.callSite, c.argsPack, result}); + + if (ftv->dcrMagicRefinement) + ftv->dcrMagicRefinement(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes}); } - unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); - asMutable(c.result)->ty.emplace(constraint->scope); + if (usedMagic) + { + // There are constraints that are blocked on these constraints. If we + // are never going to even examine them, then we should not block + // anything else on them. + // + // TODO CLI-58842 +#if 0 + for (auto& c: c.innerConstraints) + unblock(c); +#endif + } + else + asMutable(c.result)->ty.emplace(constraint->scope); } - unblock(c.result); + for (std::optional ty : c.discriminantTypes) + { + if (!ty || !isBlocked(*ty)) + continue; + // We use `any` here because the discriminant type may be pointed at by both branches, + // where the discriminant type is not negated, and the other where it is negated, i.e. + // `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` + // v.s. + // `any ~ any` and `~any ~ any`, so `T & any ~ T` and `T & ~any ~ T` + // + // In practice, users cannot negate `any`, so this is an implementation detail we can always change. + *asMutable(follow(*ty)) = BoundType{builtinTypes->anyType}; + } + + // Alter the inner constraints. + LUAU_ASSERT(c.innerConstraints.size() == 2); + + // Anything that is blocked on this constraint must also be blocked on our inner constraints + auto blockedIt = blocked.find(constraint.get()); + if (blockedIt != blocked.end()) + { + for (const auto& ic : c.innerConstraints) + { + for (const auto& blockedConstraint : blockedIt->second) + block(ic, blockedConstraint); + } + } + + unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); + + unblock(c.result); return true; } @@ -1930,7 +1953,7 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l return errorRecoveryType(); } - TypePackId modulePack = FFlag::LuauScopelessModule ? module->returnType : module->getModuleScope()->returnType; + TypePackId modulePack = module->returnType; if (get(modulePack)) return errorRecoveryType(); diff --git a/Analysis/src/DcrLogger.cpp b/Analysis/src/DcrLogger.cpp index ef33aa6..a1ef650 100644 --- a/Analysis/src/DcrLogger.cpp +++ b/Analysis/src/DcrLogger.cpp @@ -105,6 +105,7 @@ void write(JsonEmitter& emitter, const ConstraintSnapshot& snapshot) { ObjectEmitter o = emitter.writeObject(); o.writePair("stringification", snapshot.stringification); + o.writePair("location", snapshot.location); o.writePair("blocks", snapshot.blocks); o.finish(); } @@ -293,6 +294,7 @@ void DcrLogger::captureInitialSolverState(const Scope* rootScope, const std::vec std::string id = toPointerId(c); solveLog.initialState.constraints[id] = { toString(*c.get(), opts), + c->location, snapshotBlocks(c), }; } @@ -310,6 +312,7 @@ StepSnapshot DcrLogger::prepareStepSnapshot( std::string id = toPointerId(c); constraints[id] = { toString(*c.get(), opts), + c->location, snapshotBlocks(c), }; } @@ -337,6 +340,7 @@ void DcrLogger::captureFinalSolverState(const Scope* rootScope, const std::vecto std::string id = toPointerId(c); solveLog.finalState.constraints[id] = { toString(*c.get(), opts), + c->location, snapshotBlocks(c), }; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 94342cc..a70d6dd 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -31,7 +31,6 @@ LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAG(DebugLuauLogSolverToJson); -LUAU_FASTFLAG(LuauScopelessModule); namespace Luau { @@ -113,9 +112,7 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c CloneState cloneState; std::vector typesToPersist; - typesToPersist.reserve( - checkedModule->declaredGlobals.size() + - (FFlag::LuauScopelessModule ? checkedModule->exportedTypeBindings.size() : checkedModule->getModuleScope()->exportedTypeBindings.size())); + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); for (const auto& [name, ty] : checkedModule->declaredGlobals) { @@ -127,8 +124,7 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c typesToPersist.push_back(globalTy); } - for (const auto& [name, ty] : - FFlag::LuauScopelessModule ? checkedModule->exportedTypeBindings : checkedModule->getModuleScope()->exportedTypeBindings) + for (const auto& [name, ty] : checkedModule->exportedTypeBindings) { TypeFun globalTy = clone(ty, globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; @@ -173,9 +169,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t CloneState cloneState; std::vector typesToPersist; - typesToPersist.reserve( - checkedModule->declaredGlobals.size() + - (FFlag::LuauScopelessModule ? checkedModule->exportedTypeBindings.size() : checkedModule->getModuleScope()->exportedTypeBindings.size())); + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); for (const auto& [name, ty] : checkedModule->declaredGlobals) { @@ -187,8 +181,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t typesToPersist.push_back(globalTy); } - for (const auto& [name, ty] : - FFlag::LuauScopelessModule ? checkedModule->exportedTypeBindings : checkedModule->getModuleScope()->exportedTypeBindings) + for (const auto& [name, ty] : checkedModule->exportedTypeBindings) { TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; @@ -571,30 +564,17 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalinternalTypes.clear(); - if (FFlag::LuauScopelessModule) - { - module->astTypes.clear(); - module->astTypePacks.clear(); - module->astExpectedTypes.clear(); - module->astOriginalCallTypes.clear(); - module->astOverloadResolvedTypes.clear(); - module->astResolvedTypes.clear(); - module->astOriginalResolvedTypes.clear(); - module->astResolvedTypePacks.clear(); - module->astScopes.clear(); + module->astTypes.clear(); + module->astTypePacks.clear(); + module->astExpectedTypes.clear(); + module->astOriginalCallTypes.clear(); + module->astOverloadResolvedTypes.clear(); + module->astResolvedTypes.clear(); + module->astOriginalResolvedTypes.clear(); + module->astResolvedTypePacks.clear(); + module->astScopes.clear(); - module->scopes.clear(); - } - else - { - module->astTypes.clear(); - module->astExpectedTypes.clear(); - module->astOriginalCallTypes.clear(); - module->astResolvedTypes.clear(); - module->astResolvedTypePacks.clear(); - module->astOriginalResolvedTypes.clear(); - module->scopes.resize(1); - } + module->scopes.clear(); } if (mode != Mode::NoCheck) diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 912c415..9c3ae07 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -47,6 +47,7 @@ TypeId Instantiation::clean(TypeId ty) FunctionType clone = FunctionType{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; clone.magicFunction = ftv->magicFunction; clone.dcrMagicFunction = ftv->dcrMagicFunction; + clone.dcrMagicRefinement = ftv->dcrMagicRefinement; clone.tags = ftv->tags; clone.argNames = ftv->argNames; TypeId result = addType(std::move(clone)); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index a9faded..c0f4405 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess, false); LUAU_FASTFLAG(LuauSubstitutionReentrant); -LUAU_FASTFLAG(LuauScopelessModule); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); LUAU_FASTFLAG(LuauSubstitutionFixMissingFields); @@ -227,11 +226,8 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr } // Copy external stuff over to Module itself - if (FFlag::LuauScopelessModule) - { - this->returnType = moduleScope->returnType; - this->exportedTypeBindings = std::move(moduleScope->exportedTypeBindings); - } + this->returnType = moduleScope->returnType; + this->exportedTypeBindings = std::move(moduleScope->exportedTypeBindings); } bool Module::hasModuleScope() const diff --git a/Analysis/src/Refinement.cpp b/Analysis/src/Refinement.cpp index fb019f1..459379a 100644 --- a/Analysis/src/Refinement.cpp +++ b/Analysis/src/Refinement.cpp @@ -29,4 +29,9 @@ RefinementId RefinementArena::proposition(DefId def, TypeId discriminantTy) return NotNull{allocator.allocate(Proposition{def, discriminantTy})}; } +RefinementId RefinementArena::variadic(const std::vector& refis) +{ + return NotNull{allocator.allocate(Variadic{refis})}; +} + } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 07bdbd4..e59c7e0 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -26,6 +26,7 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) +LUAU_FASTFLAGVARIABLE(LuauDontExtendUnsealedRValueTables, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) @@ -35,7 +36,6 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) -LUAU_FASTFLAGVARIABLE(LuauScopelessModule, false) LUAU_FASTFLAGVARIABLE(LuauTryhardAnd, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauIntersectionTestForEquality, false) @@ -43,6 +43,7 @@ LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAG(SupportTypeAliasGoToDeclaration) +LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) namespace Luau { @@ -913,7 +914,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) } else { - expectedTypes.push_back(checkLValue(scope, *dest)); + expectedTypes.push_back(checkLValue(scope, *dest, ValueContext::LValue)); } } @@ -930,7 +931,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) TypeId left = nullptr; if (dest->is() || dest->is()) - left = checkLValue(scope, *dest); + left = checkLValue(scope, *dest, ValueContext::LValue); else left = *expectedTypes[i]; @@ -1119,8 +1120,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) if (ModulePtr module = resolver->getModule(moduleInfo->name)) { - scope->importedTypeBindings[name] = - FFlag::LuauScopelessModule ? module->exportedTypeBindings : module->getModuleScope()->exportedTypeBindings; + scope->importedTypeBindings[name] = module->exportedTypeBindings; if (FFlag::SupportTypeAliasGoToDeclaration) scope->importedModules[name] = moduleInfo->name; } @@ -2132,7 +2132,7 @@ TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) { - TypeId ty = checkLValue(scope, expr); + TypeId ty = checkLValue(scope, expr, ValueContext::RValue); if (std::optional lvalue = tryGetLValue(expr)) if (std::optional refiTy = resolveLValue(scope, *lvalue)) @@ -2977,14 +2977,23 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) { - if (auto predicate = tryGetTypeGuardPredicate(expr)) - return {booleanType, {std::move(*predicate)}}; + if (!FFlag::LuauTypecheckTypeguards) + { + if (auto predicate = tryGetTypeGuardPredicate(expr)) + return {booleanType, {std::move(*predicate)}}; + } // For these, passing expectedType is worse than simply forcing them, because their implementation // may inadvertently check if expectedTypes exist first and use it, instead of forceSingleton first. WithPredicate lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); WithPredicate rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); + if (FFlag::LuauTypecheckTypeguards) + { + if (auto predicate = tryGetTypeGuardPredicate(expr)) + return {booleanType, {std::move(*predicate)}}; + } + PredicateVec predicates; if (auto lvalue = tryGetLValue(*expr.left)) @@ -3068,21 +3077,21 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {stringType}; } -TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) +TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx) { - return checkLValueBinding(scope, expr); + return checkLValueBinding(scope, expr, ctx); } -TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr, ValueContext ctx) { if (auto a = expr.as()) return checkLValueBinding(scope, *a); else if (auto a = expr.as()) return checkLValueBinding(scope, *a); else if (auto a = expr.as()) - return checkLValueBinding(scope, *a); + return checkLValueBinding(scope, *a, ctx); else if (auto a = expr.as()) - return checkLValueBinding(scope, *a); + return checkLValueBinding(scope, *a, ctx); else if (auto a = expr.as()) { for (AstExpr* expr : a->expressions) @@ -3130,7 +3139,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGloba return result; } -TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr, ValueContext ctx) { TypeId lhs = checkExpr(scope, *expr.expr).type; @@ -3153,7 +3162,15 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex { return it->second.type; } - else if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) + else if (!FFlag::LuauDontExtendUnsealedRValueTables && (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free)) + { + TypeId theType = freshType(scope); + Property& property = lhsTable->props[name]; + property.type = theType; + property.location = expr.indexLocation; + return theType; + } + else if (FFlag::LuauDontExtendUnsealedRValueTables && ((ctx == ValueContext::LValue && lhsTable->state == TableState::Unsealed) || lhsTable->state == TableState::Free)) { TypeId theType = freshType(scope); Property& property = lhsTable->props[name]; @@ -3216,7 +3233,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex return errorRecoveryType(scope); } -TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr, ValueContext ctx) { TypeId exprType = checkExpr(scope, *expr.expr).type; tablify(exprType); @@ -3274,7 +3291,15 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex { return it->second.type; } - else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) + else if (!FFlag::LuauDontExtendUnsealedRValueTables && (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free)) + { + TypeId resultType = freshType(scope); + Property& property = exprTable->props[value->value.data]; + property.type = resultType; + property.location = expr.index->location; + return resultType; + } + else if (FFlag::LuauDontExtendUnsealedRValueTables && ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) { TypeId resultType = freshType(scope); Property& property = exprTable->props[value->value.data]; @@ -3290,20 +3315,35 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex unify(indexType, indexer.indexType, scope, expr.index->location); return indexer.indexResultType; } - else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) + else if (!FFlag::LuauDontExtendUnsealedRValueTables && (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free)) { TypeId resultType = freshType(exprTable->level); exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return resultType; } + else if (FFlag::LuauDontExtendUnsealedRValueTables && ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free)) + { + TypeId indexerType = freshType(exprTable->level); + unify(indexType, indexerType, scope, expr.location); + TypeId indexResultType = freshType(exprTable->level); + + exprTable->indexer = TableIndexer{anyIfNonstrict(indexerType), anyIfNonstrict(indexResultType)}; + return indexResultType; + } else { /* - * If we use [] indexing to fetch a property from a sealed table that has no indexer, we have no idea if it will - * work, so we just mint a fresh type, return that, and hope for the best. + * If we use [] indexing to fetch a property from a sealed table that + * has no indexer, we have no idea if it will work so we just return any + * and hope for the best. */ - TypeId resultType = freshType(scope); - return resultType; + if (FFlag::LuauDontExtendUnsealedRValueTables) + return anyType; + else + { + TypeId resultType = freshType(scope); + return resultType; + } } } @@ -4508,7 +4548,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module return errorRecoveryType(scope); } - TypePackId modulePack = FFlag::LuauScopelessModule ? module->returnType : module->getModuleScope()->returnType; + TypePackId modulePack = module->returnType; if (get(modulePack)) return errorRecoveryType(scope); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index d48c72f..bda062a 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -1038,6 +1038,24 @@ void Unifier::tryUnifyNormalizedTypes( } } + if (FFlag::DebugLuauDeferredConstraintResolution) + { + for (TypeId superTable : superNorm.tables) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify(subClass, superTable); + + if (innerState.errors.empty()) + { + found = true; + log.concat(std::move(innerState.log)); + break; + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + return reportError(*e); + } + } + if (!found) { return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); diff --git a/CMakeLists.txt b/CMakeLists.txt index 05d701e..4255c7c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -172,6 +172,7 @@ endif() if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) + target_compile_options(Luau.Reduce.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Ast.CLI PRIVATE ${LUAU_OPTIONS}) diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index 7482b0a..0941d47 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -8,7 +8,9 @@ namespace CodeGen struct IrFunction; -void updateUseInfo(IrFunction& function); +void updateUseCounts(IrFunction& function); + +void updateLastUseLocations(IrFunction& function); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index 5b51e0a..ebbba68 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -50,6 +50,8 @@ struct IrBuilder IrOp vmConst(uint32_t index); IrOp vmUpvalue(uint8_t index); + bool inTerminatedBlock = false; + bool activeFastcallFallback = false; IrOp fastcallFallbackReturn; diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 1c70c80..28f5b29 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -16,31 +16,90 @@ namespace Luau namespace CodeGen { +// IR instruction command. +// In the command description, following abbreviations are used: +// * Rn - VM stack register slot, n in 0..254 +// * Kn - VM proto constant slot, n in 0..2^23-1 +// * UPn - VM function upvalue slot, n in 0..254 +// * A, B, C, D, E are instruction arguments enum class IrCmd : uint8_t { NOP, + // Load a tag from TValue + // A: Rn or Kn LOAD_TAG, + + // Load a pointer (*) from TValue + // A: Rn or Kn LOAD_POINTER, + + // Load a double number from TValue + // A: Rn or Kn LOAD_DOUBLE, + + // Load an int from TValue + // A: Rn LOAD_INT, + + // Load a TValue from memory + // A: Rn or Kn or pointer (TValue) LOAD_TVALUE, + + // Load a TValue from table node value + // A: pointer (LuaNode) LOAD_NODE_VALUE_TV, // TODO: we should find a way to generalize LOAD_TVALUE + + // Load current environment table LOAD_ENV, + // Get pointer (TValue) to table array at index + // A: pointer (Table) + // B: unsigned int GET_ARR_ADDR, + + // Get pointer (LuaNode) to table node element at the active cached slot index + // A: pointer (Table) GET_SLOT_NODE_ADDR, + // Store a tag into TValue + // A: Rn + // B: tag STORE_TAG, + + // Store a pointer (*) into TValue + // A: Rn + // B: pointer STORE_POINTER, + + // Store a double number into TValue + // A: Rn + // B: double STORE_DOUBLE, + + // Store an int into TValue + // A: Rn + // B: int STORE_INT, + + // Store a TValue into memory + // A: Rn or pointer (TValue) + // B: TValue STORE_TVALUE, + + // Store a TValue into table node value + // A: pointer (LuaNode) + // B: TValue STORE_NODE_VALUE_TV, // TODO: we should find a way to generalize STORE_TVALUE + // Add/Sub two integers together + // A, B: int ADD_INT, SUB_INT, + // Add/Sub/Mul/Div/Mod/Pow two double numbers + // A, B: double + // In final x64 lowering, B can also be Rn or Kn ADD_NUM, SUB_NUM, MUL_NUM, @@ -48,91 +107,351 @@ enum class IrCmd : uint8_t MOD_NUM, POW_NUM, + // Negate a double number + // A: double UNM_NUM, + // Compute Luau 'not' operation on destructured TValue + // A: tag + // B: double NOT_ANY, // TODO: boolean specialization will be useful + // Unconditional jump + // A: block JUMP, + + // Jump if TValue is truthy + // A: Rn + // B: block (if true) + // C: block (if false) JUMP_IF_TRUTHY, + + // Jump if TValue is falsy + // A: Rn + // B: block (if true) + // C: block (if false) JUMP_IF_FALSY, + + // Jump if tags are equal + // A, B: tag + // C: block (if true) + // D: block (if false) JUMP_EQ_TAG, - JUMP_EQ_BOOLEAN, + + // Jump if two int numbers are equal + // A, B: int + // C: block (if true) + // D: block (if false) + JUMP_EQ_INT, + + // Jump if pointers are equal + // A, B: pointer (*) + // C: block (if true) + // D: block (if false) JUMP_EQ_POINTER, + // Perform a conditional jump based on the result of double comparison + // A, B: double + // C: condition + // D: block (if true) + // E: block (if false) JUMP_CMP_NUM, - JUMP_CMP_STR, + + // Perform a conditional jump based on the result of TValue comparison + // A, B: Rn + // C: condition + // D: block (if true) + // E: block (if false) JUMP_CMP_ANY, + // Get table length + // A: pointer (Table) TABLE_LEN, + + // Allocate new table + // A: int (array element count) + // B: int (node element count) NEW_TABLE, + + // Duplicate a table + // A: pointer (Table) DUP_TABLE, + // Try to convert a double number into a table index or jump if it's not an integer + // A: double + // B: block NUM_TO_INDEX, + // Convert integer into a double number + // A: int + INT_TO_NUM, + // Fallback functions + + // Perform an arithmetic operation on TValues of any type + // A: Rn (where to store the result) + // B: Rn (lhs) + // C: Rn or Kn (rhs) DO_ARITH, + + // Get length of a TValue of any type + // A: Rn (where to store the result) + // B: Rn DO_LEN, + + // Lookup a value in TValue of any type using a key of any type + // A: Rn (where to store the result) + // B: Rn + // C: Rn or unsigned int (key) GET_TABLE, + + // Store a value into TValue of any type using a key of any type + // A: Rn (value to store) + // B: Rn + // C: Rn or unsigned int (key) SET_TABLE, + + // Lookup a value in the environment + // A: Rn (where to store the result) + // B: unsigned int (import path) GET_IMPORT, + + // Concatenate multiple TValues + // A: Rn (where to store the result) + // B: unsigned int (index of the first VM stack slot) + // C: unsigned int (number of stack slots to go over) CONCAT, + + // Load function upvalue into stack slot + // A: Rn + // B: UPn GET_UPVALUE, + + // Store TValue from stack slot into a function upvalue + // A: UPn + // B: Rn SET_UPVALUE, - // Guards and checks + // Convert TValues into numbers for a numerical for loop + // A: Rn (start) + // B: Rn (end) + // C: Rn (step) + PREPARE_FORN, + + // Guards and checks (these instructions are not block terminators even though they jump to fallback) + + // Guard against tag mismatch + // A, B: tag + // C: block + // In final x64 lowering, A can also be Rn CHECK_TAG, + + // Guard against readonly table + // A: pointer (Table) + // B: block CHECK_READONLY, + + // Guard against table having a metatable + // A: pointer (Table) + // B: block CHECK_NO_METATABLE, + + // Guard against executing in unsafe environment + // A: block CHECK_SAFE_ENV, + + // Guard against index overflowing the table array size + // A: pointer (Table) + // B: block CHECK_ARRAY_SIZE, + + // Guard against cached table node slot not matching the actual table node slot for a key + // A: pointer (LuaNode) + // B: Kn + // C: block CHECK_SLOT_MATCH, // Special operations + + // Check interrupt handler + // A: unsigned int (pcpos) INTERRUPT, + + // Check and run GC assist if necessary CHECK_GC, + + // Handle GC write barrier (forward) + // A: pointer (GCObject) + // B: Rn (TValue that was written to the object) BARRIER_OBJ, + + // Handle GC write barrier (backwards) for a write into a table + // A: pointer (Table) BARRIER_TABLE_BACK, + + // Handle GC write barrier (forward) for a write into a table + // A: pointer (Table) + // B: Rn (TValue that was written to the object) BARRIER_TABLE_FORWARD, + + // Update savedpc value + // A: unsigned int (pcpos) SET_SAVEDPC, + + // Close open upvalues for registers at specified index or higher + // A: Rn (starting register index) CLOSE_UPVALS, // While capture is a no-op right now, it might be useful to track register/upvalue lifetimes + // A: Rn or UPn + // B: boolean (true for reference capture, false for value capture) CAPTURE, // Operations that don't have an IR representation yet + + // Set a list of values to table in target register + // A: unsigned int (bytecode instruction index) + // B: Rn (target) + // C: Rn (source start) + // D: int (count or -1 to assign values up to stack top) + // E: unsigned int (table index to start from) LOP_SETLIST, + + // Load function from source register using name into target register and copying source register into target register + 1 + // A: unsigned int (bytecode instruction index) + // B: Rn (target) + // C: Rn (source) + // D: block (next) + // E: block (fallback) LOP_NAMECALL, + + // Call specified function + // A: unsigned int (bytecode instruction index) + // B: Rn (function, followed by arguments) + // C: int (argument count or -1 to preserve all arguments up to stack top) + // D: int (result count or -1 to preserve all results and adjust stack top) + // Note: return values are placed starting from Rn specified in 'B' LOP_CALL, + + // Return specified values from the function + // A: unsigned int (bytecode instruction index) + // B: Rn (value start) + // B: int (result count or -1 to return all values up to stack top) LOP_RETURN, + + // Perform a fast call of a built-in function + // A: unsigned int (bytecode instruction index) + // B: Rn (argument start) + // C: int (argument count or -1 preserve all arguments up to stack top) + // D: block (fallback) + // Note: return values are placed starting from Rn specified in 'B' LOP_FASTCALL, + + // Perform a fast call of a built-in function using 1 register argument + // A: unsigned int (bytecode instruction index) + // B: Rn (result start) + // C: Rn (arg1) + // D: block (fallback) LOP_FASTCALL1, + + // Perform a fast call of a built-in function using 2 register arguments + // A: unsigned int (bytecode instruction index) + // B: Rn (result start) + // C: Rn (arg1) + // D: Rn (arg2) + // E: block (fallback) LOP_FASTCALL2, + + // Perform a fast call of a built-in function using 1 register argument and 1 constant argument + // A: unsigned int (bytecode instruction index) + // B: Rn (result start) + // C: Rn (arg1) + // D: Kn (arg2) + // E: block (fallback) LOP_FASTCALL2K, - LOP_FORNPREP, - LOP_FORNLOOP, + LOP_FORGLOOP, LOP_FORGLOOP_FALLBACK, - LOP_FORGPREP_NEXT, - LOP_FORGPREP_INEXT, LOP_FORGPREP_XNEXT_FALLBACK, + + // Perform `and` or `or` operation (selecting lhs or rhs based on whether the lhs is truthy) and put the result into target register + // A: unsigned int (bytecode instruction index) + // B: Rn (target) + // C: Rn (lhs) + // D: Rn or Kn (rhs) LOP_AND, LOP_ANDK, LOP_OR, LOP_ORK, + + // Increment coverage data (saturating 24 bit add) + // A: unsigned int (bytecode instruction index) LOP_COVERAGE, // Operations that have a translation, but use a full instruction fallback + + // Load a value from global table at specified key + // A: unsigned int (bytecode instruction index) + // B: Rn (dest) + // C: Kn (key) FALLBACK_GETGLOBAL, + + // Store a value into global table at specified key + // A: unsigned int (bytecode instruction index) + // B: Rn (value) + // C: Kn (key) FALLBACK_SETGLOBAL, + + // Load a value from table at specified key + // A: unsigned int (bytecode instruction index) + // B: Rn (dest) + // C: Rn (table) + // D: Kn (key) FALLBACK_GETTABLEKS, + + // Store a value into a table at specified key + // A: unsigned int (bytecode instruction index) + // B: Rn (value) + // C: Rn (table) + // D: Kn (key) FALLBACK_SETTABLEKS, + + // Load function from source register using name into target register and copying source register into target register + 1 + // A: unsigned int (bytecode instruction index) + // B: Rn (target) + // C: Rn (source) + // D: Kn (name) FALLBACK_NAMECALL, // Operations that don't have assembly lowering at all + + // Prepare stack for variadic functions so that GETVARARGS works correctly + // A: unsigned int (bytecode instruction index) + // B: int (numparams) FALLBACK_PREPVARARGS, + + // Copy variables into the target registers from vararg storage for current function + // A: unsigned int (bytecode instruction index) + // B: Rn (dest start) + // C: int (count) FALLBACK_GETVARARGS, + + // Create closure from a child proto + // A: unsigned int (bytecode instruction index) + // B: Rn (dest) + // C: unsigned int (protoid) FALLBACK_NEWCLOSURE, + + // Create closure from a pre-created function object (reusing it unless environments diverge) + // A: unsigned int (bytecode instruction index) + // B: Rn (dest) + // C: Kn (prototype) FALLBACK_DUPCLOSURE, + + // Prepare loop variables for a generic for loop, jump to the loop backedge unconditionally + // A: unsigned int (bytecode instruction index) + // B: Rn (loop state, updates Rn Rn+1 Rn+2) + // B: block FALLBACK_FORGPREP, }; @@ -251,15 +570,18 @@ enum class IrBlockKind : uint8_t Bytecode, Fallback, Internal, + Dead, }; struct IrBlock { IrBlockKind kind; + uint16_t useCount = 0; + // Start points to an instruction index in a stream // End is implicit - uint32_t start; + uint32_t start = ~0u; Label label; }; @@ -279,6 +601,64 @@ struct IrFunction std::vector bcMapping; Proto* proto = nullptr; + + IrBlock& blockOp(IrOp op) + { + LUAU_ASSERT(op.kind == IrOpKind::Block); + return blocks[op.index]; + } + + IrInst& instOp(IrOp op) + { + LUAU_ASSERT(op.kind == IrOpKind::Inst); + return instructions[op.index]; + } + + IrConst& constOp(IrOp op) + { + LUAU_ASSERT(op.kind == IrOpKind::Constant); + return constants[op.index]; + } + + uint8_t tagOp(IrOp op) + { + IrConst& value = constOp(op); + + LUAU_ASSERT(value.kind == IrConstKind::Tag); + return value.valueTag; + } + + bool boolOp(IrOp op) + { + IrConst& value = constOp(op); + + LUAU_ASSERT(value.kind == IrConstKind::Bool); + return value.valueBool; + } + + int intOp(IrOp op) + { + IrConst& value = constOp(op); + + LUAU_ASSERT(value.kind == IrConstKind::Int); + return value.valueInt; + } + + unsigned uintOp(IrOp op) + { + IrConst& value = constOp(op); + + LUAU_ASSERT(value.kind == IrConstKind::Uint); + return value.valueUint; + } + + double doubleOp(IrOp op) + { + IrConst& value = constOp(op); + + LUAU_ASSERT(value.kind == IrConstKind::Double); + return value.valueDouble; + } }; } // namespace CodeGen diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index 2f44ea8..47a5f9e 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -21,12 +21,16 @@ struct IrToStringContext std::vector& constants; }; -void toString(IrToStringContext& ctx, IrInst inst, uint32_t index); +void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index); +void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index); // Block title void toString(IrToStringContext& ctx, IrOp op); void toString(std::string& result, IrConst constant); -void toStringDetailed(IrToStringContext& ctx, IrInst inst, uint32_t index); +void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index); +void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index); // Block title + +std::string toString(IrFunction& function, bool includeDetails); std::string dump(IrFunction& function); diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 8438205..1aef9a3 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -92,19 +92,14 @@ inline bool isBlockTerminator(IrCmd cmd) case IrCmd::JUMP_IF_TRUTHY: case IrCmd::JUMP_IF_FALSY: case IrCmd::JUMP_EQ_TAG: - case IrCmd::JUMP_EQ_BOOLEAN: + case IrCmd::JUMP_EQ_INT: case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_CMP_NUM: - case IrCmd::JUMP_CMP_STR: case IrCmd::JUMP_CMP_ANY: case IrCmd::LOP_NAMECALL: case IrCmd::LOP_RETURN: - case IrCmd::LOP_FORNPREP: - case IrCmd::LOP_FORNLOOP: case IrCmd::LOP_FORGLOOP: case IrCmd::LOP_FORGLOOP_FALLBACK: - case IrCmd::LOP_FORGPREP_NEXT: - case IrCmd::LOP_FORGPREP_INEXT: case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: case IrCmd::FALLBACK_FORGPREP: return true; @@ -142,6 +137,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::NEW_TABLE: case IrCmd::DUP_TABLE: case IrCmd::NUM_TO_INDEX: + case IrCmd::INT_TO_NUM: return true; default: break; @@ -157,5 +153,24 @@ inline bool hasSideEffects(IrCmd cmd) return !hasResult(cmd); } +// Remove a single instruction +void kill(IrFunction& function, IrInst& inst); + +// Remove a range of instructions +void kill(IrFunction& function, uint32_t start, uint32_t end); + +// Remove a block, including all instructions inside +void kill(IrFunction& function, IrBlock& block); + +void removeUse(IrFunction& function, IrInst& inst); +void removeUse(IrFunction& function, IrBlock& block); + +// Replace a single operand and update use counts (can cause chain removal of dead code) +void replace(IrFunction& function, IrOp& original, IrOp replacement); + +// Replace a single instruction +// Target instruction index instead of reference is used to handle introduction of a new block terminator +void replace(IrFunction& function, uint32_t instIdx, IrInst replacement); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/OptimizeFinalX64.h b/CodeGen/include/Luau/OptimizeFinalX64.h new file mode 100644 index 0000000..bc50dd7 --- /dev/null +++ b/CodeGen/include/Luau/OptimizeFinalX64.h @@ -0,0 +1,14 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/IrData.h" + +namespace Luau +{ +namespace CodeGen +{ + +void optimizeMemoryOperandsX64(IrFunction& function); + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 72b2cbb..78f001f 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -7,6 +7,7 @@ #include "Luau/CodeBlockUnwind.h" #include "Luau/IrAnalysis.h" #include "Luau/IrBuilder.h" +#include "Luau/OptimizeFinalX64.h" #include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilderDwarf2.h" #include "Luau/UnwindBuilderWin.h" @@ -431,7 +432,7 @@ static NativeProto* assembleFunction(AssemblyBuilderX64& build, NativeState& dat IrBuilder builder; builder.buildFunctionIr(proto); - updateUseInfo(builder.function); + optimizeMemoryOperandsX64(builder.function); IrLoweringX64 lowering(build, helpers, data, proto, builder.function); diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index e0dae66..7d36e17 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -195,13 +195,20 @@ static void callBarrierImpl(AssemblyBuilderX64& build, RegisterX64 tmp, Register if (object == rArg3) { LUAU_ASSERT(tmp != rArg2); - build.mov(rArg2, object); - build.mov(rArg3, tmp); + + if (rArg2 != object) + build.mov(rArg2, object); + + if (rArg3 != tmp) + build.mov(rArg3, tmp); } else { - build.mov(rArg3, tmp); - build.mov(rArg2, object); + if (rArg3 != tmp) + build.mov(rArg3, tmp); + + if (rArg2 != object) + build.mov(rArg2, object); } build.mov(rArg1, rState); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index b9f3953..a27d78a 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -11,31 +11,56 @@ namespace Luau namespace CodeGen { -static void recordUse(IrInst& inst, size_t index) +void updateUseCounts(IrFunction& function) { - LUAU_ASSERT(inst.useCount < 0xffff); + std::vector& blocks = function.blocks; + std::vector& instructions = function.instructions; - inst.useCount++; - inst.lastUse = uint32_t(index); + for (IrBlock& block : blocks) + block.useCount = 0; + + for (IrInst& inst : instructions) + inst.useCount = 0; + + auto checkOp = [&](IrOp op) { + if (op.kind == IrOpKind::Inst) + { + IrInst& target = instructions[op.index]; + LUAU_ASSERT(target.useCount < 0xffff); + target.useCount++; + } + else if (op.kind == IrOpKind::Block) + { + IrBlock& target = blocks[op.index]; + LUAU_ASSERT(target.useCount < 0xffff); + target.useCount++; + } + }; + + for (IrInst& inst : instructions) + { + checkOp(inst.a); + checkOp(inst.b); + checkOp(inst.c); + checkOp(inst.d); + checkOp(inst.e); + } } -void updateUseInfo(IrFunction& function) +void updateLastUseLocations(IrFunction& function) { std::vector& instructions = function.instructions; for (IrInst& inst : instructions) - { - inst.useCount = 0; inst.lastUse = 0; - } - for (size_t i = 0; i < instructions.size(); ++i) + for (size_t instIdx = 0; instIdx < instructions.size(); ++instIdx) { - IrInst& inst = instructions[i]; + IrInst& inst = instructions[instIdx]; - auto checkOp = [&instructions, i](IrOp op) { + auto checkOp = [&](IrOp op) { if (op.kind == IrOpKind::Inst) - recordUse(instructions[op.index], i); + instructions[op.index].lastUse = uint32_t(instIdx); }; checkOp(inst.a); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 25f6d45..9c57310 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -2,6 +2,7 @@ #include "Luau/IrBuilder.h" #include "Luau/Common.h" +#include "Luau/IrAnalysis.h" #include "Luau/IrUtils.h" #include "CustomExecUtils.h" @@ -40,7 +41,9 @@ void IrBuilder::buildFunctionIr(Proto* proto) if (instIndexToBlock[i] != kNoAssociatedBlockIndex) beginBlock(blockAtInst(i)); - translateInst(op, pc, i); + // We skip dead bytecode instructions when they appear after block was already terminated + if (!inTerminatedBlock) + translateInst(op, pc, i); i = nexti; LUAU_ASSERT(i <= proto->sizecode); @@ -52,6 +55,9 @@ void IrBuilder::buildFunctionIr(Proto* proto) inst(IrCmd::JUMP, blockAtInst(i)); } } + + // Now that all has been generated, compute use counts + updateUseCounts(function); } void IrBuilder::rebuildBytecodeBasicBlocks(Proto* proto) @@ -120,7 +126,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstSetGlobal(*this, pc, i); break; case LOP_CALL: - inst(IrCmd::LOP_CALL, constUint(i)); + inst(IrCmd::LOP_CALL, constUint(i), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1), constInt(LUAU_INSN_C(*pc) - 1)); if (activeFastcallFallback) { @@ -132,7 +138,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) } break; case LOP_RETURN: - inst(IrCmd::LOP_RETURN, constUint(i)); + inst(IrCmd::LOP_RETURN, constUint(i), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1)); break; case LOP_GETTABLE: translateInstGetTable(*this, pc, i); @@ -249,7 +255,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstDupTable(*this, pc, i); break; case LOP_SETLIST: - inst(IrCmd::LOP_SETLIST, constUint(i)); + inst(IrCmd::LOP_SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1])); break; case LOP_GETUPVAL: translateInstGetUpval(*this, pc, i); @@ -262,10 +268,15 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) break; case LOP_FASTCALL: { - IrOp fallback = block(IrBlockKind::Fallback); - IrOp next = blockAtInst(i + LUAU_INSN_C(*pc) + 2); + int skip = LUAU_INSN_C(*pc); - inst(IrCmd::LOP_FASTCALL, constUint(i), fallback); + IrOp fallback = block(IrBlockKind::Fallback); + IrOp next = blockAtInst(i + skip + 2); + + Instruction call = pc[skip + 1]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + inst(IrCmd::LOP_FASTCALL, constUint(i), vmReg(LUAU_INSN_A(call)), constInt(LUAU_INSN_B(call) - 1), fallback); inst(IrCmd::JUMP, next); beginBlock(fallback); @@ -276,10 +287,15 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) } case LOP_FASTCALL1: { - IrOp fallback = block(IrBlockKind::Fallback); - IrOp next = blockAtInst(i + LUAU_INSN_C(*pc) + 2); + int skip = LUAU_INSN_C(*pc); - inst(IrCmd::LOP_FASTCALL1, constUint(i), fallback); + IrOp fallback = block(IrBlockKind::Fallback); + IrOp next = blockAtInst(i + skip + 2); + + Instruction call = pc[skip + 1]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + inst(IrCmd::LOP_FASTCALL1, constUint(i), vmReg(LUAU_INSN_A(call)), vmReg(LUAU_INSN_B(*pc)), fallback); inst(IrCmd::JUMP, next); beginBlock(fallback); @@ -290,10 +306,15 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) } case LOP_FASTCALL2: { - IrOp fallback = block(IrBlockKind::Fallback); - IrOp next = blockAtInst(i + LUAU_INSN_C(*pc) + 2); + int skip = LUAU_INSN_C(*pc); - inst(IrCmd::LOP_FASTCALL2, constUint(i), fallback); + IrOp fallback = block(IrBlockKind::Fallback); + IrOp next = blockAtInst(i + skip + 2); + + Instruction call = pc[skip + 1]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + inst(IrCmd::LOP_FASTCALL2, constUint(i), vmReg(LUAU_INSN_A(call)), vmReg(LUAU_INSN_B(*pc)), vmReg(pc[1]), fallback); inst(IrCmd::JUMP, next); beginBlock(fallback); @@ -304,10 +325,15 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) } case LOP_FASTCALL2K: { - IrOp fallback = block(IrBlockKind::Fallback); - IrOp next = blockAtInst(i + LUAU_INSN_C(*pc) + 2); + int skip = LUAU_INSN_C(*pc); - inst(IrCmd::LOP_FASTCALL2K, constUint(i), fallback); + IrOp fallback = block(IrBlockKind::Fallback); + IrOp next = blockAtInst(i + skip + 2); + + Instruction call = pc[skip + 1]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + inst(IrCmd::LOP_FASTCALL2K, constUint(i), vmReg(LUAU_INSN_A(call)), vmReg(LUAU_INSN_B(*pc)), vmConst(pc[1]), fallback); inst(IrCmd::JUMP, next); beginBlock(fallback); @@ -317,72 +343,50 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) break; } case LOP_FORNPREP: - { - IrOp loopStart = blockAtInst(i + getOpLength(LOP_FORNPREP)); - IrOp loopExit = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); - - inst(IrCmd::LOP_FORNPREP, constUint(i), loopStart, loopExit); - - beginBlock(loopStart); + translateInstForNPrep(*this, pc, i); break; - } case LOP_FORNLOOP: - { - IrOp loopRepeat = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); - IrOp loopExit = blockAtInst(i + getOpLength(LOP_FORNLOOP)); - - inst(IrCmd::LOP_FORNLOOP, constUint(i), loopRepeat, loopExit); - - beginBlock(loopExit); + translateInstForNLoop(*this, pc, i); break; - } case LOP_FORGLOOP: { - IrOp loopRepeat = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); - IrOp loopExit = blockAtInst(i + getOpLength(LOP_FORGLOOP)); - IrOp fallback = block(IrBlockKind::Fallback); + // We have a translation for ipairs-style traversal, general loop iteration is still too complex + if (int(pc[1]) < 0) + { + translateInstForGLoopIpairs(*this, pc, i); + } + else + { + IrOp loopRepeat = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); + IrOp loopExit = blockAtInst(i + getOpLength(LOP_FORGLOOP)); + IrOp fallback = block(IrBlockKind::Fallback); - inst(IrCmd::LOP_FORGLOOP, constUint(i), loopRepeat, loopExit, fallback); + inst(IrCmd::LOP_FORGLOOP, constUint(i), loopRepeat, loopExit, fallback); - beginBlock(fallback); - inst(IrCmd::LOP_FORGLOOP_FALLBACK, constUint(i), loopRepeat, loopExit); + beginBlock(fallback); + inst(IrCmd::LOP_FORGLOOP_FALLBACK, constUint(i), loopRepeat, loopExit); - beginBlock(loopExit); + beginBlock(loopExit); + } break; } case LOP_FORGPREP_NEXT: - { - IrOp target = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); - IrOp fallback = block(IrBlockKind::Fallback); - - inst(IrCmd::LOP_FORGPREP_NEXT, constUint(i), target, fallback); - - beginBlock(fallback); - inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, constUint(i), target); + translateInstForGPrepNext(*this, pc, i); break; - } case LOP_FORGPREP_INEXT: - { - IrOp target = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); - IrOp fallback = block(IrBlockKind::Fallback); - - inst(IrCmd::LOP_FORGPREP_INEXT, constUint(i), target, fallback); - - beginBlock(fallback); - inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, constUint(i), target); + translateInstForGPrepInext(*this, pc, i); break; - } case LOP_AND: - inst(IrCmd::LOP_AND, constUint(i)); + inst(IrCmd::LOP_AND, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); break; case LOP_ANDK: - inst(IrCmd::LOP_ANDK, constUint(i)); + inst(IrCmd::LOP_ANDK, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); break; case LOP_OR: - inst(IrCmd::LOP_OR, constUint(i)); + inst(IrCmd::LOP_OR, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmReg(LUAU_INSN_C(*pc))); break; case LOP_ORK: - inst(IrCmd::LOP_ORK, constUint(i)); + inst(IrCmd::LOP_ORK, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(LUAU_INSN_C(*pc))); break; case LOP_COVERAGE: inst(IrCmd::LOP_COVERAGE, constUint(i)); @@ -401,30 +405,34 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) IrOp next = blockAtInst(i + getOpLength(LOP_NAMECALL)); IrOp fallback = block(IrBlockKind::Fallback); - inst(IrCmd::LOP_NAMECALL, constUint(i), next, fallback); + inst(IrCmd::LOP_NAMECALL, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), next, fallback); beginBlock(fallback); - inst(IrCmd::FALLBACK_NAMECALL, constUint(i)); + inst(IrCmd::FALLBACK_NAMECALL, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(pc[1])); inst(IrCmd::JUMP, next); beginBlock(next); break; } case LOP_PREPVARARGS: - inst(IrCmd::FALLBACK_PREPVARARGS, constUint(i)); + inst(IrCmd::FALLBACK_PREPVARARGS, constUint(i), constInt(LUAU_INSN_A(*pc))); break; case LOP_GETVARARGS: - inst(IrCmd::FALLBACK_GETVARARGS, constUint(i)); + inst(IrCmd::FALLBACK_GETVARARGS, constUint(i), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_B(*pc) - 1)); break; case LOP_NEWCLOSURE: - inst(IrCmd::FALLBACK_NEWCLOSURE, constUint(i)); + inst(IrCmd::FALLBACK_NEWCLOSURE, constUint(i), vmReg(LUAU_INSN_A(*pc)), constUint(LUAU_INSN_D(*pc))); break; case LOP_DUPCLOSURE: - inst(IrCmd::FALLBACK_DUPCLOSURE, constUint(i)); + inst(IrCmd::FALLBACK_DUPCLOSURE, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmConst(LUAU_INSN_D(*pc))); break; case LOP_FORGPREP: - inst(IrCmd::FALLBACK_FORGPREP, constUint(i)); + { + IrOp loopStart = blockAtInst(i + 1 + LUAU_INSN_D(*pc)); + + inst(IrCmd::FALLBACK_FORGPREP, constUint(i), vmReg(LUAU_INSN_A(*pc)), loopStart); break; + } default: LUAU_ASSERT(!"unknown instruction"); break; @@ -445,6 +453,8 @@ void IrBuilder::beginBlock(IrOp block) LUAU_ASSERT(target.start == ~0u || target.start == uint32_t(function.instructions.size())); target.start = uint32_t(function.instructions.size()); + + inTerminatedBlock = false; } IrOp IrBuilder::constBool(bool value) @@ -528,6 +538,10 @@ IrOp IrBuilder::inst(IrCmd cmd, IrOp a, IrOp b, IrOp c, IrOp d, IrOp e) { uint32_t index = uint32_t(function.instructions.size()); function.instructions.push_back({cmd, a, b, c, d, e}); + + if (isBlockTerminator(cmd)) + inTerminatedBlock = true; + return {IrOpKind::Inst, index}; } @@ -537,7 +551,7 @@ IrOp IrBuilder::block(IrBlockKind kind) kind = IrBlockKind::Fallback; uint32_t index = uint32_t(function.blocks.size()); - function.blocks.push_back(IrBlock{kind, ~0u}); + function.blocks.push_back(IrBlock{kind}); return IrOp{IrOpKind::Block, index}; } diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index eb5a074..5a23861 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -29,6 +29,14 @@ static void append(std::string& result, const char* fmt, ...) result.append(buf); } +static void padToDetailColumn(std::string& result, size_t lineStart) +{ + int pad = kDetailsAlignColumn - int(result.size() - lineStart); + + if (pad > 0) + result.append(pad, ' '); +} + static const char* getTagName(uint8_t tag) { switch (tag) @@ -122,14 +130,12 @@ const char* getCmdName(IrCmd cmd) return "JUMP_IF_FALSY"; case IrCmd::JUMP_EQ_TAG: return "JUMP_EQ_TAG"; - case IrCmd::JUMP_EQ_BOOLEAN: - return "JUMP_EQ_BOOLEAN"; + case IrCmd::JUMP_EQ_INT: + return "JUMP_EQ_INT"; case IrCmd::JUMP_EQ_POINTER: return "JUMP_EQ_POINTER"; case IrCmd::JUMP_CMP_NUM: return "JUMP_CMP_NUM"; - case IrCmd::JUMP_CMP_STR: - return "JUMP_CMP_STR"; case IrCmd::JUMP_CMP_ANY: return "JUMP_CMP_ANY"; case IrCmd::TABLE_LEN: @@ -140,6 +146,8 @@ const char* getCmdName(IrCmd cmd) return "DUP_TABLE"; case IrCmd::NUM_TO_INDEX: return "NUM_TO_INDEX"; + case IrCmd::INT_TO_NUM: + return "INT_TO_NUM"; case IrCmd::DO_ARITH: return "DO_ARITH"; case IrCmd::DO_LEN: @@ -156,6 +164,8 @@ const char* getCmdName(IrCmd cmd) return "GET_UPVALUE"; case IrCmd::SET_UPVALUE: return "SET_UPVALUE"; + case IrCmd::PREPARE_FORN: + return "PREPARE_FORN"; case IrCmd::CHECK_TAG: return "CHECK_TAG"; case IrCmd::CHECK_READONLY: @@ -200,18 +210,10 @@ const char* getCmdName(IrCmd cmd) return "LOP_FASTCALL2"; case IrCmd::LOP_FASTCALL2K: return "LOP_FASTCALL2K"; - case IrCmd::LOP_FORNPREP: - return "LOP_FORNPREP"; - case IrCmd::LOP_FORNLOOP: - return "LOP_FORNLOOP"; case IrCmd::LOP_FORGLOOP: return "LOP_FORGLOOP"; case IrCmd::LOP_FORGLOOP_FALLBACK: return "LOP_FORGLOOP_FALLBACK"; - case IrCmd::LOP_FORGPREP_NEXT: - return "LOP_FORGPREP_NEXT"; - case IrCmd::LOP_FORGPREP_INEXT: - return "LOP_FORGPREP_INEXT"; case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: return "LOP_FORGPREP_XNEXT_FALLBACK"; case IrCmd::LOP_AND: @@ -259,12 +261,14 @@ const char* getBlockKindName(IrBlockKind kind) return "bb_fallback"; case IrBlockKind::Internal: return "bb"; + case IrBlockKind::Dead: + return "dead"; } LUAU_UNREACHABLE(); } -void toString(IrToStringContext& ctx, IrInst inst, uint32_t index) +void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index) { append(ctx.result, " "); @@ -305,6 +309,11 @@ void toString(IrToStringContext& ctx, IrInst inst, uint32_t index) } } +void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index) +{ + append(ctx.result, "%s_%u:", getBlockKindName(block.kind), index); +} + void toString(IrToStringContext& ctx, IrOp op) { switch (op.kind) @@ -358,18 +367,12 @@ void toString(std::string& result, IrConst constant) } } -void toStringDetailed(IrToStringContext& ctx, IrInst inst, uint32_t index) +void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index) { size_t start = ctx.result.size(); toString(ctx, inst, index); - - int pad = kDetailsAlignColumn - int(ctx.result.size() - start); - - if (pad > 0) - ctx.result.append(pad, ' '); - - LUAU_ASSERT(inst.useCount == 0 || inst.lastUse != 0); + padToDetailColumn(ctx.result, start); if (inst.useCount == 0 && hasSideEffects(inst.cmd)) append(ctx.result, "; %%%u, has side-effects\n", index); @@ -377,7 +380,17 @@ void toStringDetailed(IrToStringContext& ctx, IrInst inst, uint32_t index) append(ctx.result, "; useCount: %d, lastUse: %%%u\n", inst.useCount, inst.lastUse); } -std::string dump(IrFunction& function) +void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index) +{ + size_t start = ctx.result.size(); + + toString(ctx, block, index); + padToDetailColumn(ctx.result, start); + + append(ctx.result, "; useCount: %d\n", block.useCount); +} + +std::string toString(IrFunction& function, bool includeDetails) { std::string result; IrToStringContext ctx{result, function.blocks, function.constants}; @@ -386,7 +399,18 @@ std::string dump(IrFunction& function) { IrBlock& block = function.blocks[i]; - append(ctx.result, "%s_%u:\n", getBlockKindName(block.kind), unsigned(i)); + if (block.kind == IrBlockKind::Dead) + continue; + + if (includeDetails) + { + toStringDetailed(ctx, block, uint32_t(i)); + } + else + { + toString(ctx, block, uint32_t(i)); + ctx.result.append("\n"); + } if (block.start == ~0u) { @@ -394,10 +418,9 @@ std::string dump(IrFunction& function) continue; } - for (uint32_t index = block.start; true; index++) + // To allow dumping blocks that are still being constructed, we can't rely on terminator and need a bounds check + for (uint32_t index = block.start; index < uint32_t(function.instructions.size()); index++) { - LUAU_ASSERT(index < function.instructions.size()); - IrInst& inst = function.instructions[index]; // Nop is used to replace dead instructions in-place, so it's not that useful to see them @@ -405,7 +428,16 @@ std::string dump(IrFunction& function) continue; append(ctx.result, " "); - toStringDetailed(ctx, inst, index); + + if (includeDetails) + { + toStringDetailed(ctx, inst, index); + } + else + { + toString(ctx, inst, index); + ctx.result.append("\n"); + } if (isBlockTerminator(inst.cmd)) { @@ -415,6 +447,13 @@ std::string dump(IrFunction& function) } } + return result; +} + +std::string dump(IrFunction& function) +{ + std::string result = toString(function, /* includeDetails */ true); + printf("%s\n", result.c_str()); return result; diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 62b5dea..03bb181 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -3,6 +3,7 @@ #include "Luau/CodeGen.h" #include "Luau/DenseHash.h" +#include "Luau/IrAnalysis.h" #include "Luau/IrDump.h" #include "Luau/IrUtils.h" @@ -30,6 +31,9 @@ IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, { freeGprMap.fill(true); freeXmmMap.fill(true); + + // In order to allocate registers during lowering, we need to know where instruction results are last used + updateLastUseLocations(function); } void IrLoweringX64::lower(AssemblyOptions options) @@ -93,6 +97,9 @@ void IrLoweringX64::lower(AssemblyOptions options) IrBlock& block = function.blocks[blockIndex]; LUAU_ASSERT(block.start != ~0u); + if (block.kind == IrBlockKind::Dead) + continue; + // If we want to skip fallback code IR/asm, we'll record when those blocks start once we see them if (block.kind == IrBlockKind::Fallback && !seenFallback) { @@ -102,7 +109,10 @@ void IrLoweringX64::lower(AssemblyOptions options) } if (options.includeIr) - build.logAppend("# %s_%u:\n", getBlockKindName(block.kind), blockIndex); + { + build.logAppend("# "); + toStringDetailed(ctx, block, uint32_t(i)); + } build.setLabel(block.label); @@ -179,6 +189,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(inst.regX64, luauRegTag(inst.a.index)); else if (inst.a.kind == IrOpKind::VmConst) build.mov(inst.regX64, luauConstantTag(inst.a.index)); + // If we have a register, we assume it's a pointer to TValue + // We might introduce explicit operand types in the future to make this more robust + else if (inst.a.kind == IrOpKind::Inst) + build.mov(inst.regX64, dword[regOp(inst.a) + offsetof(TValue, tt)]); else LUAU_ASSERT(!"Unsupported instruction form"); break; @@ -237,7 +251,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regX64 = allocGprRegOrReuse(SizeX64::qword, index, {inst.b}); - build.mov(dwordReg(inst.regX64), regOp(inst.b)); + if (dwordReg(inst.regX64) != regOp(inst.b)) + build.mov(dwordReg(inst.regX64), regOp(inst.b)); + build.shl(dwordReg(inst.regX64), kTValueSizeLog2); build.add(inst.regX64, qword[regOp(inst.a) + offsetof(Table, array)]); } @@ -442,7 +458,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } else { - LUAU_ASSERT(!"Unsupported instruction form"); + if (lhs != xmm0) + build.vmovsd(xmm0, lhs, lhs); + + build.vmovsd(xmm1, memRegDoubleOp(inst.b)); + build.call(qword[rNativeContext + offsetof(NativeContext, libm_pow)]); + + if (inst.regX64 != xmm0) + build.vmovsd(inst.regX64, xmm0, xmm0); } break; @@ -525,8 +548,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; } - case IrCmd::JUMP_EQ_BOOLEAN: - build.cmp(regOp(inst.a), boolOp(inst.b) ? 1 : 0); + case IrCmd::JUMP_EQ_INT: + build.cmp(regOp(inst.a), intOp(inst.b)); build.jcc(ConditionX64::Equal, labelOp(inst.c)); jumpOrFallthrough(blockOp(inst.d), next); @@ -576,7 +599,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(dwordReg(rArg2), uintOp(inst.a)); build.mov(dwordReg(rArg3), uintOp(inst.b)); build.call(qword[rNativeContext + offsetof(NativeContext, luaH_new)]); - build.mov(inst.regX64, rax); + + if (inst.regX64 != rax) + build.mov(inst.regX64, rax); break; case IrCmd::DUP_TABLE: inst.regX64 = allocGprReg(SizeX64::qword); @@ -585,7 +610,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(rArg2, regOp(inst.a)); build.mov(rArg1, rState); build.call(qword[rNativeContext + offsetof(NativeContext, luaH_clone)]); - build.mov(inst.regX64, rax); + + if (inst.regX64 != rax) + build.mov(inst.regX64, rax); break; case IrCmd::NUM_TO_INDEX: { @@ -596,6 +623,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) convertNumberToIndexOrJump(build, tmp.reg, regOp(inst.a), inst.regX64, labelOp(inst.b)); break; } + case IrCmd::INT_TO_NUM: + inst.regX64 = allocXmmReg(); + + build.vcvtsi2sd(inst.regX64, inst.regX64, regOp(inst.a)); + break; case IrCmd::DO_ARITH: LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); @@ -711,6 +743,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.setLabel(next); break; } + case IrCmd::PREPARE_FORN: + callPrepareForN(build, inst.a.index, inst.b.index, inst.c.index); + break; case IrCmd::CHECK_TAG: if (inst.a.kind == IrOpKind::Inst) { @@ -828,7 +863,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cmp(tmp2.reg, qword[tmp1.reg + offsetof(UpVal, v)]); build.jcc(ConditionX64::Above, next); - build.mov(rArg2, tmp2.reg); + if (rArg2 != tmp2.reg) + build.mov(rArg2, tmp2.reg); + build.mov(rArg1, rState); build.call(qword[rNativeContext + offsetof(NativeContext, luaF_close)]); @@ -843,6 +880,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::LOP_SETLIST: { const Instruction* pc = proto->code + uintOp(inst.a); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::Constant); + LUAU_ASSERT(inst.e.kind == IrOpKind::Constant); Label next; emitInstSetList(build, pc, next); @@ -852,13 +893,18 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::LOP_NAMECALL: { const Instruction* pc = proto->code + uintOp(inst.a); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - emitInstNameCall(build, pc, uintOp(inst.a), proto->k, blockOp(inst.b).label, blockOp(inst.c).label); + emitInstNameCall(build, pc, uintOp(inst.a), proto->k, blockOp(inst.d).label, blockOp(inst.e).label); break; } case IrCmd::LOP_CALL: { const Instruction* pc = proto->code + uintOp(inst.a); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + LUAU_ASSERT(inst.d.kind == IrOpKind::Constant); emitInstCall(build, helpers, pc, uintOp(inst.a)); break; @@ -866,27 +912,37 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::LOP_RETURN: { const Instruction* pc = proto->code + uintOp(inst.a); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); emitInstReturn(build, helpers, pc, uintOp(inst.a)); break; } case IrCmd::LOP_FASTCALL: - emitInstFastCall(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + + emitInstFastCall(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.d)); break; case IrCmd::LOP_FASTCALL1: - emitInstFastCall1(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + + emitInstFastCall1(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.d)); break; case IrCmd::LOP_FASTCALL2: - emitInstFastCall2(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg); + + emitInstFastCall2(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.e)); break; case IrCmd::LOP_FASTCALL2K: - emitInstFastCall2K(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); - break; - case IrCmd::LOP_FORNPREP: - emitInstForNPrep(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b), labelOp(inst.c)); - break; - case IrCmd::LOP_FORNLOOP: - emitInstForNLoop(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b), labelOp(inst.c)); + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + + emitInstFastCall2K(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.e)); break; case IrCmd::LOP_FORGLOOP: emitinstForGLoop(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b), labelOp(inst.c), labelOp(inst.d)); @@ -895,12 +951,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) emitinstForGLoopFallback(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); build.jmp(labelOp(inst.c)); break; - case IrCmd::LOP_FORGPREP_NEXT: - emitInstForGPrepNext(build, proto->code + uintOp(inst.a), labelOp(inst.b), labelOp(inst.c)); - break; - case IrCmd::LOP_FORGPREP_INEXT: - emitInstForGPrepInext(build, proto->code + uintOp(inst.a), labelOp(inst.b), labelOp(inst.c)); - break; case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: emitInstForGPrepXnextFallback(build, proto->code + uintOp(inst.a), uintOp(inst.a), labelOp(inst.b)); break; @@ -922,30 +972,59 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) // Full instruction fallbacks case IrCmd::FALLBACK_GETGLOBAL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + emitFallback(build, data, LOP_GETGLOBAL, uintOp(inst.a)); break; case IrCmd::FALLBACK_SETGLOBAL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + emitFallback(build, data, LOP_SETGLOBAL, uintOp(inst.a)); break; case IrCmd::FALLBACK_GETTABLEKS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + emitFallback(build, data, LOP_GETTABLEKS, uintOp(inst.a)); break; case IrCmd::FALLBACK_SETTABLEKS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + emitFallback(build, data, LOP_SETTABLEKS, uintOp(inst.a)); break; case IrCmd::FALLBACK_NAMECALL: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind == IrOpKind::VmConst); + emitFallback(build, data, LOP_NAMECALL, uintOp(inst.a)); break; case IrCmd::FALLBACK_PREPVARARGS: + LUAU_ASSERT(inst.b.kind == IrOpKind::Constant); + emitFallback(build, data, LOP_PREPVARARGS, uintOp(inst.a)); break; case IrCmd::FALLBACK_GETVARARGS: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + emitFallback(build, data, LOP_GETVARARGS, uintOp(inst.a)); break; case IrCmd::FALLBACK_NEWCLOSURE: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::Constant); + emitFallback(build, data, LOP_NEWCLOSURE, uintOp(inst.a)); break; case IrCmd::FALLBACK_DUPCLOSURE: + LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind == IrOpKind::VmConst); + emitFallback(build, data, LOP_DUPCLOSURE, uintOp(inst.a)); break; case IrCmd::FALLBACK_FORGPREP: @@ -1006,60 +1085,42 @@ OperandX64 IrLoweringX64::memRegTagOp(IrOp op) const RegisterX64 IrLoweringX64::regOp(IrOp op) const { - LUAU_ASSERT(op.kind == IrOpKind::Inst); - return function.instructions[op.index].regX64; + return function.instOp(op).regX64; } IrConst IrLoweringX64::constOp(IrOp op) const { - LUAU_ASSERT(op.kind == IrOpKind::Constant); - return function.constants[op.index]; + return function.constOp(op); } uint8_t IrLoweringX64::tagOp(IrOp op) const { - IrConst value = constOp(op); - - LUAU_ASSERT(value.kind == IrConstKind::Tag); - return value.valueTag; + return function.tagOp(op); } bool IrLoweringX64::boolOp(IrOp op) const { - IrConst value = constOp(op); - - LUAU_ASSERT(value.kind == IrConstKind::Bool); - return value.valueBool; + return function.boolOp(op); } int IrLoweringX64::intOp(IrOp op) const { - IrConst value = constOp(op); - - LUAU_ASSERT(value.kind == IrConstKind::Int); - return value.valueInt; + return function.intOp(op); } unsigned IrLoweringX64::uintOp(IrOp op) const { - IrConst value = constOp(op); - - LUAU_ASSERT(value.kind == IrConstKind::Uint); - return value.valueUint; + return function.uintOp(op); } double IrLoweringX64::doubleOp(IrOp op) const { - IrConst value = constOp(op); - - LUAU_ASSERT(value.kind == IrConstKind::Double); - return value.valueDouble; + return function.doubleOp(op); } IrBlock& IrLoweringX64::blockOp(IrOp op) const { - LUAU_ASSERT(op.kind == IrOpKind::Block); - return function.blocks[op.index]; + return function.blockOp(op); } Label& IrLoweringX64::labelOp(IrOp op) const @@ -1162,7 +1223,9 @@ void IrLoweringX64::freeLastUseReg(IrInst& target, uint32_t index) { if (target.lastUse == index && !target.reusedReg) { - LUAU_ASSERT(target.regX64 != noreg); + // Register might have already been freed if it had multiple uses inside a single instruction + if (target.regX64 == noreg) + return; freeReg(target.regX64); target.regX64 = noreg; diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 32958b5..fdbdf66 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -3,6 +3,9 @@ #include "Luau/Bytecode.h" #include "Luau/IrBuilder.h" +#include "Luau/IrUtils.h" + +#include "CustomExecUtils.h" #include "lobject.h" #include "ltm.h" @@ -215,7 +218,7 @@ void translateInstJumpxEqB(IrBuilder& build, const Instruction* pc, int pcpos) build.beginBlock(checkValue); IrOp va = build.inst(IrCmd::LOAD_INT, build.vmReg(ra)); - build.inst(IrCmd::JUMP_EQ_BOOLEAN, va, build.constBool(aux & 0x1), not_ ? next : target, not_ ? target : next); + build.inst(IrCmd::JUMP_EQ_INT, va, build.constInt(aux & 0x1), not_ ? next : target, not_ ? target : next); // Fallthrough in original bytecode is implicit, so we start next internal block here if (build.isInternalBlock(next)) @@ -238,7 +241,12 @@ void translateInstJumpxEqN(IrBuilder& build, const Instruction* pc, int pcpos) build.beginBlock(checkValue); IrOp va = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra)); - IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmConst(aux & 0xffffff)); + + LUAU_ASSERT(build.function.proto); + TValue protok = build.function.proto->k[aux & 0xffffff]; + + LUAU_ASSERT(protok.tt == LUA_TNUMBER); + IrOp vb = build.constDouble(protok.value.n); build.inst(IrCmd::JUMP_CMP_NUM, va, vb, build.cond(IrCondition::NotEqual), not_ ? target : next, not_ ? next : target); @@ -286,7 +294,20 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, } IrOp vb = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rb)); - IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, opc); + IrOp vc; + + if (opc.kind == IrOpKind::VmConst) + { + LUAU_ASSERT(build.function.proto); + TValue protok = build.function.proto->k[opc.index]; + + LUAU_ASSERT(protok.tt == LUA_TNUMBER); + vc = build.constDouble(protok.value.n); + } + else + { + vc = build.inst(IrCmd::LOAD_DOUBLE, opc); + } IrOp va; @@ -458,6 +479,209 @@ void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc) build.inst(IrCmd::CLOSE_UPVALS, build.vmReg(ra)); } +void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + + IrOp loopStart = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); + IrOp loopExit = build.blockAtInst(getJumpTarget(*pc, pcpos)); + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp nextStep = build.block(IrBlockKind::Internal); + IrOp direct = build.block(IrBlockKind::Internal); + IrOp reverse = build.block(IrBlockKind::Internal); + + IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0)); + build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), fallback); + IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); + build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), fallback); + IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); + build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), fallback); + build.inst(IrCmd::JUMP, nextStep); + + // After successful conversion of arguments to number in a fallback, we return here + build.beginBlock(nextStep); + + IrOp zero = build.constDouble(0.0); + IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); + IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); + IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + + // step <= 0 + build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::LessEqual), reverse, direct); + + // TODO: target branches can probably be arranged better, but we need tests for NaN behavior preservation + + // step <= 0 is false, check idx <= limit + build.beginBlock(direct); + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopStart, loopExit); + + // step <= 0 is true, check limit <= idx + build.beginBlock(reverse); + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopStart, loopExit); + + // Fallback will try to convert loop variables to numbers or throw an error + build.beginBlock(fallback); + build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); + build.inst(IrCmd::PREPARE_FORN, build.vmReg(ra + 0), build.vmReg(ra + 1), build.vmReg(ra + 2)); + build.inst(IrCmd::JUMP, nextStep); + + // Fallthrough in original bytecode is implicit, so we start next internal block here + if (build.isInternalBlock(loopStart)) + build.beginBlock(loopStart); +} + +void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + + IrOp loopRepeat = build.blockAtInst(getJumpTarget(*pc, pcpos)); + IrOp loopExit = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); + + build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); + + IrOp zero = build.constDouble(0.0); + IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); + IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); + + IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + idx = build.inst(IrCmd::ADD_NUM, idx, step); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 2), idx); + + IrOp direct = build.block(IrBlockKind::Internal); + IrOp reverse = build.block(IrBlockKind::Internal); + + // step <= 0 + build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::LessEqual), reverse, direct); + + // step <= 0 is false, check idx <= limit + build.beginBlock(direct); + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + + // step <= 0 is true, check limit <= idx + build.beginBlock(reverse); + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + + // Fallthrough in original bytecode is implicit, so we start next internal block here + if (build.isInternalBlock(loopExit)) + build.beginBlock(loopExit); +} + +void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + + IrOp target = build.blockAtInst(pcpos + 1 + LUAU_INSN_D(*pc)); + IrOp fallback = build.block(IrBlockKind::Fallback); + + // fast-path: pairs/next + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + IrOp tagB = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); + build.inst(IrCmd::CHECK_TAG, tagB, build.constTag(LUA_TTABLE), fallback); + IrOp tagC = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); + build.inst(IrCmd::CHECK_TAG, tagC, build.constTag(LUA_TNIL), fallback); + + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNIL)); + + // setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(ra + 2), build.constInt(0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 2), build.constTag(LUA_TLIGHTUSERDATA)); + + build.inst(IrCmd::JUMP, target); + + // FallbackStreamScope not used here because this instruction doesn't fallthrough to next instruction + build.beginBlock(fallback); + build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), target); +} + +void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + + IrOp target = build.blockAtInst(pcpos + 1 + LUAU_INSN_D(*pc)); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp finish = build.block(IrBlockKind::Internal); + + // fast-path: ipairs/inext + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + IrOp tagB = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); + build.inst(IrCmd::CHECK_TAG, tagB, build.constTag(LUA_TTABLE), fallback); + IrOp tagC = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); + build.inst(IrCmd::CHECK_TAG, tagC, build.constTag(LUA_TNUMBER), fallback); + + IrOp numC = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + build.inst(IrCmd::JUMP_CMP_NUM, numC, build.constDouble(0.0), build.cond(IrCondition::NotEqual), fallback, finish); + + build.beginBlock(finish); + + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNIL)); + + // setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + build.inst(IrCmd::STORE_INT, build.vmReg(ra + 2), build.constInt(0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 2), build.constTag(LUA_TLIGHTUSERDATA)); + + build.inst(IrCmd::JUMP, target); + + // FallbackStreamScope not used here because this instruction doesn't fallthrough to next instruction + build.beginBlock(fallback); + build.inst(IrCmd::LOP_FORGPREP_XNEXT_FALLBACK, build.constUint(pcpos), target); +} + +void translateInstForGLoopIpairs(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + LUAU_ASSERT(int(pc[1]) < 0); + + IrOp loopRepeat = build.blockAtInst(getJumpTarget(*pc, pcpos)); + IrOp loopExit = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); + IrOp fallback = build.block(IrBlockKind::Fallback); + + IrOp hasElem = build.block(IrBlockKind::Internal); + + build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); + + // fast-path: builtin table iteration + IrOp tagA = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra)); + build.inst(IrCmd::CHECK_TAG, tagA, build.constTag(LUA_TNIL), fallback); + + IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(ra + 1)); + IrOp index = build.inst(IrCmd::LOAD_INT, build.vmReg(ra + 2)); + + IrOp elemPtr = build.inst(IrCmd::GET_ARR_ADDR, table, index); + + // Terminate if array has ended + build.inst(IrCmd::CHECK_ARRAY_SIZE, table, index, loopExit); + + // Terminate if element is nil + IrOp elemTag = build.inst(IrCmd::LOAD_TAG, elemPtr); + build.inst(IrCmd::JUMP_EQ_TAG, elemTag, build.constTag(LUA_TNIL), loopExit, hasElem); + build.beginBlock(hasElem); + + IrOp nextIndex = build.inst(IrCmd::ADD_INT, index, build.constInt(1)); + + // We update only a dword part of the userdata pointer that's reused in loop iteration as an index + // Upper bits start and remain to be 0 + build.inst(IrCmd::STORE_INT, build.vmReg(ra + 2), nextIndex); + // Tag should already be set to lightuserdata + + // setnvalue(ra + 3, double(index + 1)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 3), build.inst(IrCmd::INT_TO_NUM, nextIndex)); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 3), build.constTag(LUA_TNUMBER)); + + // setobj2s(L, ra + 4, e); + IrOp elemTV = build.inst(IrCmd::LOAD_TVALUE, elemPtr); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra + 4), elemTV); + + build.inst(IrCmd::JUMP, loopRepeat); + + build.beginBlock(fallback); + build.inst(IrCmd::LOP_FORGLOOP_FALLBACK, build.constUint(pcpos), loopRepeat, loopExit); + + // Fallthrough in original bytecode is implicit, so we start next internal block here + if (build.isInternalBlock(loopExit)) + build.beginBlock(loopExit); +} + void translateInstGetTableN(IrBuilder& build, const Instruction* pc, int pcpos) { int ra = LUAU_INSN_A(*pc); @@ -654,7 +878,7 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) IrOp next = build.blockAtInst(pcpos + 2); FallbackStreamScope scope(build, fallback, next); - build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos)); + build.inst(IrCmd::FALLBACK_GETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); build.inst(IrCmd::JUMP, next); } @@ -685,7 +909,7 @@ void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) IrOp next = build.blockAtInst(pcpos + 2); FallbackStreamScope scope(build, fallback, next); - build.inst(IrCmd::FALLBACK_SETTABLEKS, build.constUint(pcpos)); + build.inst(IrCmd::FALLBACK_SETTABLEKS, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); build.inst(IrCmd::JUMP, next); } @@ -708,7 +932,7 @@ void translateInstGetGlobal(IrBuilder& build, const Instruction* pc, int pcpos) IrOp next = build.blockAtInst(pcpos + 2); FallbackStreamScope scope(build, fallback, next); - build.inst(IrCmd::FALLBACK_GETGLOBAL, build.constUint(pcpos)); + build.inst(IrCmd::FALLBACK_GETGLOBAL, build.constUint(pcpos), build.vmReg(ra), build.vmConst(aux)); build.inst(IrCmd::JUMP, next); } @@ -734,7 +958,7 @@ void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos) IrOp next = build.blockAtInst(pcpos + 2); FallbackStreamScope scope(build, fallback, next); - build.inst(IrCmd::FALLBACK_SETGLOBAL, build.constUint(pcpos)); + build.inst(IrCmd::FALLBACK_SETGLOBAL, build.constUint(pcpos), build.vmReg(ra), build.vmConst(aux)); build.inst(IrCmd::JUMP, next); } diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index 53030a2..6ffc911 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -42,6 +42,11 @@ void translateInstDupTable(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstGetUpval(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstSetUpval(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc); +void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstForGPrepNext(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstForGPrepInext(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstForGLoopIpairs(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstGetTableN(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstSetTableN(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstGetTable(IrBuilder& build, const Instruction* pc, int pcpos); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp new file mode 100644 index 0000000..0c1a896 --- /dev/null +++ b/CodeGen/src/IrUtils.cpp @@ -0,0 +1,133 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/IrUtils.h" + +namespace Luau +{ +namespace CodeGen +{ + +static uint32_t getBlockEnd(IrFunction& function, uint32_t start) +{ + uint32_t end = start; + + // Find previous block terminator + while (!isBlockTerminator(function.instructions[end].cmd)) + end++; + + return end; +} + +static void addUse(IrFunction& function, IrOp op) +{ + if (op.kind == IrOpKind::Inst) + function.instructions[op.index].useCount++; + else if (op.kind == IrOpKind::Block) + function.blocks[op.index].useCount++; +} + +static void removeUse(IrFunction& function, IrOp op) +{ + if (op.kind == IrOpKind::Inst) + removeUse(function, function.instructions[op.index]); + else if (op.kind == IrOpKind::Block) + removeUse(function, function.blocks[op.index]); +} + +void kill(IrFunction& function, IrInst& inst) +{ + LUAU_ASSERT(inst.useCount == 0); + + inst.cmd = IrCmd::NOP; + + removeUse(function, inst.a); + removeUse(function, inst.b); + removeUse(function, inst.c); + removeUse(function, inst.d); + removeUse(function, inst.e); +} + +void kill(IrFunction& function, uint32_t start, uint32_t end) +{ + // Kill instructions in reverse order to avoid killing instructions that are still marked as used + for (int i = int(end); i >= int(start); i--) + { + IrInst& curr = function.instructions[i]; + + if (curr.cmd == IrCmd::NOP) + continue; + + kill(function, curr); + } +} + +void kill(IrFunction& function, IrBlock& block) +{ + LUAU_ASSERT(block.useCount == 0); + + block.kind = IrBlockKind::Dead; + + uint32_t start = block.start; + uint32_t end = getBlockEnd(function, start); + + kill(function, start, end); +} + +void removeUse(IrFunction& function, IrInst& inst) +{ + LUAU_ASSERT(inst.useCount); + inst.useCount--; + + if (inst.useCount == 0) + kill(function, inst); +} + +void removeUse(IrFunction& function, IrBlock& block) +{ + LUAU_ASSERT(block.useCount); + block.useCount--; + + if (block.useCount == 0) + kill(function, block); +} + +void replace(IrFunction& function, IrOp& original, IrOp replacement) +{ + // Add use before removing new one if that's the last one keeping target operand alive + addUse(function, replacement); + removeUse(function, original); + + original = replacement; +} + +void replace(IrFunction& function, uint32_t instIdx, IrInst replacement) +{ + IrInst& inst = function.instructions[instIdx]; + IrCmd prevCmd = inst.cmd; + + // Add uses before removing new ones if those are the last ones keeping target operand alive + addUse(function, replacement.a); + addUse(function, replacement.b); + addUse(function, replacement.c); + addUse(function, replacement.d); + addUse(function, replacement.e); + + removeUse(function, inst.a); + removeUse(function, inst.b); + removeUse(function, inst.c); + removeUse(function, inst.d); + removeUse(function, inst.e); + + inst = replacement; + + // If we introduced an earlier terminating instruction, all following instructions become dead + if (!isBlockTerminator(prevCmd) && isBlockTerminator(inst.cmd)) + { + uint32_t start = instIdx + 1; + uint32_t end = getBlockEnd(function, start); + + kill(function, start, end); + } +} + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/OptimizeFinalX64.cpp b/CodeGen/src/OptimizeFinalX64.cpp new file mode 100644 index 0000000..57f9a5c --- /dev/null +++ b/CodeGen/src/OptimizeFinalX64.cpp @@ -0,0 +1,111 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/OptimizeFinalX64.h" + +#include "Luau/IrUtils.h" + +#include + +namespace Luau +{ +namespace CodeGen +{ + +// x64 assembly allows memory operands, but IR separates loads from uses +// To improve final x64 lowering, we try to 'inline' single-use register/constant loads into some of our instructions +// This pass might not be useful on different architectures +static void optimizeMemoryOperandsX64(IrFunction& function, IrBlock& block) +{ + LUAU_ASSERT(block.kind != IrBlockKind::Dead); + + for (uint32_t index = block.start; true; index++) + { + LUAU_ASSERT(index < function.instructions.size()); + IrInst& inst = function.instructions[index]; + + switch (inst.cmd) + { + case IrCmd::CHECK_TAG: + { + if (inst.a.kind == IrOpKind::Inst) + { + IrInst& tag = function.instOp(inst.a); + + if (tag.useCount == 1 && tag.cmd == IrCmd::LOAD_TAG && (tag.a.kind == IrOpKind::VmReg || tag.a.kind == IrOpKind::VmConst)) + replace(function, inst.a, tag.a); + } + break; + } + case IrCmd::ADD_NUM: + case IrCmd::SUB_NUM: + case IrCmd::MUL_NUM: + case IrCmd::DIV_NUM: + case IrCmd::MOD_NUM: + case IrCmd::POW_NUM: + { + if (inst.b.kind == IrOpKind::Inst) + { + IrInst& rhs = function.instOp(inst.b); + + if (rhs.useCount == 1 && rhs.cmd == IrCmd::LOAD_DOUBLE && (rhs.a.kind == IrOpKind::VmReg || rhs.a.kind == IrOpKind::VmConst)) + replace(function, inst.b, rhs.a); + } + break; + } + case IrCmd::JUMP_EQ_TAG: + { + if (inst.a.kind == IrOpKind::Inst) + { + IrInst& tagA = function.instOp(inst.a); + + if (tagA.useCount == 1 && tagA.cmd == IrCmd::LOAD_TAG && (tagA.a.kind == IrOpKind::VmReg || tagA.a.kind == IrOpKind::VmConst)) + { + replace(function, inst.a, tagA.a); + break; + } + } + + if (inst.b.kind == IrOpKind::Inst) + { + IrInst& tagB = function.instOp(inst.b); + + if (tagB.useCount == 1 && tagB.cmd == IrCmd::LOAD_TAG && (tagB.a.kind == IrOpKind::VmReg || tagB.a.kind == IrOpKind::VmConst)) + { + std::swap(inst.a, inst.b); + replace(function, inst.a, tagB.a); + } + } + break; + } + case IrCmd::JUMP_CMP_NUM: + { + if (inst.a.kind == IrOpKind::Inst) + { + IrInst& num = function.instOp(inst.a); + + if (num.useCount == 1 && num.cmd == IrCmd::LOAD_DOUBLE) + replace(function, inst.a, num.a); + } + break; + } + default: + break; + } + + if (isBlockTerminator(inst.cmd)) + break; + } +} + +void optimizeMemoryOperandsX64(IrFunction& function) +{ + for (IrBlock& block : function.blocks) + { + if (block.kind == IrBlockKind::Dead) + continue; + + optimizeMemoryOperandsX64(function, block); + } +} + +} // namespace CodeGen +} // namespace Luau diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index 94dce41..a14cc1e 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -13,6 +13,7 @@ inline bool isFlagExperimental(const char* flag) static const char* kList[] = { "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code "LuauTryhardAnd", // waiting for a fix in graphql-lua -> apollo-client-lia -> lua-apps + "LuauTypecheckTypeguards", // requires some fixes to lua-apps code (CLI-67030) // makes sure we always have at least one entry nullptr, }; diff --git a/Sources.cmake b/Sources.cmake index 815301b..aef55e6 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -70,6 +70,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/IrUtils.h CodeGen/include/Luau/Label.h CodeGen/include/Luau/OperandX64.h + CodeGen/include/Luau/OptimizeFinalX64.h CodeGen/include/Luau/RegisterA64.h CodeGen/include/Luau/RegisterX64.h CodeGen/include/Luau/UnwindBuilder.h @@ -92,7 +93,9 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/src/IrDump.cpp CodeGen/src/IrLoweringX64.cpp CodeGen/src/IrTranslation.cpp + CodeGen/src/IrUtils.cpp CodeGen/src/NativeState.cpp + CodeGen/src/OptimizeFinalX64.cpp CodeGen/src/UnwindBuilderDwarf2.cpp CodeGen/src/UnwindBuilderWin.cpp @@ -337,6 +340,7 @@ if(TARGET Luau.UnitTest) tests/DenseHash.test.cpp tests/Error.test.cpp tests/Frontend.test.cpp + tests/IrBuilder.test.cpp tests/JsonEmitter.test.cpp tests/Lexer.test.cpp tests/Linter.test.cpp diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index a69965e..1d31b28 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -10,8 +10,6 @@ #include -LUAU_FASTFLAG(LuauScopelessModule) - using namespace Luau; namespace @@ -145,8 +143,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "real_source") TEST_CASE_FIXTURE(FrontendFixture, "automatically_check_dependent_scripts") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; fileResolver.source["game/Gui/Modules/B"] = R"( local Modules = game:GetService('Gui').Modules @@ -224,8 +220,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "any_annotation_breaks_cycle") TEST_CASE_FIXTURE(FrontendFixture, "nocheck_modules_are_typed") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - fileResolver.source["game/Gui/Modules/A"] = R"( --!nocheck export type Foo = number @@ -281,8 +275,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_detection_between_check_and_nocheck") TEST_CASE_FIXTURE(FrontendFixture, "nocheck_cycle_used_by_checked") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - fileResolver.source["game/Gui/Modules/A"] = R"( --!nocheck local Modules = game:GetService('Gui').Modules @@ -501,8 +493,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "dont_recheck_script_that_hasnt_been_marked_d TEST_CASE_FIXTURE(FrontendFixture, "recheck_if_dependent_script_is_dirty") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; fileResolver.source["game/Gui/Modules/B"] = R"( local Modules = game:GetService('Gui').Modules diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp new file mode 100644 index 0000000..4ed8728 --- /dev/null +++ b/tests/IrBuilder.test.cpp @@ -0,0 +1,223 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/IrBuilder.h" +#include "Luau/IrAnalysis.h" +#include "Luau/IrDump.h" +#include "Luau/OptimizeFinalX64.h" + +#include "doctest.h" + +using namespace Luau::CodeGen; + +class IrBuilderFixture +{ +public: + IrBuilder build; +}; + +TEST_SUITE_BEGIN("Optimization"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptCheckTag") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + IrOp tag1 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::CHECK_TAG, tag1, build.constTag(0), fallback); + IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmConst(5)); + build.inst(IrCmd::CHECK_TAG, tag2, build.constTag(0), fallback); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(fallback); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + updateUseCounts(build.function); + optimizeMemoryOperandsX64(build.function); + + // Load from memory is 'inlined' into CHECK_TAG + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + CHECK_TAG R2, tnil, bb_fallback_1 + CHECK_TAG K5, tnil, bb_fallback_1 + LOP_RETURN 0u + +bb_fallback_1: + LOP_RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptBinaryArith") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + IrOp opA = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)); + IrOp opB = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)); + build.inst(IrCmd::ADD_NUM, opA, opB); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + updateUseCounts(build.function); + optimizeMemoryOperandsX64(build.function); + + // Load from memory is 'inlined' into second argument + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + %0 = LOAD_DOUBLE R1 + %2 = ADD_NUM %0, R2 + LOP_RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag1") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp trueBlock = build.block(IrBlockKind::Internal); + IrOp falseBlock = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + IrOp opA = build.inst(IrCmd::LOAD_TAG, build.vmReg(1)); + IrOp opB = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::JUMP_EQ_TAG, opA, opB, trueBlock, falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(trueBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + updateUseCounts(build.function); + optimizeMemoryOperandsX64(build.function); + + // Load from memory is 'inlined' into first argument + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + %1 = LOAD_TAG R2 + JUMP_EQ_TAG R1, %1, bb_1, bb_2 + +bb_1: + LOP_RETURN 0u + +bb_2: + LOP_RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag2") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp trueBlock = build.block(IrBlockKind::Internal); + IrOp falseBlock = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + IrOp opA = build.inst(IrCmd::LOAD_TAG, build.vmReg(1)); + IrOp opB = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::STORE_TAG, build.vmReg(6), opA); + build.inst(IrCmd::JUMP_EQ_TAG, opA, opB, trueBlock, falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(trueBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + updateUseCounts(build.function); + optimizeMemoryOperandsX64(build.function); + + // Load from memory is 'inlined' into second argument is it can't be done for the first one + // We also swap first and second argument to generate memory access on the LHS + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + %0 = LOAD_TAG R1 + STORE_TAG R6, %0 + JUMP_EQ_TAG R2, %0, bb_1, bb_2 + +bb_1: + LOP_RETURN 0u + +bb_2: + LOP_RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag3") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp trueBlock = build.block(IrBlockKind::Internal); + IrOp falseBlock = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + IrOp arrElem = build.inst(IrCmd::GET_ARR_ADDR, table, build.constUint(0)); + IrOp opA = build.inst(IrCmd::LOAD_TAG, arrElem); + build.inst(IrCmd::JUMP_EQ_TAG, opA, build.constTag(0), trueBlock, falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(trueBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + updateUseCounts(build.function); + optimizeMemoryOperandsX64(build.function); + + // Load from memory is 'inlined' into first argument + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + %0 = LOAD_POINTER R1 + %1 = GET_ARR_ADDR %0, 0u + %2 = LOAD_TAG %1 + JUMP_EQ_TAG %2, tnil, bb_1, bb_2 + +bb_1: + LOP_RETURN 0u + +bb_2: + LOP_RETURN 0u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptJumpCmpNum") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp trueBlock = build.block(IrBlockKind::Internal); + IrOp falseBlock = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + IrOp opA = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)); + IrOp opB = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)); + build.inst(IrCmd::JUMP_CMP_NUM, opA, opB, trueBlock, falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(trueBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + build.beginBlock(falseBlock); + build.inst(IrCmd::LOP_RETURN, build.constUint(0)); + + updateUseCounts(build.function); + optimizeMemoryOperandsX64(build.function); + + // Load from memory is 'inlined' into first argument + CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( +bb_0: + %1 = LOAD_DOUBLE R2 + JUMP_CMP_NUM R1, %1, bb_1, bb_2 + +bb_1: + LOP_RETURN 0u + +bb_2: + LOP_RETURN 0u + +)"); +} + +TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 34c2e8f..8557913 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -112,8 +112,6 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - CheckResult result = check(R"( return {sign=math.sign} )"); @@ -285,8 +283,6 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TEST_CASE_FIXTURE(Fixture, "any_persistance_does_not_leak") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - fileResolver.source["Module/A"] = R"( export type A = B type B = A @@ -310,7 +306,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_reexports") {"LuauSubstitutionReentrant", true}, {"LuauClassTypeVarsInSubstitution", true}, {"LuauSubstitutionFixMissingFields", true}, - {"LuauScopelessModule", true}, }; fileResolver.source["Module/A"] = R"( @@ -349,7 +344,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_types_of_reexported_values") {"LuauSubstitutionReentrant", true}, {"LuauClassTypeVarsInSubstitution", true}, {"LuauSubstitutionFixMissingFields", true}, - {"LuauScopelessModule", true}, }; fileResolver.source["Module/A"] = R"( diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index 5deeb35..28c5bba 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -253,8 +253,6 @@ TEST_CASE_FIXTURE(Fixture, "delay_function_does_not_require_its_argument_to_retu TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - CheckResult result = check(R"( --!nonstrict diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 3e98367..d5f9537 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -465,8 +465,6 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_always_resolve_to_a_real_type") TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - CheckResult result = check(R"( export type A = {field: number} @@ -498,8 +496,6 @@ TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - CheckResult result = check(R"( export type Array = { [number]: T } )"); @@ -527,8 +523,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly") TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definitions") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - CheckResult result = check(R"( export type Record = { name: string, location: string } local a: Record = { name="Waldo", location="?????" } diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 71586f9..683469a 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -109,8 +109,6 @@ TEST_CASE_FIXTURE(Fixture, "vararg_functions_should_allow_calls_of_any_types_and TEST_CASE_FIXTURE(BuiltinsFixture, "vararg_function_is_quantified") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - CheckResult result = check(R"( local T = {} function T.f(...) diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 3861a8b..d7b0bdb 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -282,8 +282,14 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_methods") function x:f(): string return self:id("hello") end function x:g(): number return self:id(37) end )"); - // TODO: Quantification should be doing the conversion, not normalization. - LUAU_REQUIRE_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + LUAU_REQUIRE_NO_ERRORS(result); + else + { + // TODO: Quantification should be doing the conversion, not normalization. + LUAU_REQUIRE_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") @@ -296,8 +302,14 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") local y: number = self:id(37) end )"); - // TODO: Should typecheck but currently errors CLI-39916 - LUAU_REQUIRE_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + LUAU_REQUIRE_NO_ERRORS(result); + else + { + // TODO: Should typecheck but currently errors CLI-39916 + LUAU_REQUIRE_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "infer_generic_property") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 7d629f7..c389f32 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -514,8 +514,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators") // Ideally, we would not try to export a function type with generic types from incorrect scope TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - fileResolver.source["game/A"] = R"( local wrapStrictTable @@ -555,8 +553,6 @@ return wrapStrictTable(Constants, "Constants") TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface_variadic") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - fileResolver.source["game/A"] = R"( local wrapStrictTable diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 43c0b38..fb44ec4 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -36,25 +36,26 @@ std::optional> magicFunctionInstanceIsA( return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; } -std::vector dcrMagicRefinementInstanceIsA(const MagicRefinementContext& ctx) +void dcrMagicRefinementInstanceIsA(const MagicRefinementContext& ctx) { - if (ctx.callSite->args.size != 1) - return {}; + if (ctx.callSite->args.size != 1 || ctx.discriminantTypes.empty()) + return; auto index = ctx.callSite->func->as(); auto str = ctx.callSite->args.data[0]->as(); if (!index || !str) - return {}; + return; - std::optional def = ctx.dfg->getDef(index->expr); - if (!def) - return {}; + std::optional discriminantTy = ctx.discriminantTypes[0]; + if (!discriminantTy) + return; std::optional tfun = ctx.scope->lookupType(std::string(str->value.data, str->value.size)); if (!tfun) - return {}; + return; - return {ctx.refinementArena->proposition(*def, tfun->type)}; + LUAU_ASSERT(get(*discriminantTy)); + asMutable(*discriminantTy)->ty.emplace(tfun->type); } struct RefinementClassFixture : BuiltinsFixture @@ -1491,4 +1492,45 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknown_to_table_then_take_the_length CHECK_EQ("table", toString(requireTypeAtPosition({3, 29}))); } +TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_during_constraint_solving_stage") +{ + CheckResult result = check(R"( + type Id = T + + local function f(x: Id | Id>) + if typeof(x) ~= "string" and x:IsA("Part") then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("Folder | string", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_during_constraint_solving_stage_2") +{ + CheckResult result = check(R"( + local function hof(f: (Instance) -> ()) end + + hof(function(inst) + if inst:IsA("Part") then + local foo = inst + else + local foo = inst + end + end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("Instance & ~Part", toString(requireTypeAtPosition({7, 28}))); + else + CHECK_EQ("Instance", toString(requireTypeAtPosition({7, 28}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index e3c1ab1..fcebd1f 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -18,6 +18,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) +LUAU_FASTFLAG(LuauDontExtendUnsealedRValueTables) TEST_SUITE_BEGIN("TableTests"); @@ -628,7 +629,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too") const TableIndexer& indexer = *ttv->indexer; - REQUIRE_EQ(indexer.indexType, typeChecker.numberType); + REQUIRE("number" == toString(indexer.indexType)); REQUIRE(nullptr != get(follow(indexer.indexResultType))); } @@ -869,6 +870,51 @@ TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_ CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); } +TEST_CASE_FIXTURE(Fixture, "any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode") +{ + CheckResult result = check(R"( + --!nonstrict + + local constants = { + key1 = "value1", + key2 = "value2" + } + + local function getKey() + return "key1" + end + + local k1 = constants[getKey()] + )"); + + CHECK("any" == toString(requireType("k1"))); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode") +{ + CheckResult result = check(R"( + local constants = { + key1 = "value1", + key2 = "value2" + } + + function getConstant(key) + return constants[key] + end + + local k1 = getConstant("key1") + )"); + + if (FFlag::LuauDontExtendUnsealedRValueTables) + CHECK("any" == toString(requireType("k1"))); + else + CHECK("a" == toString(requireType("k1"))); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_should_infer_new_properties_over_indexer") { CheckResult result = check(R"( @@ -2967,8 +3013,6 @@ TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the // The real bug here was that we weren't always uncondionally typechecking a trailing return statement last. TEST_CASE_FIXTURE(BuiltinsFixture, "dont_leak_free_table_props") { - ScopedFastFlag luauScopelessModule{"LuauScopelessModule", true}; - CheckResult result = check(R"( local function a(state) print(state.blah) @@ -3493,4 +3537,59 @@ _ = {_,} LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type") +{ + ScopedFastFlag sff{"LuauDontExtendUnsealedRValueTables", true}; + + CheckResult result = check(R"( + local events = {} + local mockObserveEvent = function(_, key, callback) + events[key] = callback + end + + events['FriendshipNotifications']({ + EventArgs = { + UserId2 = '2' + }, + Type = 'FriendshipDeclined' + }) + )"); + + TypeId ty = follow(requireType("events")); + const TableType* tt = get(ty); + REQUIRE_MESSAGE(tt, "Expected table but got " << toString(ty, {true})); + + CHECK(tt->props.empty()); + REQUIRE(tt->indexer); + + CHECK("string" == toString(tt->indexer->indexType)); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_extend_unsealed_tables_in_rvalue_position") +{ + ScopedFastFlag sff{"LuauDontExtendUnsealedRValueTables", true}; + + CheckResult result = check(R"( + local testDictionary = { + FruitName = "Lemon", + FruitColor = "Yellow", + Sour = true + } + + local print: any + + print(testDictionary[""]) + )"); + + TypeId ty = follow(requireType("testDictionary")); + const TableType* ttv = get(ty); + REQUIRE(ttv); + + CHECK(0 == ttv->props.count("")); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 4f0afc3..16797ee 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1158,4 +1158,18 @@ end LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "typechecking_in_type_guards") +{ + ScopedFastFlag sff{"LuauTypecheckTypeguards", true}; + + CheckResult result = check(R"( +local a = type(foo) == 'nil' +local b = typeof(foo) ~= 'nil' + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == "Unknown global 'foo'"); + CHECK(toString(result.errors[1]) == "Unknown global 'foo'"); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index 3766687..565982c 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -11,13 +11,9 @@ AstQuery::getDocumentationSymbolAtPosition.overloaded_fn AstQuery::getDocumentationSymbolAtPosition.table_overloaded_function_prop AutocompleteTest.autocomplete_first_function_arg_expected_type AutocompleteTest.autocomplete_oop_implicit_self +AutocompleteTest.autocomplete_string_singleton_equality AutocompleteTest.do_compatible_self_calls AutocompleteTest.do_wrong_compatible_self_calls -AutocompleteTest.keyword_methods -AutocompleteTest.no_incompatible_self_calls -AutocompleteTest.no_wrong_compatible_self_calls_with_generics -AutocompleteTest.string_singleton_as_table_key -AutocompleteTest.suggest_table_keys AutocompleteTest.type_correct_expected_argument_type_pack_suggestion AutocompleteTest.type_correct_expected_argument_type_suggestion_self AutocompleteTest.type_correct_expected_return_type_pack_suggestion @@ -60,7 +56,6 @@ DefinitionTests.single_class_type_identity_in_global_types FrontendTest.environments FrontendTest.nocheck_cycle_used_by_checked FrontendTest.reexport_cyclic_type -FrontendTest.trace_requires_in_nonstrict_mode GenericsTests.apply_type_function_nested_generics1 GenericsTests.apply_type_function_nested_generics2 GenericsTests.better_mismatch_error_messages @@ -126,6 +121,7 @@ ProvisionalTests.table_insert_with_a_singleton_argument ProvisionalTests.typeguard_inference_incomplete ProvisionalTests.weirditer_should_not_loop_forever RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string +RefinementTest.discriminate_tag RefinementTest.else_with_no_explicit_expression_should_also_refine_the_tagged_union RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil RefinementTest.narrow_property_of_a_bounded_variable @@ -136,23 +132,19 @@ RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector RefinementTest.typeguard_in_assert_position RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table -RefinementTest.x_is_not_instance_or_else_not_part RuntimeLimits.typescript_port_of_Result_type TableTests.a_free_shape_can_turn_into_a_scalar_directly TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.accidentally_checked_prop_in_opposite_branch -TableTests.builtin_table_names +TableTests.any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode TableTests.call_method -TableTests.call_method_with_explicit_self_argument TableTests.casting_tables_with_props_into_table_with_indexer3 TableTests.casting_tables_with_props_into_table_with_indexer4 TableTests.checked_prop_too_early -TableTests.defining_a_method_for_a_local_unsealed_table_is_ok -TableTests.defining_a_self_method_for_a_local_unsealed_table_is_ok +TableTests.disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index TableTests.dont_quantify_table_that_belongs_to_outer_scope -TableTests.dont_seal_an_unsealed_table_by_passing_it_to_a_function_that_takes_a_sealed_table TableTests.dont_suggest_exact_match_keys TableTests.error_detailed_metatable_prop TableTests.expected_indexer_from_table_union @@ -175,7 +167,6 @@ TableTests.infer_array_2 TableTests.inferred_return_type_of_free_table TableTests.inferring_crazy_table_should_also_be_quick TableTests.instantiate_table_cloning_3 -TableTests.instantiate_tables_at_scope_level TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.leaking_bad_metatable_errors @@ -191,7 +182,6 @@ TableTests.oop_polymorphic TableTests.open_table_unification_2 TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table -TableTests.quantifying_a_bound_var_works TableTests.reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table TableTests.result_is_always_any_if_lhs_is_any TableTests.result_is_bool_for_equality_operators_if_lhs_is_any @@ -200,7 +190,6 @@ TableTests.shared_selfs TableTests.shared_selfs_from_free_param TableTests.shared_selfs_through_metatables TableTests.table_call_metamethod_basic -TableTests.table_function_check_use_after_free TableTests.table_indexing_error_location TableTests.table_insert_should_cope_with_optional_properties_in_nonstrict TableTests.table_insert_should_cope_with_optional_properties_in_strict @@ -209,16 +198,11 @@ TableTests.table_simple_call TableTests.table_subtyping_with_extra_props_dont_report_multiple_errors TableTests.table_subtyping_with_missing_props_dont_report_multiple_errors TableTests.table_unification_4 -TableTests.tc_member_function TableTests.tc_member_function_2 -TableTests.unifying_tables_shouldnt_uaf1 TableTests.unifying_tables_shouldnt_uaf2 -TableTests.used_colon_correctly TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon -TableTests.used_dot_instead_of_colon_but_correctly ToString.exhaustive_toString_of_cyclic_table -ToString.function_type_with_argument_names_and_self ToString.function_type_with_argument_names_generic ToString.named_metatable_toStringNamedFunction ToString.toStringDetailed2 @@ -238,7 +222,6 @@ TryUnifyTests.variadics_should_use_reversed_properly TypeAliases.cannot_create_cyclic_type_with_unknown_module TypeAliases.corecursive_types_generic TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any -TypeAliases.forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2 TypeAliases.generic_param_remap TypeAliases.mismatched_generic_type_param TypeAliases.mutually_recursive_types_errors @@ -256,7 +239,6 @@ TypeInfer.checking_should_not_ice TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError TypeInfer.dont_report_type_errors_within_an_AstStatError -TypeInfer.follow_on_new_types_in_substitution TypeInfer.fuzz_free_table_type_change_during_index_check TypeInfer.globals TypeInfer.globals2 @@ -281,9 +263,7 @@ TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.cannot_hoist_interior_defns_into_signature TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site -TypeInferFunctions.dont_mutate_the_underlying_head_of_typepack_when_calling_with_self TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict -TypeInferFunctions.first_argument_can_be_optional TypeInferFunctions.function_cast_error_uses_correct_language TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_unsealed_overwrite @@ -309,7 +289,6 @@ TypeInferFunctions.too_few_arguments_variadic_generic2 TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function -TypeInferFunctions.vararg_function_is_quantified TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_with_generic_next @@ -318,22 +297,15 @@ TypeInferLoops.loop_iter_no_indexer_nonstrict TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.unreachable_code_after_infinite_loop -TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free -TypeInferModules.bound_free_table_export_is_ok TypeInferModules.custom_require_global TypeInferModules.do_not_modify_imported_types_4 -TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated -TypeInferModules.require_a_variadic_function TypeInferModules.type_error_of_unknown_qualified_type TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_another_overload_works TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory -TypeInferOOP.method_depends_on_table TypeInferOOP.methods_are_topologically_sorted -TypeInferOOP.nonstrict_self_mismatch_tail TypeInferOOP.object_constructor_can_refer_to_method_of_self -TypeInferOOP.table_oop TypeInferOperators.CallAndOrOfFunctions TypeInferOperators.CallOrOfFunctions TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable @@ -345,7 +317,6 @@ TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_ TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown TypeInferOperators.operator_eq_completely_incompatible -TypeInferOperators.produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs TypeInferOperators.UnknownGlobalCompoundAssign @@ -358,7 +329,6 @@ TypeInferUnknownNever.dont_unify_operands_if_one_of_the_operand_is_never_in_any_ TypeInferUnknownNever.math_operators_and_never TypePackTests.detect_cyclic_typepacks2 TypePackTests.pack_tail_unification_check -TypePackTests.self_and_varargs_should_work TypePackTests.type_alias_backwards_compatible TypePackTests.type_alias_default_export TypePackTests.type_alias_default_mixed_self diff --git a/tools/flag-bisect.py b/tools/flag-bisect.py new file mode 100644 index 0000000..01f3ef7 --- /dev/null +++ b/tools/flag-bisect.py @@ -0,0 +1,458 @@ +#!/usr/bin/python3 +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +import argparse +import asyncio +import copy +import json +import math +import os +import platform +import re +import subprocess +import sys +import textwrap +from enum import Enum + +def add_parser(subparsers): + flag_bisect_command = subparsers.add_parser('flag-bisect', + help=help(), + description=help(), + epilog=epilog(), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + add_argument_parsers(flag_bisect_command) + flag_bisect_command.set_defaults(func=flag_bisect_main) + return flag_bisect_command + +def help(): + return 'Search for a set of flags triggering the faulty behavior in unit tests' + +def get_terminal_width(): + try: + return os.get_terminal_size().columns + except: + # Return a reasonable default when a terminal is not available + return 80 +def wrap_text(text, width): + leading_whitespace_re = re.compile('( *)') + + def get_paragraphs_and_indent(string): + lines = string.split('\n') + result = '' + line_count = 0 + initial_indent = '' + subsequent_indent = '' + for line in lines: + if len(line.strip()) == 0: + if line_count > 0: + yield result, initial_indent, subsequent_indent + result = '' + line_count = 0 + else: + line_count += 1 + if line_count == 1: + initial_indent = leading_whitespace_re.match(line).group(1) + subsequent_indent = initial_indent + elif line_count == 2: + subsequent_indent = leading_whitespace_re.match(line).group(1) + result += line.strip() + '\n' + + result = '' + for paragraph, initial_indent, subsequent_indent in get_paragraphs_and_indent(text): + result += textwrap.fill(paragraph, width=width, initial_indent=initial_indent, subsequent_indent=subsequent_indent, break_on_hyphens=False) + '\n\n' + return result + +def wrap_text_for_terminal(text): + right_margin = 2 # This margin matches what argparse uses when formatting argument documentation + min_width = 20 + width = max(min_width, get_terminal_width() - right_margin) + return wrap_text(text, width) + +def epilog(): + return wrap_text_for_terminal(''' + This tool uses the delta debugging algorithm to minimize the set of flags to the ones that are faulty in your unit tests, + and the usage is trivial. Just provide a path to the unit test and you're done, the tool will do the rest. + + There are many use cases with flag-bisect. Included but not limited to: + + 1: If your test is failing when you omit `--fflags=true` but it works when passing `--fflags=true`, then you can + use this tool to find that set of flag requirements to see which flags are missing that will help to fix it. Ditto + for the opposite too, this tool is generalized for that case. + + 2: If you happen to run into a problem on production, and you're not sure which flags is the problem and you can easily + create a unit test, you can run flag-bisect on that unit test to rapidly find the set of flags. + + 3: If you have a flag that causes a performance regression, there's also the `--timeout=N` where `N` is in seconds. + + 4: If you have tests that are demonstrating flakiness behavior, you can also use `--tries=N` where `N` is the number of + attempts to run the same set of flags before moving on to the new set. This will eventually drill down to the flaky flag(s). + Generally 8 tries should be more than enough, but it depends on the rarity. The more rare it is, the higher the attempts count + needs to be. Note that this comes with a performance cost the higher you go, but certainly still faster than manual search. + This argument will disable parallel mode by default. If this is not desired, explicitly write `--parallel=on`. + + 5: By default flag-bisect runs in parallel mode which uses a slightly modified version of delta debugging algorithm to support + trying multiple sets of flags concurrently. This means that the number of sets the algorithm will try at once is equal to the + number of concurrent jobs. There is currently no upper bound to that, so heed this warning that your machine may slow down + significantly. In this mode, we display the number of jobs it is running in parallel. Use `--parallel=off` to disable parallel + mode. + + Be aware that this introduces some level of *non-determinism*, and it is fundamental due to the interaction with flag dependencies + and the fact one job may finish faster than another job that got ran in the same cycle. However, it generally shouldn't matter + if your test is deterministic and has no implicit flag dependencies in the codebase. + + The tool will try to automatically figure out which of `--pass` or `--fail` to use if you omit them or use `--auto` by applying + heuristics. For example, if the tests works using `--fflags=true` and crashes if omitting `--fflags=true`, then it knows + to use `--pass` to give you set of flags that will cause that crash. As usual, vice versa is also true. Since this is a + heuristic, if it gets that guess wrong, you can override with `--pass` or `--fail`. + + You can speed this process up by scoping it to as few tests as possible, for example if you're using doctest then you'd + pass `--tc=my_test` as an argument after `--`, so `flag-bisect ./path/to/binary -- --tc=my_test`. + ''') + +class InterestnessMode(Enum): + AUTO = 0, + FAIL = 1, + PASS = 2, + +def add_argument_parsers(parser): + parser.add_argument('binary_path', help='Path to the unit test binary that will be bisected for a set of flags') + + parser.add_argument('--tries', dest='attempts', type=int, default=1, metavar='N', + help='If the tests are flaky, flag-bisect will try again with the same set by N amount of times before moving on') + + parser.add_argument('--parallel', dest='parallel', choices=['on', 'off'], default='default', + help='Test multiple sets of flags in parallel, useful when the test takes a while to run.') + + parser.add_argument('--explicit', dest='explicit', action='store_true', default=False, help='Explicitly set flags to false') + + parser.add_argument('--filter', dest='filter', default=None, help='Regular expression to filter for a subset of flags to test') + + parser.add_argument('--verbose', dest='verbose', action='store_true', default=False, help='Show stdout and stderr of the program being run') + + interestness_parser = parser.add_mutually_exclusive_group() + interestness_parser.add_argument('--auto', dest='mode', action='store_const', const=InterestnessMode.AUTO, + default=InterestnessMode.AUTO, help='Automatically figure out which one of --pass or --fail should be used') + interestness_parser.add_argument('--fail', dest='mode', action='store_const', const=InterestnessMode.FAIL, + help='You want this if omitting --fflags=true causes tests to fail') + interestness_parser.add_argument('--pass', dest='mode', action='store_const', const=InterestnessMode.PASS, + help='You want this if passing --fflags=true causes tests to pass') + interestness_parser.add_argument('--timeout', dest='timeout', type=int, default=0, metavar='SECONDS', + help='Find the flag(s) causing performance regression if time to run exceeds the timeout in seconds') + +class Options: + def __init__(self, args, other_args, sense): + self.path = args.binary_path + self.explicit = args.explicit + self.sense = sense + self.timeout = args.timeout + self.interested_in_timeouts = args.timeout != 0 + self.attempts = args.attempts + self.parallel = (args.parallel == 'on' or args.parallel == 'default') if args.attempts == 1 else args.parallel == 'on' + self.filter = re.compile(".*" + args.filter + ".*") if args.filter else None + self.verbose = args.verbose + self.other_args = [arg for arg in other_args if arg != '--'] # Useless to have -- here, discard. + + def copy_with_sense(self, sense): + new_copy = copy.copy(self) + new_copy.sense = sense + return new_copy + +class InterestnessResult(Enum): + FAIL = 0, + PASS = 1, + TIMED_OUT = 2, + +class Progress: + def __init__(self, count, n_of_jobs=None): + self.count = count + self.steps = 0 + self.n_of_jobs = n_of_jobs + self.buffer = None + + def show(self): + # remaining is actually the height of the current search tree. + remain = int(math.log2(self.count)) + flag_plural = 'flag' if self.count == 1 else 'flags' + node_plural = 'node' if remain == 1 else 'nodes' + jobs_info = f', running {self.n_of_jobs} jobs' if self.n_of_jobs is not None else '' + return f'flag bisection: testing {self.count} {flag_plural} (step {self.steps}, {remain} {node_plural} remain{jobs_info})' + + def hide(self): + if self.buffer: + sys.stdout.write('\b \b' * len(self.buffer)) + + def update(self, len, n_of_jobs=None): + self.hide() + self.count = len + self.steps += 1 + self.n_of_jobs = n_of_jobs + self.buffer = self.show() + sys.stdout.write(self.buffer) + sys.stdout.flush() + +def list_fflags(options): + try: + out = subprocess.check_output([options.path, '--list-fflags'], encoding='UTF-8') + flag_names = [] + + # It's unlikely that a program we're going to test has no flags. + # So if the output doesn't start with FFlag, assume it doesn't support --list-fflags and therefore cannot be bisected. + if not out.startswith('FFlag') and not out.startswith('DFFlag') and not out.startswith('SFFlag'): + return None + + flag_names = out.split('\n')[:-1] + + subset = [flag for flag in flag_names if options.filter.match(flag) is not None] if options.filter else flag_names + return subset if subset else None + except: + return None + +def mk_flags_argument(options, flags, initial_flags): + lst = [flag + '=true' for flag in flags] + + # When --explicit is provided, we'd like to find the set of flags from initial_flags that's not in active flags. + # This is so that we can provide a =false value instead of leaving them out to be the default value. + if options.explicit: + for flag in initial_flags: + if flag not in flags: + lst.append(flag + '=false') + + return '--fflags=' + ','.join(lst) + +def mk_command_line(options, flags_argument): + arguments = [options.path, *options.other_args] + if flags_argument is not None: + arguments.append(flags_argument) + return arguments + +async def get_interestness(options, flags_argument): + try: + timeout = options.timeout if options.interested_in_timeouts else None + cmd = mk_command_line(options, flags_argument) + stdout = subprocess.PIPE if not options.verbose else None + stderr = subprocess.PIPE if not options.verbose else None + process = subprocess.run(cmd, stdout=stdout, stderr=stderr, timeout=timeout) + return InterestnessResult.PASS if process.returncode == 0 else InterestnessResult.FAIL + except subprocess.TimeoutExpired: + return InterestnessResult.TIMED_OUT + +async def is_hot(options, flags_argument, pred=any): + results = await asyncio.gather(*[get_interestness(options, flags_argument) for _ in range(options.attempts)]) + + if options.interested_in_timeouts: + return pred([InterestnessResult.TIMED_OUT == x for x in results]) + else: + return pred([(InterestnessResult.PASS if options.sense else InterestnessResult.FAIL) == x for x in results]) + +def pairwise_disjoints(flags, granularity): + offset = 0 + per_slice_len = len(flags) // granularity + while offset < len(flags): + yield flags[offset:offset + per_slice_len] + offset += per_slice_len + +def subsets_and_complements(flags, granularity): + for disjoint_set in pairwise_disjoints(flags, granularity): + yield disjoint_set, [flag for flag in flags if flag not in disjoint_set] + +# https://www.cs.purdue.edu/homes/xyzhang/fall07/Papers/delta-debugging.pdf +async def ddmin(options, initial_flags): + current = initial_flags + granularity = 2 + + progress = Progress(len(current)) + progress.update(len(current)) + + while len(current) >= 2: + changed = False + + for (subset, complement) in subsets_and_complements(current, granularity): + progress.update(len(current)) + if await is_hot(options, mk_flags_argument(options, complement, initial_flags)): + current = complement + granularity = max(granularity - 1, 2) + changed = True + break + elif await is_hot(options, mk_flags_argument(options, subset, initial_flags)): + current = subset + granularity = 2 + changed = True + break + + if not changed: + if granularity == len(current): + break + granularity = min(granularity * 2, len(current)) + + progress.hide() + return current + +async def ddmin_parallel(options, initial_flags): + current = initial_flags + granularity = 2 + + progress = Progress(len(current)) + progress.update(len(current), granularity) + + while len(current) >= 2: + changed = False + + subset_jobs = [] + complement_jobs = [] + + def advance(task): + nonlocal current + nonlocal granularity + nonlocal changed + # task.cancel() calls the callback passed to add_done_callback... + if task.cancelled(): + return + hot, new_delta, new_granularity = task.result() + if hot and not changed: + current = new_delta + granularity = new_granularity + changed = True + for job in subset_jobs: + job.cancel() + for job in complement_jobs: + job.cancel() + + for (subset, complement) in subsets_and_complements(current, granularity): + async def work(flags, new_granularity): + hot = await is_hot(options, mk_flags_argument(options, flags, initial_flags)) + return (hot, flags, new_granularity) + + # We want to run subset jobs in parallel first. + subset_job = asyncio.create_task(work(subset, 2)) + subset_job.add_done_callback(advance) + subset_jobs.append(subset_job) + + # Then the complements afterwards, but only if we didn't find a new subset. + complement_job = asyncio.create_task(work(complement, max(granularity - 1, 2))) + complement_job.add_done_callback(advance) + complement_jobs.append(complement_job) + + # When we cancel jobs, the asyncio.gather will be waiting pointlessly. + # In that case, we'd like to return the control to this routine. + await asyncio.gather(*subset_jobs, return_exceptions=True) + if not changed: + await asyncio.gather(*complement_jobs, return_exceptions=True) + progress.update(len(current), granularity) + + if not changed: + if granularity == len(current): + break + granularity = min(granularity * 2, len(current)) + + progress.hide() + return current + +def search(options, initial_flags): + if options.parallel: + return ddmin_parallel(options, initial_flags) + else: + return ddmin(options, initial_flags) + +async def do_work(args, other_args): + sense = None + + # If --timeout isn't used, try to apply a heuristic to figure out which of --pass or --fail we want. + if args.timeout == 0 and args.mode == InterestnessMode.AUTO: + inner_options = Options(args, other_args, sense) + + # We aren't interested in timeout for this heuristic. It just makes no sense to assume timeouts. + # This actually cannot happen by this point, but if we make timeout a non-exclusive switch to --auto, this will go wrong. + inner_options.timeout = 0 + inner_options.interested_in_timeouts = False + + all_tasks = asyncio.gather( + is_hot(inner_options.copy_with_sense(True), '--fflags=true', all), + is_hot(inner_options.copy_with_sense(False), '--fflags=false' if inner_options.explicit else None, all), + ) + + # If it times out, we can print a message saying that this is still working. We intentionally want to continue doing work. + done, pending = await asyncio.wait([all_tasks], timeout=1.5) + if all_tasks not in done: + print('Hang on! I\'m running your program to try and figure out which of --pass or --fail to use!') + print('Need to find out faster? Cancel the work and explicitly write --pass or --fail') + + is_pass_hot, is_fail_hot = await all_tasks + + # This is a bit counter-intuitive, but the following table tells us which of the sense we want. + # Because when you omit --fflags=true argument and it fails, then is_fail_hot is True. + # Consequently, you need to use --pass to find out what that set of flags is. And vice versa. + # + # Also, when is_pass_hot is True and is_fail_hot is False, then that program is working as expected. + # There should be no reason to run flag bisection. + # However, this can be ambiguous in the opposite of the aforementioned outcome! + # + # is_pass_hot | is_fail_hot | is ambiguous? + #-------------|-------------|--------------- + # True | True | No! Pick --pass. + # False | False | No! Pick --fail. + # True | False | No! But this is the exact situation where you shouldn't need to flag-bisect. Raise an error. + # False | True | Yes! But we'll pragmatically pick --fail here in the hope it gives the correct set of flags. + + if is_pass_hot and not is_fail_hot: + print('The tests seems to be working fine for me. If you really need to flag-bisect, please try again with an explicit --pass or --fail', file=sys.stderr) + return 1 + + if not is_pass_hot and is_fail_hot: + print('I couldn\'t quite figure out which of --pass or --fail to use, but I\'ll carry on anyway') + + sense = is_pass_hot + argument = '--pass' if sense else '--fail' + print(f'I\'m bisecting flags as if {argument} was used') + else: + sense = True if args.mode == InterestnessMode.PASS else False + + options = Options(args, other_args, sense) + + initial_flags = list_fflags(options) + if initial_flags is None: + print('I cannot bisect flags with ' + options.path, file=sys.stderr) + print('These are required for me to be able to cooperate:', file=sys.stderr) + print('\t--list-fflags must print a list of flags separated by newlines, including FFlag prefix', file=sys.stderr) + print('\t--fflags=... to accept a comma-separated pair of flag names and their value in the form FFlagFoo=true', file=sys.stderr) + return 1 + + # On Windows, there is an upper bound on the numbers of characters for a command line incantation. + # If we don't handle this ourselves, the runtime error is going to look nothing like the actual problem. + # It'd say "file name way too long" or something to that effect. We can teed up a better error message and + # tell the user how to work around it by using --filter. + if platform.system() == 'Windows': + cmd_line = ' '.join(mk_command_line(options, mk_flags_argument(options, initial_flags, []))) + if len(cmd_line) >= 8191: + print(f'Never mind! The command line is too long because we have {len(initial_flags)} flags to test', file=sys.stderr) + print('Consider using `--filter=` to narrow it down upfront, or use any version of WSL instead', file=sys.stderr) + return 1 + + hot_flags = await search(options, initial_flags) + if hot_flags: + print('I narrowed down to these flags:') + print(textwrap.indent('\n'.join(hot_flags), prefix='\t')) + + # If we showed the command line in explicit mode, all flags would be listed here. + # This would pollute the terminal with 3000 flags. We don't want that. Don't show it. + # Ditto for when the number flags we bisected are equal. + if not options.explicit and len(hot_flags) != len(initial_flags): + print('$ ' + ' '.join(mk_command_line(options, mk_flags_argument(options, hot_flags, initial_flags)))) + + return 0 + + print('I found nothing, sorry', file=sys.stderr) + return 1 + +def flag_bisect_main(args, other_args): + return asyncio.run(do_work(args, other_args)) + +def main(): + parser = argparse.ArgumentParser(description=help(), epilog=epilog(), formatter_class=argparse.RawTextHelpFormatter) + add_argument_parsers(parser) + args, other_args = parser.parse_known_args() + return flag_bisect_main(args, other_args) + +if __name__ == '__main__': + sys.exit(main())