diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index e79c4c9..2044704 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -2,12 +2,13 @@ #pragma once #include "Luau/Ast.h" -#include "Luau/Refinement.h" #include "Luau/Constraint.h" +#include "Luau/ControlFlow.h" #include "Luau/DataFlowGraph.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" #include "Luau/NotNull.h" +#include "Luau/Refinement.h" #include "Luau/Symbol.h" #include "Luau/Type.h" #include "Luau/Variant.h" @@ -141,26 +142,26 @@ struct ConstraintGraphBuilder */ void visit(AstStatBlock* block); - void visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); + ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); - void visit(const ScopePtr& scope, AstStat* stat); - void visit(const ScopePtr& scope, AstStatBlock* block); - void visit(const ScopePtr& scope, AstStatLocal* local); - void visit(const ScopePtr& scope, AstStatFor* for_); - void visit(const ScopePtr& scope, AstStatForIn* forIn); - void visit(const ScopePtr& scope, AstStatWhile* while_); - void visit(const ScopePtr& scope, AstStatRepeat* repeat); - void visit(const ScopePtr& scope, AstStatLocalFunction* function); - void visit(const ScopePtr& scope, AstStatFunction* function); - void visit(const ScopePtr& scope, AstStatReturn* ret); - void visit(const ScopePtr& scope, AstStatAssign* assign); - void visit(const ScopePtr& scope, AstStatCompoundAssign* assign); - void visit(const ScopePtr& scope, AstStatIf* ifStatement); - void visit(const ScopePtr& scope, AstStatTypeAlias* alias); - void visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal); - void visit(const ScopePtr& scope, AstStatDeclareClass* declareClass); - void visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); - void visit(const ScopePtr& scope, AstStatError* error); + ControlFlow visit(const ScopePtr& scope, AstStat* stat); + ControlFlow visit(const ScopePtr& scope, AstStatBlock* block); + ControlFlow visit(const ScopePtr& scope, AstStatLocal* local); + ControlFlow visit(const ScopePtr& scope, AstStatFor* for_); + ControlFlow visit(const ScopePtr& scope, AstStatForIn* forIn); + ControlFlow visit(const ScopePtr& scope, AstStatWhile* while_); + ControlFlow visit(const ScopePtr& scope, AstStatRepeat* repeat); + ControlFlow visit(const ScopePtr& scope, AstStatLocalFunction* function); + ControlFlow visit(const ScopePtr& scope, AstStatFunction* function); + ControlFlow visit(const ScopePtr& scope, AstStatReturn* ret); + ControlFlow visit(const ScopePtr& scope, AstStatAssign* assign); + ControlFlow visit(const ScopePtr& scope, AstStatCompoundAssign* assign); + ControlFlow visit(const ScopePtr& scope, AstStatIf* ifStatement); + ControlFlow visit(const ScopePtr& scope, AstStatTypeAlias* alias); + ControlFlow visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal); + ControlFlow visit(const ScopePtr& scope, AstStatDeclareClass* declareClass); + ControlFlow visit(const ScopePtr& scope, AstStatDeclareFunction* declareFunction); + ControlFlow visit(const ScopePtr& scope, AstStatError* error); InferencePack checkPack(const ScopePtr& scope, AstArray exprs, const std::vector>& expectedTypes = {}); InferencePack checkPack(const ScopePtr& scope, AstExpr* expr, const std::vector>& expectedTypes = {}); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 4fd7d0d..e9e1e88 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -143,6 +143,14 @@ struct ConstraintSolver bool block(TypeId target, NotNull constraint); bool block(TypePackId target, NotNull constraint); + /** + * For all constraints that are blocked on one constraint, make them block + * on a new constraint. + * @param source the constraint to copy blocks from. + * @param addition the constraint that other constraints should now block on. + */ + void inheritBlocks(NotNull source, NotNull addition); + // Traverse the type. If any blocked or pending types are found, block // the constraint on them. // diff --git a/Analysis/include/Luau/ControlFlow.h b/Analysis/include/Luau/ControlFlow.h new file mode 100644 index 0000000..8272bd5 --- /dev/null +++ b/Analysis/include/Luau/ControlFlow.h @@ -0,0 +1,36 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +namespace Luau +{ + +struct Scope; +using ScopePtr = std::shared_ptr; + +enum class ControlFlow +{ + None = 0b00001, + Returns = 0b00010, + Throws = 0b00100, + Break = 0b01000, // Currently unused. + Continue = 0b10000, // Currently unused. +}; + +inline ControlFlow operator&(ControlFlow a, ControlFlow b) +{ + return ControlFlow(int(a) & int(b)); +} + +inline ControlFlow operator|(ControlFlow a, ControlFlow b) +{ + return ControlFlow(int(a) | int(b)); +} + +inline bool matches(ControlFlow a, ControlFlow b) +{ + return (a & b) != ControlFlow(0); +} + +} // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 9c0366a..68ba8ff 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -89,14 +89,21 @@ struct FrontendOptions // order to get more precise type information) bool forAutocomplete = false; + bool runLintChecks = false; + // If not empty, randomly shuffle the constraint set before attempting to // solve. Use this value to seed the random number generator. std::optional randomizeConstraintResolutionSeed; + + std::optional enabledLintWarnings; }; struct CheckResult { std::vector errors; + + LintResult lintResult; + std::vector timeoutHits; }; @@ -133,8 +140,9 @@ struct Frontend CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess - LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); - LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); + // Use 'check' with 'runLintChecks' set to true in FrontendOptions (enabledLintWarnings be set there as well) + LintResult lint_DEPRECATED(const ModuleName& name, std::optional enabledLintWarnings = {}); + LintResult lint_DEPRECATED(const SourceModule& module, std::optional enabledLintWarnings = {}); bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 2faa029..72f8760 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Error.h" +#include "Luau/Linter.h" #include "Luau/FileResolver.h" #include "Luau/ParseOptions.h" #include "Luau/ParseResult.h" @@ -88,6 +89,7 @@ struct Module std::unordered_map declaredGlobals; ErrorVec errors; + LintResult lintResult; Mode mode; SourceCode::Type type; bool timeout = false; diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 15dc7d4..1540470 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -19,6 +19,8 @@ using ModulePtr = std::shared_ptr; bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); +bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); +bool isConsistentSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); class TypeIds { @@ -203,7 +205,7 @@ struct NormalizedFunctionType }; // A normalized generic/free type is a union, where each option is of the form (X & T) where -// * X is either a free type or a generic +// * X is either a free type, a generic or a blocked type. // * T is a normalized type. struct NormalizedType; using NormalizedTyvars = std::unordered_map>; @@ -214,7 +216,7 @@ bool isInhabited_DEPRECATED(const NormalizedType& norm); // * P is a union of primitive types (including singletons, classes and the error type) // * T is a union of table types // * F is a union of an intersection of function types -// * G is a union of generic/free normalized types, intersected with a normalized type +// * G is a union of generic/free/blocked types, intersected with a normalized type struct NormalizedType { // The top part of the type. diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 0d39726..745ea47 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -66,6 +66,7 @@ struct Scope RefinementMap refinements; DenseHashMap dcrRefinements{nullptr}; + void inheritRefinements(const ScopePtr& childScope); // For mutually recursive type aliases, it's important that // they use the same types for the same names. diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index ef2d4c6..dba2a8d 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -102,7 +102,7 @@ struct BlockedType BlockedType(); int index; - static int nextIndex; + static int DEPRECATED_nextIndex; }; struct PrimitiveType diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 21cb263..6816179 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -2,14 +2,15 @@ #pragma once #include "Luau/Anyification.h" -#include "Luau/Predicate.h" +#include "Luau/ControlFlow.h" #include "Luau/Error.h" #include "Luau/Module.h" -#include "Luau/Symbol.h" +#include "Luau/Predicate.h" #include "Luau/Substitution.h" +#include "Luau/Symbol.h" #include "Luau/TxnLog.h" -#include "Luau/TypePack.h" #include "Luau/Type.h" +#include "Luau/TypePack.h" #include "Luau/Unifier.h" #include "Luau/UnifierSharedState.h" @@ -87,28 +88,28 @@ struct TypeChecker std::vector> getScopes() const; - void check(const ScopePtr& scope, const AstStat& statement); - void check(const ScopePtr& scope, const AstStatBlock& statement); - void check(const ScopePtr& scope, const AstStatIf& statement); - void check(const ScopePtr& scope, const AstStatWhile& statement); - void check(const ScopePtr& scope, const AstStatRepeat& statement); - void check(const ScopePtr& scope, const AstStatReturn& return_); - void check(const ScopePtr& scope, const AstStatAssign& assign); - void check(const ScopePtr& scope, const AstStatCompoundAssign& assign); - void check(const ScopePtr& scope, const AstStatLocal& local); - void check(const ScopePtr& scope, const AstStatFor& local); - void check(const ScopePtr& scope, const AstStatForIn& forin); - void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); - void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); - void check(const ScopePtr& scope, const AstStatTypeAlias& typealias); - void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); - void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); + ControlFlow check(const ScopePtr& scope, const AstStat& statement); + ControlFlow check(const ScopePtr& scope, const AstStatBlock& statement); + ControlFlow check(const ScopePtr& scope, const AstStatIf& statement); + ControlFlow check(const ScopePtr& scope, const AstStatWhile& statement); + ControlFlow check(const ScopePtr& scope, const AstStatRepeat& statement); + ControlFlow check(const ScopePtr& scope, const AstStatReturn& return_); + ControlFlow check(const ScopePtr& scope, const AstStatAssign& assign); + ControlFlow check(const ScopePtr& scope, const AstStatCompoundAssign& assign); + ControlFlow check(const ScopePtr& scope, const AstStatLocal& local); + ControlFlow check(const ScopePtr& scope, const AstStatFor& local); + ControlFlow check(const ScopePtr& scope, const AstStatForIn& forin); + ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); + ControlFlow check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); + ControlFlow check(const ScopePtr& scope, const AstStatTypeAlias& typealias); + ControlFlow check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); + ControlFlow check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); void prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0); void prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); - void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); - void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); + ControlFlow checkBlock(const ScopePtr& scope, const AstStatBlock& statement); + ControlFlow checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); WithPredicate checkExpr( diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index 15e501f..9c4f013 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -81,6 +81,8 @@ namespace Luau::Unifiable using Name = std::string; +int freshIndex(); + struct Free { explicit Free(TypeLevel level); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index fc886ac..e7817e5 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -58,6 +58,7 @@ struct Unifier NotNull scope; // const Scope maybe TxnLog log; + bool failure = false; ErrorVec errors; Location location; Variance variance = Covariant; @@ -93,7 +94,7 @@ private: // Traverse the two types provided and block on any BlockedTypes we find. // Returns true if any types were blocked on. - bool blockOnBlockedTypes(TypeId subTy, TypeId superTy); + bool DEPRECATED_blockOnBlockedTypes(TypeId subTy, TypeId superTy); void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall); void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionType* uv); diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 711d357..e90cb7d 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -5,6 +5,7 @@ #include "Luau/Breadcrumb.h" #include "Luau/Common.h" #include "Luau/Constraint.h" +#include "Luau/ControlFlow.h" #include "Luau/DcrLogger.h" #include "Luau/ModuleResolver.h" #include "Luau/RecursionCounter.h" @@ -22,6 +23,7 @@ LUAU_FASTFLAG(LuauNegatedClassTypes); namespace Luau { +bool doesCallError(const AstExprCall* call); // TypeInfer.cpp const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp static std::optional matchRequire(const AstExprCall& call) @@ -344,14 +346,14 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) logger->captureGenerationModule(module); } -void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) +ControlFlow ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) { RecursionCounter counter{&recursionCount}; if (recursionCount >= FInt::LuauCheckRecursionLimit) { reportCodeTooComplex(block->location); - return; + return ControlFlow::None; } std::unordered_map aliasDefinitionLocations; @@ -396,59 +398,77 @@ void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope, } } + std::optional firstControlFlow; for (AstStat* stat : block->body) - visit(scope, stat); + { + ControlFlow cf = visit(scope, stat); + if (cf != ControlFlow::None && !firstControlFlow) + firstControlFlow = cf; + } + + return firstControlFlow.value_or(ControlFlow::None); } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat) { RecursionLimiter limiter{&recursionCount, FInt::LuauCheckRecursionLimit}; if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto i = stat->as()) - visit(scope, i); + return visit(scope, i); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (stat->is() || stat->is()) { // Nothing + return ControlFlow::None; // TODO: ControlFlow::Break/Continue } else if (auto r = stat->as()) - visit(scope, r); + return visit(scope, r); else if (auto e = stat->as()) + { checkPack(scope, e->expr); + + if (auto call = e->expr->as(); call && doesCallError(call)) + return ControlFlow::Throws; + + return ControlFlow::None; + } else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto a = stat->as()) - visit(scope, a); + return visit(scope, a); else if (auto a = stat->as()) - visit(scope, a); + return visit(scope, a); else if (auto f = stat->as()) - visit(scope, f); + return visit(scope, f); else if (auto f = stat->as()) - visit(scope, f); + return visit(scope, f); else if (auto a = stat->as()) - visit(scope, a); + return visit(scope, a); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else if (auto s = stat->as()) - visit(scope, s); + return visit(scope, s); else + { LUAU_ASSERT(0 && "Internal error: Unknown AstStat type"); + return ControlFlow::None; + } } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) { std::vector varTypes; varTypes.reserve(local->vars.size); @@ -534,7 +554,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) } } - if (local->vars.size == 1 && local->values.size == 1 && firstValueType) + if (local->vars.size == 1 && local->values.size == 1 && firstValueType && scope.get() == rootScope) { AstLocal* var = local->vars.data[0]; AstExpr* value = local->values.data[0]; @@ -592,9 +612,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) } } } + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) { TypeId annotationTy = builtinTypes->numberType; if (for_->var->annotation) @@ -619,9 +641,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_) forScope->dcrRefinements[bc->def] = annotationTy; visit(forScope, for_->body); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) { ScopePtr loopScope = childScope(forIn, scope); @@ -645,27 +669,33 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) addConstraint(loopScope, getLocation(forIn->values), IterableConstraint{iterator, variablePack}); visit(loopScope, forIn->body); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatWhile* while_) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatWhile* while_) { check(scope, while_->condition); ScopePtr whileScope = childScope(while_, scope); visit(whileScope, while_->body); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatRepeat* repeat) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatRepeat* repeat) { ScopePtr repeatScope = childScope(repeat, scope); visitBlockWithoutChildScope(repeatScope, repeat->body); check(repeatScope, repeat->condition); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* function) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* function) { // Local // Global @@ -699,9 +729,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* }); addConstraint(scope, std::move(c)); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* function) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* function) { // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. // With or without self @@ -779,9 +811,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct }); addConstraint(scope, std::move(c)); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) { // At this point, the only way scope->returnType should have anything // interesting in it is if the function has an explicit return annotation. @@ -793,13 +827,18 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret) TypePackId exprTypes = checkPack(scope, ret->list, expectedTypes).tp; addConstraint(scope, ret->location, PackSubtypeConstraint{exprTypes, scope->returnType}); + + return ControlFlow::Returns; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) { ScopePtr innerScope = childScope(block, scope); - visitBlockWithoutChildScope(innerScope, block); + ControlFlow flow = visitBlockWithoutChildScope(innerScope, block); + scope->inheritRefinements(innerScope); + + return flow; } static void bindFreeType(TypeId a, TypeId b) @@ -819,7 +858,7 @@ static void bindFreeType(TypeId a, TypeId b) asMutable(b)->ty.emplace(a); } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) { std::vector varTypes = checkLValues(scope, assign->vars); @@ -839,9 +878,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign) TypePackId varPack = arena->addTypePack({varTypes}); addConstraint(scope, assign->location, PackSubtypeConstraint{exprPack, varPack}); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) { // We need to tweak the BinaryConstraint that we emit, so we cannot use the // strategy of falsifying an AST fragment. @@ -852,23 +893,34 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* addConstraint(scope, assign->location, BinaryConstraint{assign->op, varTy, valueTy, resultType, assign, &module->astOriginalCallTypes, &module->astOverloadResolvedTypes}); addConstraint(scope, assign->location, SubtypeConstraint{resultType, varTy}); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) { - ScopePtr condScope = childScope(ifStatement->condition, scope); - RefinementId refinement = check(condScope, ifStatement->condition, std::nullopt).refinement; + RefinementId refinement = check(scope, ifStatement->condition, std::nullopt).refinement; ScopePtr thenScope = childScope(ifStatement->thenbody, scope); applyRefinements(thenScope, ifStatement->condition->location, refinement); - visit(thenScope, ifStatement->thenbody); + ScopePtr elseScope = childScope(ifStatement->elsebody ? ifStatement->elsebody : ifStatement, scope); + applyRefinements(elseScope, ifStatement->elseLocation.value_or(ifStatement->condition->location), refinementArena.negation(refinement)); + + ControlFlow thencf = visit(thenScope, ifStatement->thenbody); + ControlFlow elsecf = ControlFlow::None; if (ifStatement->elsebody) - { - ScopePtr elseScope = childScope(ifStatement->elsebody, scope); - applyRefinements(elseScope, ifStatement->elseLocation.value_or(ifStatement->condition->location), refinementArena.negation(refinement)); - visit(elseScope, ifStatement->elsebody); - } + elsecf = visit(elseScope, ifStatement->elsebody); + + if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && elsecf == ControlFlow::None) + scope->inheritRefinements(elseScope); + else if (thencf == ControlFlow::None && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + scope->inheritRefinements(thenScope); + + if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + return ControlFlow::Returns; + else + return ControlFlow::None; } static bool occursCheck(TypeId needle, TypeId haystack) @@ -890,7 +942,7 @@ static bool occursCheck(TypeId needle, TypeId haystack) return false; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alias) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alias) { ScopePtr* defnScope = astTypeAliasDefiningScopes.find(alias); @@ -904,7 +956,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alia // case we just skip over it. auto bindingIt = typeBindings->find(alias->name.value); if (bindingIt == typeBindings->end() || defnScope == nullptr) - return; + return ControlFlow::None; TypeId ty = resolveType(*defnScope, alias->type, /* inTypeArguments */ false); @@ -935,9 +987,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatTypeAlias* alia std::move(typeParams), std::move(typePackParams), }); + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* global) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* global) { LUAU_ASSERT(global->type); @@ -949,6 +1003,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareGlobal* BreadcrumbId bc = dfg->getBreadcrumb(global); rootScope->dcrRefinements[bc->def] = globalTy; + + return ControlFlow::None; } static bool isMetamethod(const Name& name) @@ -958,7 +1014,7 @@ static bool isMetamethod(const Name& name) name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) { std::optional superTy = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; if (declaredClass->superName) @@ -969,7 +1025,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d if (!lookupType) { reportError(declaredClass->location, UnknownSymbol{superName, UnknownSymbol::Type}); - return; + return ControlFlow::None; } // We don't have generic classes, so this assertion _should_ never be hit. @@ -981,7 +1037,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d reportError(declaredClass->location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass->name.value)}); - return; + return ControlFlow::None; } } @@ -1056,9 +1112,11 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* d } } } + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction* global) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction* global) { std::vector> generics = createGenerics(scope, global->generics); std::vector> genericPacks = createGenericPacks(scope, global->genericPacks); @@ -1097,14 +1155,18 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction BreadcrumbId bc = dfg->getBreadcrumb(global); rootScope->dcrRefinements[bc->def] = fnType; + + return ControlFlow::None; } -void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error) +ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatError* error) { for (AstStat* stat : error->statements) visit(scope, stat); for (AstExpr* expr : error->expressions) check(scope, expr); + + return ControlFlow::None; } InferencePack ConstraintGraphBuilder::checkPack( diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 3c306b4..5662cf0 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1273,19 +1273,11 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullsecond) - { - if (ic) - block(NotNull{ic}, blockedConstraint); - if (sc) - block(NotNull{sc}, blockedConstraint); - } - } + if (ic) + inheritBlocks(constraint, NotNull{ic}); + + if (sc) + inheritBlocks(constraint, NotNull{sc}); unblock(c.result); return true; @@ -1330,7 +1322,7 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNullty.emplace(result.value_or(builtinTypes->errorRecoveryType())); + asMutable(c.resultType)->ty.emplace(result.value_or(builtinTypes->anyType)); unblock(c.resultType); return true; } @@ -1796,13 +1788,23 @@ bool ConstraintSolver::tryDispatchIterableFunction( return false; } - const TypeId firstIndex = isNil(firstIndexTy) ? arena->freshType(constraint->scope) // FIXME: Surely this should be a union (free | nil) - : firstIndexTy; + TypeId firstIndex; + TypeId retIndex; + if (isNil(firstIndexTy) || isOptional(firstIndexTy)) + { + firstIndex = arena->addType(UnionType{{arena->freshType(constraint->scope), builtinTypes->nilType}}); + retIndex = firstIndex; + } + else + { + firstIndex = firstIndexTy; + retIndex = arena->addType(UnionType{{firstIndexTy, builtinTypes->nilType}}); + } // nextTy : (tableTy, indexTy?) -> (indexTy?, valueTailTy...) - const TypePackId nextArgPack = arena->addTypePack({tableTy, arena->addType(UnionType{{firstIndex, builtinTypes->nilType}})}); + const TypePackId nextArgPack = arena->addTypePack({tableTy, firstIndex}); const TypePackId valueTailTy = arena->addTypePack(FreeTypePack{constraint->scope}); - const TypePackId nextRetPack = arena->addTypePack(TypePack{{firstIndex}, valueTailTy}); + const TypePackId nextRetPack = arena->addTypePack(TypePack{{retIndex}, valueTailTy}); const TypeId expectedNextTy = arena->addType(FunctionType{TypeLevel{}, constraint->scope, nextArgPack, nextRetPack}); unify(nextTy, expectedNextTy, constraint->scope); @@ -1825,7 +1827,8 @@ bool ConstraintSolver::tryDispatchIterableFunction( modifiedNextRetHead.push_back(*it); TypePackId modifiedNextRetPack = arena->addTypePack(std::move(modifiedNextRetHead), it.tail()); - pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, modifiedNextRetPack}); + auto psc = pushConstraint(constraint->scope, constraint->location, PackSubtypeConstraint{c.variables, modifiedNextRetPack}); + inheritBlocks(constraint, psc); return true; } @@ -1883,7 +1886,17 @@ std::pair, std::optional> ConstraintSolver::lookupTa TypeId indexType = follow(indexProp->second.type); if (auto ft = get(indexType)) - return {{}, first(ft->retTypes)}; + { + TypePack rets = extendTypePack(*arena, builtinTypes, ft->retTypes, 1); + if (1 == rets.head.size()) + return {{}, rets.head[0]}; + else + { + // This should probably be an error: We need the first result of the MT.__index method, + // but it returns 0 values. See CLI-68672 + return {{}, builtinTypes->nilType}; + } + } else return lookupTableProp(indexType, propName, seen); } @@ -2009,6 +2022,20 @@ bool ConstraintSolver::block(TypePackId target, NotNull constr return false; } +void ConstraintSolver::inheritBlocks(NotNull source, NotNull addition) +{ + // Anything that is blocked on this constraint must also be blocked on our + // synthesized constraints. + auto blockedIt = blocked.find(source.get()); + if (blockedIt != blocked.end()) + { + for (const auto& blockedConstraint : blockedIt->second) + { + block(addition, blockedConstraint); + } + } +} + struct Blocker : TypeOnceVisitor { NotNull solver; diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 722f1a2..de79e0b 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -29,6 +29,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) +LUAU_FASTFLAGVARIABLE(LuauLintInTypecheck, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(LuauDefinitionFileSourceModule, false) @@ -330,7 +331,7 @@ std::optional pathExprToModuleName(const ModuleName& currentModuleN namespace { -ErrorVec accumulateErrors( +static ErrorVec accumulateErrors( const std::unordered_map& sourceNodes, const std::unordered_map& modules, const ModuleName& name) { std::unordered_set seen; @@ -375,6 +376,25 @@ ErrorVec accumulateErrors( return result; } +static void filterLintOptions(LintOptions& lintOptions, const std::vector& hotcomments, Mode mode) +{ + LUAU_ASSERT(FFlag::LuauLintInTypecheck); + + uint64_t ignoreLints = LintWarning::parseMask(hotcomments); + + lintOptions.warningMask &= ~ignoreLints; + + if (mode != Mode::NoCheck) + { + lintOptions.disableWarning(Luau::LintWarning::Code_UnknownGlobal); + } + + if (mode == Mode::Strict) + { + lintOptions.disableWarning(Luau::LintWarning::Code_ImplicitReturn); + } +} + // Given a source node (start), find all requires that start a transitive dependency path that ends back at start // For each such path, record the full path and the location of the require in the starting module. // Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) @@ -514,8 +534,24 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& modules = + frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; + + checkResult.errors = accumulateErrors(sourceNodes, modules, name); + + // Get lint result only for top checked module + if (auto it = modules.find(name); it != modules.end()) + checkResult.lintResult = it->second->lintResult; + + return checkResult; + } + else + { + return CheckResult{accumulateErrors( + sourceNodes, frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules, name)}; + } } std::vector buildQueue; @@ -579,7 +615,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalerrors.clear(); + + if (frontendOptions.runLintChecks) + { + LUAU_TIMETRACE_SCOPE("lint", "Frontend"); + + LUAU_ASSERT(FFlag::LuauLintInTypecheck); + + LintOptions lintOptions = frontendOptions.enabledLintWarnings.value_or(config.enabledLint); + filterLintOptions(lintOptions, sourceModule.hotcomments, mode); + + double timestamp = getTimestamp(); + + std::vector warnings = + Luau::lint(sourceModule.root, *sourceModule.names, environmentScope, module.get(), sourceModule.hotcomments, lintOptions); + + stats.timeLint += getTimestamp() - timestamp; + + module->lintResult = classifyLints(warnings, config); + } + if (!frontendOptions.retainFullTypeGraphs) { // copyErrors needs to allocate into interfaceTypes as it copies @@ -665,6 +724,16 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& modules = + frontendOptions.forAutocomplete ? moduleResolverForAutocomplete.modules : moduleResolver.modules; + + if (auto it = modules.find(name); it != modules.end()) + checkResult.lintResult = it->second->lintResult; + } + return checkResult; } @@ -793,8 +862,10 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config return result; } -LintResult Frontend::lint(const ModuleName& name, std::optional enabledLintWarnings) +LintResult Frontend::lint_DEPRECATED(const ModuleName& name, std::optional enabledLintWarnings) { + LUAU_ASSERT(!FFlag::LuauLintInTypecheck); + LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); @@ -803,11 +874,13 @@ LintResult Frontend::lint(const ModuleName& name, std::optional enabledLintWarnings) +LintResult Frontend::lint_DEPRECATED(const SourceModule& module, std::optional enabledLintWarnings) { + LUAU_ASSERT(!FFlag::LuauLintInTypecheck); + LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 0552bec..f8f8b97 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -20,8 +20,10 @@ LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); LUAU_FASTFLAGVARIABLE(LuauNegatedClassTypes, false); LUAU_FASTFLAGVARIABLE(LuauNegatedFunctionTypes, false); LUAU_FASTFLAGVARIABLE(LuauNegatedTableTypes, false); +LUAU_FASTFLAGVARIABLE(LuauNormalizeBlockedTypes, false); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) +LUAU_FASTFLAG(LuauTransitiveSubtyping) namespace Luau { @@ -325,6 +327,8 @@ static int tyvarIndex(TypeId ty) return gtv->index; else if (const FreeType* ftv = get(ty)) return ftv->index; + else if (const BlockedType* btv = get(ty)) + return btv->index; else return 0; } @@ -529,7 +533,7 @@ static bool areNormalizedClasses(const NormalizedClassType& tys) static bool isPlainTyvar(TypeId ty) { - return (get(ty) || get(ty)); + return (get(ty) || get(ty) || (FFlag::LuauNormalizeBlockedTypes && get(ty))); } static bool isNormalizedTyvar(const NormalizedTyvars& tyvars) @@ -1271,6 +1275,8 @@ void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) bool Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) { TypeId tops = unionOfTops(here.tops, there.tops); + if (FFlag::LuauTransitiveSubtyping && get(tops) && (get(here.errors) || get(there.errors))) + tops = builtinTypes->anyType; if (!get(tops)) { clearNormal(here); @@ -1341,12 +1347,21 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor if (get(there) || get(there)) { TypeId tops = unionOfTops(here.tops, there); + if (FFlag::LuauTransitiveSubtyping && get(tops) && get(here.errors)) + tops = builtinTypes->anyType; clearNormal(here); here.tops = tops; return true; } - else if (get(there) || !get(here.tops)) + else if (!FFlag::LuauTransitiveSubtyping && (get(there) || !get(here.tops))) return true; + else if (FFlag::LuauTransitiveSubtyping && (get(there) || get(here.tops))) + return true; + else if (FFlag::LuauTransitiveSubtyping && get(there) && get(here.tops)) + { + here.tops = builtinTypes->anyType; + return true; + } else if (const UnionType* utv = get(there)) { for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) @@ -1363,7 +1378,9 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor return false; return unionNormals(here, norm); } - else if (get(there) || get(there)) + else if (FFlag::LuauTransitiveSubtyping && get(here.tops)) + return true; + else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there))) { if (tyvarIndex(there) <= ignoreSmallerTyvars) return true; @@ -1441,7 +1458,7 @@ bool Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, int ignor if (!unionNormals(here, *tn)) return false; } - else if (get(there)) + else if (!FFlag::LuauNormalizeBlockedTypes && get(there)) LUAU_ASSERT(!"Internal error: Trying to normalize a BlockedType"); else LUAU_ASSERT(!"Unreachable"); @@ -2527,7 +2544,7 @@ bool Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there) return false; return true; } - else if (get(there) || get(there)) + else if (get(there) || get(there) || (FFlag::LuauNormalizeBlockedTypes && get(there))) { NormalizedType thereNorm{builtinTypes}; NormalizedType topNorm{builtinTypes}; @@ -2802,6 +2819,32 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) } bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) +{ + if (!FFlag::LuauTransitiveSubtyping) + return isConsistentSubtype(subTy, superTy, scope, builtinTypes, ice); + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + + u.tryUnify(subTy, superTy); + return !u.failure; +} + +bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) +{ + if (!FFlag::LuauTransitiveSubtyping) + return isConsistentSubtype(subPack, superPack, scope, builtinTypes, ice); + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + + u.tryUnify(subPack, superPack); + return !u.failure; +} + +bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; @@ -2813,7 +2856,7 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) +bool isConsistentSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) { UnifierSharedState sharedState{&ice}; TypeArena arena; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 9da43ed..0b8f462 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -27,7 +27,6 @@ struct Quantifier final : TypeOnceVisitor explicit Quantifier(TypeLevel level) : level(level) { - LUAU_ASSERT(!FFlag::DebugLuauDeferredConstraintResolution); } /// @return true if outer encloses inner diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index cac7212..f54ebe2 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -149,6 +149,28 @@ std::optional Scope::linearSearchForBinding(const std::string& name, bo return std::nullopt; } +// Updates the `this` scope with the refinements from the `childScope` excluding ones that doesn't exist in `this`. +void Scope::inheritRefinements(const ScopePtr& childScope) +{ + if (FFlag::DebugLuauDeferredConstraintResolution) + { + for (const auto& [k, a] : childScope->dcrRefinements) + { + if (lookup(NotNull{k})) + dcrRefinements[k] = a; + } + } + else + { + for (const auto& [k, a] : childScope->refinements) + { + Symbol symbol = getBaseSymbol(k); + if (lookup(symbol)) + refinements[k] = a; + } + } +} + bool subsumesStrict(Scope* left, Scope* right) { while (right) diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 4bc1223..42fa40a 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -25,6 +25,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAGVARIABLE(LuauMatchReturnsOptionalString, false); namespace Luau @@ -431,11 +432,11 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) } BlockedType::BlockedType() - : index(++nextIndex) + : index(FFlag::LuauNormalizeBlockedTypes ? Unifiable::freshIndex() : ++DEPRECATED_nextIndex) { } -int BlockedType::nextIndex = 0; +int BlockedType::DEPRECATED_nextIndex = 0; PendingExpansionType::PendingExpansionType( std::optional prefix, AstName name, std::vector typeArguments, std::vector packArguments) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 87d5686..abc6528 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -43,6 +43,8 @@ LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) +LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) +LUAU_FASTFLAGVARIABLE(LuauReducingAndOr, false) namespace Luau { @@ -344,42 +346,54 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo return std::move(currentModule); } -void TypeChecker::check(const ScopePtr& scope, const AstStat& program) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStat& program) { + if (finishTime && TimeTrace::getClock() > *finishTime) + throw TimeLimitError(iceHandler->moduleName); + if (auto block = program.as()) - check(scope, *block); + return check(scope, *block); else if (auto if_ = program.as()) - check(scope, *if_); + return check(scope, *if_); else if (auto while_ = program.as()) - check(scope, *while_); + return check(scope, *while_); else if (auto repeat = program.as()) - check(scope, *repeat); - else if (program.is()) + return check(scope, *repeat); + else if (program.is() || program.is()) { - } // Nothing to do - else if (program.is()) - { - } // Nothing to do + // Nothing to do + return ControlFlow::None; + } else if (auto return_ = program.as()) - check(scope, *return_); + return check(scope, *return_); else if (auto expr = program.as()) + { checkExprPack(scope, *expr->expr); + + if (FFlag::LuauTinyControlFlowAnalysis) + { + if (auto call = expr->expr->as(); call && doesCallError(call)) + return ControlFlow::Throws; + } + + return ControlFlow::None; + } else if (auto local = program.as()) - check(scope, *local); + return check(scope, *local); else if (auto for_ = program.as()) - check(scope, *for_); + return check(scope, *for_); else if (auto forIn = program.as()) - check(scope, *forIn); + return check(scope, *forIn); else if (auto assign = program.as()) - check(scope, *assign); + return check(scope, *assign); else if (auto assign = program.as()) - check(scope, *assign); + return check(scope, *assign); else if (program.is()) ice("Should not be calling two-argument check() on a function statement", program.location); else if (program.is()) ice("Should not be calling two-argument check() on a function statement", program.location); else if (auto typealias = program.as()) - check(scope, *typealias); + return check(scope, *typealias); else if (auto global = program.as()) { TypeId globalType = resolveType(scope, *global->type); @@ -387,11 +401,13 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) currentModule->declaredGlobals[globalName] = globalType; currentModule->getModuleScope()->bindings[global->name] = Binding{globalType, global->location}; + + return ControlFlow::None; } else if (auto global = program.as()) - check(scope, *global); + return check(scope, *global); else if (auto global = program.as()) - check(scope, *global); + return check(scope, *global); else if (auto errorStatement = program.as()) { const size_t oldSize = currentModule->errors.size(); @@ -405,37 +421,40 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) // HACK: We want to run typechecking on the contents of the AstStatError, but // we don't think the type errors will be useful most of the time. currentModule->errors.resize(oldSize); + + return ControlFlow::None; } else ice("Unknown AstStat"); - - if (finishTime && TimeTrace::getClock() > *finishTime) - throw TimeLimitError(iceHandler->moduleName); } // This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly. -void TypeChecker::check(const ScopePtr& scope, const AstStatBlock& block) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatBlock& block) { ScopePtr child = childScope(scope, block.location); - checkBlock(child, block); + + ControlFlow flow = checkBlock(child, block); + scope->inheritRefinements(child); + + return flow; } -void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) +ControlFlow TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) { RecursionCounter _rc(&checkRecursionCount); if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(block.location); - return; + return ControlFlow::None; } try { - checkBlockWithoutRecursionCheck(scope, block); + return checkBlockWithoutRecursionCheck(scope, block); } catch (const RecursionLimitException&) { reportErrorCodeTooComplex(block.location); - return; + return ControlFlow::None; } } @@ -488,7 +507,7 @@ struct InplaceDemoter : TypeOnceVisitor } }; -void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) +ControlFlow TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& block) { int subLevel = 0; @@ -528,6 +547,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A } }; + std::optional firstFlow; while (protoIter != sorted.end()) { // protoIter walks forward @@ -570,7 +590,9 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A // We do check the current element, so advance checkIter beyond it. ++checkIter; - check(scope, **protoIter); + ControlFlow flow = check(scope, **protoIter); + if (flow != ControlFlow::None && !firstFlow) + firstFlow = flow; } else if (auto fun = (*protoIter)->as()) { @@ -631,7 +653,11 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A scope->bindings[fun->name] = {funTy, fun->name->location}; } else - check(scope, **protoIter); + { + ControlFlow flow = check(scope, **protoIter); + if (flow != ControlFlow::None && !firstFlow) + firstFlow = flow; + } ++protoIter; } @@ -643,6 +669,8 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A } checkBlockTypeAliases(scope, sorted); + + return firstFlow.value_or(ControlFlow::None); } LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted) @@ -717,19 +745,45 @@ static std::optional tryGetTypeGuardPredicate(const AstExprBinary& ex return predicate; } -void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) { WithPredicate result = checkExpr(scope, *statement.condition); - ScopePtr ifScope = childScope(scope, statement.thenbody->location); - resolve(result.predicates, ifScope, true); - check(ifScope, *statement.thenbody); + ScopePtr thenScope = childScope(scope, statement.thenbody->location); + resolve(result.predicates, thenScope, true); - if (statement.elsebody) + if (FFlag::LuauTinyControlFlowAnalysis) { - ScopePtr elseScope = childScope(scope, statement.elsebody->location); + ScopePtr elseScope = childScope(scope, statement.elsebody ? statement.elsebody->location : statement.location); resolve(result.predicates, elseScope, false); - check(elseScope, *statement.elsebody); + + ControlFlow thencf = check(thenScope, *statement.thenbody); + ControlFlow elsecf = ControlFlow::None; + if (statement.elsebody) + elsecf = check(elseScope, *statement.elsebody); + + if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && elsecf == ControlFlow::None) + scope->inheritRefinements(elseScope); + else if (thencf == ControlFlow::None && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + scope->inheritRefinements(thenScope); + + if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) + return ControlFlow::Returns; + else + return ControlFlow::None; + } + else + { + check(thenScope, *statement.thenbody); + + if (statement.elsebody) + { + ScopePtr elseScope = childScope(scope, statement.elsebody->location); + resolve(result.predicates, elseScope, false); + check(elseScope, *statement.elsebody); + } + + return ControlFlow::None; } } @@ -750,22 +804,26 @@ ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const Scope return canUnify_(subTy, superTy, scope, location); } -void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) { WithPredicate result = checkExpr(scope, *statement.condition); ScopePtr whileScope = childScope(scope, statement.body->location); resolve(result.predicates, whileScope, true); check(whileScope, *statement.body); + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) { ScopePtr repScope = childScope(scope, statement.location); checkBlock(repScope, *statement.body); checkExpr(repScope, *statement.condition); + + return ControlFlow::None; } struct Demoter : Substitution @@ -822,7 +880,7 @@ struct Demoter : Substitution } }; -void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) { std::vector> expectedTypes; expectedTypes.reserve(return_.list.size); @@ -858,10 +916,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) if (!errors.empty()) currentModule->getModuleScope()->returnType = addTypePack({anyType}); - return; + return FFlag::LuauTinyControlFlowAnalysis ? ControlFlow::Returns : ControlFlow::None; } unify(retPack, scope->returnType, scope, return_.location, CountMismatch::Context::Return); + + return FFlag::LuauTinyControlFlowAnalysis ? ControlFlow::Returns : ControlFlow::None; } template @@ -893,7 +953,7 @@ ErrorVec TypeChecker::tryUnify(TypePackId subTy, TypePackId superTy, const Scope return tryUnify_(subTy, superTy, scope, location); } -void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) { std::vector> expectedTypes; expectedTypes.reserve(assign.vars.size); @@ -993,9 +1053,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) } } } + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assign) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assign) { AstExprBinary expr(assign.location, assign.op, assign.var, assign.value); @@ -1005,9 +1067,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assi TypeId result = checkBinaryOperation(scope, expr, left, right); unify(result, left, scope, assign.location); + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) { // Important subtlety: A local variable is not in scope while its initializer is being evaluated. // For instance, you cannot do this: @@ -1144,9 +1208,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) for (const auto& [local, binding] : varBindings) scope->bindings[local] = binding; + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) { ScopePtr loopScope = childScope(scope, expr.location); @@ -1169,9 +1235,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) unify(checkExpr(loopScope, *expr.step).type, loopVarType, scope, expr.step->location); check(loopScope, *expr.body); + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) { ScopePtr loopScope = childScope(scope, forin.location); @@ -1360,9 +1428,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) unify(retPack, varPack, scope, forin.location); check(loopScope, *forin.body); + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function) +ControlFlow TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function) { if (auto exprName = function.name->as()) { @@ -1387,8 +1457,6 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco globalBindings[name] = oldBinding; else globalBindings[name] = {quantify(funScope, ty, exprName->location), exprName->location}; - - return; } else if (auto name = function.name->as()) { @@ -1397,7 +1465,6 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; - return; } else if (auto name = function.name->as()) { @@ -1444,9 +1511,11 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); } + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function) +ControlFlow TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function) { Name name = function.name->name.value; @@ -1455,15 +1524,17 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); scope->bindings[function.name] = {quantify(funScope, ty, function.name->location), function.name->location}; + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias) { Name name = typealias.name.value; // If the alias is missing a name, we can't do anything with it. Ignore it. if (name == kParseNameError) - return; + return ControlFlow::None; std::optional binding; if (auto it = scope->exportedTypeBindings.find(name); it != scope->exportedTypeBindings.end()) @@ -1476,7 +1547,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias // If the first pass failed (this should mean a duplicate definition), the second pass isn't going to be // interesting. if (duplicateTypeAliases.find({typealias.exported, name})) - return; + return ControlFlow::None; // By now this alias must have been `prototype()`d first. if (!binding) @@ -1557,6 +1628,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (unify(ty, bindingType, aliasScope, typealias.location)) bindingType = ty; + + return ControlFlow::None; } void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel) @@ -1648,13 +1721,13 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& de scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; } -void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) { Name className(declaredClass.name.value); // Don't bother checking if the class definition was incorrect if (incorrectClassDefinitions.find(&declaredClass)) - return; + return ControlFlow::None; std::optional binding; if (auto it = scope->exportedTypeBindings.find(className); it != scope->exportedTypeBindings.end()) @@ -1721,9 +1794,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar } } } + + return ControlFlow::None; } -void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& global) +ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& global) { ScopePtr funScope = childFunctionScope(scope, global.location); @@ -1754,6 +1829,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo currentModule->declaredGlobals[fnName] = fnType; currentModule->getModuleScope()->bindings[global.name] = Binding{fnType, global.location}; + + return ControlFlow::None; } WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) @@ -2785,6 +2862,16 @@ TypeId TypeChecker::checkRelationalOperation( if (notNever) { LUAU_ASSERT(oty); + + if (FFlag::LuauReducingAndOr) + { + // Perform a limited form of type reduction for booleans + if (isPrim(*oty, PrimitiveType::Boolean) && get(get(follow(rhsType)))) + return booleanType; + if (isPrim(rhsType, PrimitiveType::Boolean) && get(get(follow(*oty)))) + return booleanType; + } + return unionOfTypes(*oty, rhsType, scope, expr.location, false); } else @@ -2808,6 +2895,16 @@ TypeId TypeChecker::checkRelationalOperation( if (notNever) { LUAU_ASSERT(oty); + + if (FFlag::LuauReducingAndOr) + { + // Perform a limited form of type reduction for booleans + if (isPrim(*oty, PrimitiveType::Boolean) && get(get(follow(rhsType)))) + return booleanType; + if (isPrim(rhsType, PrimitiveType::Boolean) && get(get(follow(*oty)))) + return booleanType; + } + return unionOfTypes(*oty, rhsType, scope, expr.location); } else diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index 9db8f7f..dcb2d36 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -8,6 +8,11 @@ namespace Unifiable static int nextIndex = 0; +int freshIndex() +{ + return ++nextIndex; +} + Free::Free(TypeLevel level) : index(++nextIndex) , level(level) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index b53401d..9f30d11 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -20,9 +20,11 @@ LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) +LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauTinyUnifyNormalsFix, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(LuauNegatedFunctionTypes) LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAG(LuauNegatedTableTypes) @@ -475,16 +477,27 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (log.get(superTy)) return tryUnifyWithAny(subTy, builtinTypes->anyType); - if (log.get(superTy)) + if (!FFlag::LuauTransitiveSubtyping && log.get(superTy)) return tryUnifyWithAny(subTy, builtinTypes->errorType); - if (log.get(superTy)) + if (!FFlag::LuauTransitiveSubtyping && log.get(superTy)) return tryUnifyWithAny(subTy, builtinTypes->unknownType); if (log.get(subTy)) + { + if (FFlag::LuauTransitiveSubtyping && normalize) + { + // TODO: there are probably cheaper ways to check if any <: T. + const NormalizedType* superNorm = normalizer->normalize(superTy); + if (!log.get(superNorm->tops)) + failure = true; + } + else + failure = true; return tryUnifyWithAny(superTy, builtinTypes->anyType); + } - if (log.get(subTy)) + if (!FFlag::LuauTransitiveSubtyping && log.get(subTy)) return tryUnifyWithAny(superTy, builtinTypes->errorType); if (log.get(subTy)) @@ -539,6 +552,35 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyIntersectionWithType(subTy, uv, superTy, cacheEnabled, isFunctionCall); } + else if (FFlag::LuauTransitiveSubtyping && log.get(subTy)) + { + tryUnifyWithAny(superTy, builtinTypes->unknownType); + failure = true; + } + else if (FFlag::LuauTransitiveSubtyping && log.get(subTy) && log.get(superTy)) + { + // error <: error + } + else if (FFlag::LuauTransitiveSubtyping && log.get(superTy)) + { + tryUnifyWithAny(subTy, builtinTypes->errorType); + failure = true; + } + else if (FFlag::LuauTransitiveSubtyping && log.get(subTy)) + { + tryUnifyWithAny(superTy, builtinTypes->errorType); + failure = true; + } + else if (FFlag::LuauTransitiveSubtyping && log.get(superTy)) + { + // At this point, all the supertypes of `error` have been handled, + // and if `error unknownType); + } + else if (FFlag::LuauTransitiveSubtyping && log.get(superTy)) + { + tryUnifyWithAny(subTy, builtinTypes->unknownType); + } else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyPrimitives(subTy, superTy); @@ -611,6 +653,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ { // A | B <: T if and only if A <: T and B <: T bool failed = false; + bool errorsSuppressed = true; std::optional unificationTooComplex; std::optional firstFailedOption; @@ -626,13 +669,17 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ if (auto e = hasUnificationTooComplex(innerState.errors)) unificationTooComplex = e; - else if (!innerState.errors.empty()) + else if (FFlag::LuauTransitiveSubtyping ? innerState.failure : !innerState.errors.empty()) { + // If errors were suppressed, we store the log up, so we can commit it if no other option succeeds. + if (FFlag::LuauTransitiveSubtyping && innerState.errors.empty()) + logs.push_back(std::move(innerState.log)); // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' - if (!firstFailedOption && !isNil(type)) + else if (!firstFailedOption && !isNil(type)) firstFailedOption = {innerState.errors.front()}; failed = true; + errorsSuppressed &= innerState.errors.empty(); } } @@ -684,12 +731,13 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ { if (firstFailedOption) reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption, mismatchContext()}); - else + else if (!FFlag::LuauTransitiveSubtyping || !errorsSuppressed) reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); + failure = true; } } -struct BlockedTypeFinder : TypeOnceVisitor +struct DEPRECATED_BlockedTypeFinder : TypeOnceVisitor { std::unordered_set blockedTypes; @@ -700,9 +748,10 @@ struct BlockedTypeFinder : TypeOnceVisitor } }; -bool Unifier::blockOnBlockedTypes(TypeId subTy, TypeId superTy) +bool Unifier::DEPRECATED_blockOnBlockedTypes(TypeId subTy, TypeId superTy) { - BlockedTypeFinder blockedTypeFinder; + LUAU_ASSERT(!FFlag::LuauNormalizeBlockedTypes); + DEPRECATED_BlockedTypeFinder blockedTypeFinder; blockedTypeFinder.traverse(subTy); blockedTypeFinder.traverse(superTy); if (!blockedTypeFinder.blockedTypes.empty()) @@ -718,6 +767,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { // T <: A | B if T <: A or T <: B bool found = false; + bool errorsSuppressed = false; std::optional unificationTooComplex; size_t failedOptionCount = 0; @@ -754,6 +804,21 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } } + if (FFlag::LuauTransitiveSubtyping && !foundHeuristic) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + TypeId type = uv->options[i]; + + if (subTy == type) + { + foundHeuristic = true; + startIndex = i; + break; + } + } + } + if (!foundHeuristic && cacheEnabled) { auto& cache = sharedState.cachedUnify; @@ -779,7 +844,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp innerState.normalize = false; innerState.tryUnify_(subTy, type, isFunctionCall); - if (innerState.errors.empty()) + if (FFlag::LuauTransitiveSubtyping ? !innerState.failure : innerState.errors.empty()) { found = true; if (FFlag::DebugLuauDeferredConstraintResolution) @@ -790,6 +855,10 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp break; } } + else if (FFlag::LuauTransitiveSubtyping && innerState.errors.empty()) + { + errorsSuppressed = true; + } else if (auto e = hasUnificationTooComplex(innerState.errors)) { unificationTooComplex = e; @@ -810,11 +879,32 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { reportError(*unificationTooComplex); } + else if (FFlag::LuauTransitiveSubtyping && !found && normalize) + { + // It is possible that T <: A | B even though T normalize(subTy); + const NormalizedType* superNorm = normalizer->normalize(superTy); + Unifier innerState = makeChildUnifier(); + if (!subNorm || !superNorm) + return reportError(location, UnificationTooComplex{}); + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + else + innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + if (!innerState.failure) + log.concat(std::move(innerState.log)); + else if (errorsSuppressed || innerState.errors.empty()) + failure = true; + else + reportError(std::move(innerState.errors.front())); + } else if (!found && normalize) { // We cannot normalize a type that contains blocked types. We have to // stop for now if we find any. - if (blockOnBlockedTypes(subTy, superTy)) + if (!FFlag::LuauNormalizeBlockedTypes && DEPRECATED_blockOnBlockedTypes(subTy, superTy)) return; // It is possible that T <: A | B even though T unificationTooComplex; size_t startIndex = 0; @@ -919,7 +1013,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* { // We cannot normalize a type that contains blocked types. We have to // stop for now if we find any. - if (blockOnBlockedTypes(subTy, superTy)) + if (!FFlag::LuauNormalizeBlockedTypes && DEPRECATED_blockOnBlockedTypes(subTy, superTy)) return; // Sometimes a negation type is inside one of the types, e.g. { p: number } & { p: ~number }. @@ -951,13 +1045,18 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* innerState.normalize = false; innerState.tryUnify_(type, superTy, isFunctionCall); + // TODO: This sets errorSuppressed to true if any of the parts is error-suppressing, + // in paricular any & T is error-suppressing. Really, errorSuppressed should be true if + // all of the parts are error-suppressing, but that fails to typecheck lua-apps. if (innerState.errors.empty()) { found = true; - if (FFlag::DebugLuauDeferredConstraintResolution) + errorsSuppressed = innerState.failure; + if (FFlag::DebugLuauDeferredConstraintResolution || (FFlag::LuauTransitiveSubtyping && innerState.failure)) logs.push_back(std::move(innerState.log)); else { + errorsSuppressed = false; log.concat(std::move(innerState.log)); break; } @@ -970,6 +1069,8 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* if (FFlag::DebugLuauDeferredConstraintResolution) log.concat(combineLogsIntoIntersection(std::move(logs))); + else if (FFlag::LuauTransitiveSubtyping && errorsSuppressed) + log.concat(std::move(logs.front())); if (unificationTooComplex) reportError(*unificationTooComplex); @@ -977,7 +1078,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* { // We cannot normalize a type that contains blocked types. We have to // stop for now if we find any. - if (blockOnBlockedTypes(subTy, superTy)) + if (!FFlag::LuauNormalizeBlockedTypes && DEPRECATED_blockOnBlockedTypes(subTy, superTy)) return; // It is possible that A & B <: T even though A error) { - if (get(superNorm.tops) || get(superNorm.tops) || get(subNorm.tops)) + if (!FFlag::LuauTransitiveSubtyping && get(superNorm.tops)) return; - else if (get(subNorm.tops)) + else if (get(superNorm.tops)) + return; + else if (get(subNorm.tops)) + { + failure = true; + return; + } + else if (!FFlag::LuauTransitiveSubtyping && get(subNorm.tops)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (get(subNorm.errors)) if (!get(superNorm.errors)) - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + { + failure = true; + if (!FFlag::LuauTransitiveSubtyping) + reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); + return; + } + + if (FFlag::LuauTransitiveSubtyping && get(superNorm.tops)) + return; + + if (FFlag::LuauTransitiveSubtyping && get(subNorm.tops)) + return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (get(subNorm.booleans)) { @@ -1911,6 +2032,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); + failure |= innerState.failure; } else if (subTable->indexer && maybeString(subTable->indexer->indexType)) { @@ -1926,6 +2048,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); + failure |= innerState.failure; } else if (subTable->state == TableState::Unsealed && isOptional(prop.type)) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` @@ -1988,6 +2111,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); + failure |= innerState.failure; } else if (superTable->state == TableState::Unsealed) { @@ -2059,6 +2183,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); + failure |= innerState.failure; } else if (superTable->indexer) { @@ -2234,6 +2359,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()}); log.concat(std::move(innerState.log)); + failure |= innerState.failure; } else if (TableType* subTable = log.getMutable(subTy)) { @@ -2274,6 +2400,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { log.concat(std::move(innerState.log)); log.bindTable(subTy, superTy); + failure |= innerState.failure; } } else @@ -2367,6 +2494,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) if (innerState.errors.empty()) { log.concat(std::move(innerState.log)); + failure |= innerState.failure; } else { @@ -2398,7 +2526,7 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy) // We cannot normalize a type that contains blocked types. We have to // stop for now if we find any. - if (blockOnBlockedTypes(subTy, superTy)) + if (!FFlag::LuauNormalizeBlockedTypes && DEPRECATED_blockOnBlockedTypes(subTy, superTy)) return; const NormalizedType* subNorm = normalizer->normalize(subTy); @@ -2726,6 +2854,7 @@ Unifier Unifier::makeChildUnifier() void Unifier::reportError(Location location, TypeErrorData data) { errors.emplace_back(std::move(location), std::move(data)); + failure = true; } // A utility function that appends the given error to the unifier's error log. @@ -2736,6 +2865,7 @@ void Unifier::reportError(Location location, TypeErrorData data) void Unifier::reportError(TypeError err) { errors.push_back(std::move(err)); + failure = true; } diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 8b7eb73..0b9d8c4 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -123,13 +123,13 @@ private: // return [explist] AstStat* parseReturn(); - // type Name `=' typeannotation + // type Name `=' Type AstStat* parseTypeAlias(const Location& start, bool exported); AstDeclaredClassProp parseDeclaredClassMethod(); - // `declare global' Name: typeannotation | - // `declare function' Name`(' [parlist] `)' [`:` TypeAnnotation] + // `declare global' Name: Type | + // `declare function' Name`(' [parlist] `)' [`:` Type] AstStat* parseDeclaration(const Location& start); // varlist `=' explist @@ -140,7 +140,7 @@ private: std::pair> prepareFunctionArguments(const Location& start, bool hasself, const TempVector& args); - // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` TypeAnnotation] + // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` Type] // funcbody ::= funcbodyhead block end std::pair parseFunctionBody( bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName); @@ -148,21 +148,21 @@ private: // explist ::= {exp `,'} exp void parseExprList(TempVector& result); - // binding ::= Name [`:` TypeAnnotation] + // binding ::= Name [`:` Type] Binding parseBinding(); // bindinglist ::= (binding | `...') {`,' bindinglist} // Returns the location of the vararg ..., or std::nullopt if the function is not vararg. std::tuple parseBindingList(TempVector& result, bool allowDot3 = false); - AstType* parseOptionalTypeAnnotation(); + AstType* parseOptionalType(); - // TypeList ::= TypeAnnotation [`,' TypeList] - // ReturnType ::= TypeAnnotation | `(' TypeList `)' - // TableProp ::= Name `:' TypeAnnotation - // TableIndexer ::= `[' TypeAnnotation `]' `:' TypeAnnotation + // TypeList ::= Type [`,' TypeList] + // ReturnType ::= Type | `(' TypeList `)' + // TableProp ::= Name `:' Type + // TableIndexer ::= `[' Type `]' `:' Type // PropList ::= (TableProp | TableIndexer) [`,' PropList] - // TypeAnnotation + // Type // ::= Name // | `nil` // | `{' [PropList] `}' @@ -171,24 +171,25 @@ private: // Returns the variadic annotation, if it exists. AstTypePack* parseTypeList(TempVector& result, TempVector>& resultNames); - std::optional parseOptionalReturnTypeAnnotation(); - std::pair parseReturnTypeAnnotation(); + std::optional parseOptionalReturnType(); + std::pair parseReturnType(); - AstTableIndexer* parseTableIndexerAnnotation(); + AstTableIndexer* parseTableIndexer(); - AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack); - AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation); + AstTypeOrPack parseFunctionType(bool allowPack); + AstType* parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, + AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation); - AstType* parseTableTypeAnnotation(); - AstTypeOrPack parseSimpleTypeAnnotation(bool allowPack); + AstType* parseTableType(); + AstTypeOrPack parseSimpleType(bool allowPack); - AstTypeOrPack parseTypeOrPackAnnotation(); - AstType* parseTypeAnnotation(TempVector& parts, const Location& begin); - AstType* parseTypeAnnotation(); + AstTypeOrPack parseTypeOrPack(); + AstType* parseType(); - AstTypePack* parseTypePackAnnotation(); - AstTypePack* parseVariadicArgumentAnnotation(); + AstTypePack* parseTypePack(); + AstTypePack* parseVariadicArgumentTypePack(); + + AstType* parseTypeSuffix(AstType* type, const Location& begin); static std::optional parseUnaryOp(const Lexeme& l); static std::optional parseBinaryOp(const Lexeme& l); @@ -215,7 +216,7 @@ private: // primaryexp -> prefixexp { `.' NAME | `[' exp `]' | `:' NAME funcargs | funcargs } AstExpr* parsePrimaryExpr(bool asStatement); - // asexp -> simpleexp [`::' typeAnnotation] + // asexp -> simpleexp [`::' Type] AstExpr* parseAssertionExpr(); // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp @@ -244,7 +245,7 @@ private: // `<' namelist `>' std::pair, AstArray> parseGenericTypeList(bool withDefaultValues); - // `<' typeAnnotation[, ...] `>' + // `<' Type[, ...] `>' AstArray parseTypeParams(); std::optional> parseCharArray(); @@ -302,13 +303,12 @@ private: AstStatError* reportStatError(const Location& location, const AstArray& expressions, const AstArray& statements, const char* format, ...) LUAU_PRINTF_ATTR(5, 6); AstExprError* reportExprError(const Location& location, const AstArray& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); - AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) - LUAU_PRINTF_ATTR(4, 5); + AstTypeError* reportTypeError(const Location& location, const AstArray& types, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); // `parseErrorLocation` is associated with the parser error // `astErrorLocation` is associated with the AstTypeError created // It can be useful to have different error locations so that the parse error can include the next lexeme, while the AstTypeError can precisely // define the location (possibly of zero size) where a type annotation is expected. - AstTypeError* reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) + AstTypeError* reportMissingTypeError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); AstExpr* reportFunctionArgsError(AstExpr* func, bool self); @@ -401,8 +401,8 @@ private: std::vector scratchBinding; std::vector scratchLocal; std::vector scratchTableTypeProps; - std::vector scratchAnnotation; - std::vector scratchTypeOrPackAnnotation; + std::vector scratchType; + std::vector scratchTypeOrPack; std::vector scratchDeclaredClassProps; std::vector scratchItem; std::vector scratchArgName; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 4c34771..40fa754 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -130,7 +130,7 @@ void TempVector::push_back(const T& item) size_++; } -static bool shouldParseTypePackAnnotation(Lexer& lexer) +static bool shouldParseTypePack(Lexer& lexer) { if (lexer.current().type == Lexeme::Dot3) return true; @@ -330,11 +330,12 @@ AstStat* Parser::parseStat() if (options.allowTypeAnnotations) { if (ident == "type") - return parseTypeAlias(expr->location, /* exported =*/false); + return parseTypeAlias(expr->location, /* exported= */ false); + if (ident == "export" && lexer.current().type == Lexeme::Name && AstName(lexer.current().name) == "type") { nextLexeme(); - return parseTypeAlias(expr->location, /* exported =*/true); + return parseTypeAlias(expr->location, /* exported= */ true); } } @@ -742,7 +743,7 @@ AstStat* Parser::parseReturn() return allocator.alloc(Location(start, end), copy(list)); } -// type Name [`<' varlist `>'] `=' typeannotation +// type Name [`<' varlist `>'] `=' Type AstStat* Parser::parseTypeAlias(const Location& start, bool exported) { // note: `type` token is already parsed for us, so we just need to parse the rest @@ -757,7 +758,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) expectAndConsume('=', "type alias"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); return allocator.alloc(Location(start, type->location), name->name, name->location, generics, genericPacks, type, exported); } @@ -789,16 +790,16 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() expectMatchAndConsume(')', matchParen); - AstTypeList retTypes = parseOptionalReturnTypeAnnotation().value_or(AstTypeList{copy(nullptr, 0), nullptr}); + AstTypeList retTypes = parseOptionalReturnType().value_or(AstTypeList{copy(nullptr, 0), nullptr}); Location end = lexer.current().location; - TempVector vars(scratchAnnotation); + TempVector vars(scratchType); TempVector> varNames(scratchOptArgName); if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) { return AstDeclaredClassProp{ - fnName.name, reportTypeAnnotationError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; + fnName.name, reportTypeError(Location(start, end), {}, "'self' must be present as the unannotated first parameter"), true}; } // Skip the first index. @@ -809,7 +810,7 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() if (args[i].annotation) vars.push_back(args[i].annotation); else - vars.push_back(reportTypeAnnotationError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated")); + vars.push_back(reportTypeError(Location(start, end), {}, "All declaration parameters aside from 'self' must be annotated")); } if (vararg && !varargAnnotation) @@ -846,10 +847,10 @@ AstStat* Parser::parseDeclaration(const Location& start) expectMatchAndConsume(')', matchParen); - AstTypeList retTypes = parseOptionalReturnTypeAnnotation().value_or(AstTypeList{copy(nullptr, 0)}); + AstTypeList retTypes = parseOptionalReturnType().value_or(AstTypeList{copy(nullptr, 0)}); Location end = lexer.current().location; - TempVector vars(scratchAnnotation); + TempVector vars(scratchType); TempVector varNames(scratchArgName); for (size_t i = 0; i < args.size(); ++i) @@ -898,7 +899,7 @@ AstStat* Parser::parseDeclaration(const Location& start) expectMatchAndConsume(']', begin); expectAndConsume(':', "property type annotation"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); // TODO: since AstName conains a char*, it can't contain null bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); @@ -912,7 +913,7 @@ AstStat* Parser::parseDeclaration(const Location& start) { Name propName = parseName("property name"); expectAndConsume(':', "property type annotation"); - AstType* propType = parseTypeAnnotation(); + AstType* propType = parseType(); props.push_back(AstDeclaredClassProp{propName.name, propType, false}); } } @@ -926,7 +927,7 @@ AstStat* Parser::parseDeclaration(const Location& start) { expectAndConsume(':', "global variable declaration"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); return allocator.alloc(Location(start, type->location), globalName->name, type); } else @@ -1027,7 +1028,7 @@ std::pair Parser::parseFunctionBody( expectMatchAndConsume(')', matchParen, true); - std::optional typelist = parseOptionalReturnTypeAnnotation(); + std::optional typelist = parseOptionalReturnType(); AstLocal* funLocal = nullptr; @@ -1085,7 +1086,7 @@ Parser::Binding Parser::parseBinding() if (!name) name = Name(nameError, lexer.current().location); - AstType* annotation = parseOptionalTypeAnnotation(); + AstType* annotation = parseOptionalType(); return Binding(*name, annotation); } @@ -1104,7 +1105,7 @@ std::tuple Parser::parseBindingList(TempVector Parser::parseBindingList(TempVector& result, TempVector>& resultNames) { while (true) { - if (shouldParseTypePackAnnotation(lexer)) - return parseTypePackAnnotation(); + if (shouldParseTypePack(lexer)) + return parseTypePack(); if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == ':') { @@ -1156,7 +1157,7 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector& result, TempVector Parser::parseOptionalReturnTypeAnnotation() +std::optional Parser::parseOptionalReturnType() { if (options.allowTypeAnnotations && (lexer.current().type == ':' || lexer.current().type == Lexeme::SkinnyArrow)) { @@ -1183,7 +1184,7 @@ std::optional Parser::parseOptionalReturnTypeAnnotation() unsigned int oldRecursionCount = recursionCounter; - auto [_location, result] = parseReturnTypeAnnotation(); + auto [_location, result] = parseReturnType(); // At this point, if we find a , character, it indicates that there are multiple return types // in this type annotation, but the list wasn't wrapped in parentheses. @@ -1202,27 +1203,27 @@ std::optional Parser::parseOptionalReturnTypeAnnotation() return std::nullopt; } -// ReturnType ::= TypeAnnotation | `(' TypeList `)' -std::pair Parser::parseReturnTypeAnnotation() +// ReturnType ::= Type | `(' TypeList `)' +std::pair Parser::parseReturnType() { incrementRecursionCounter("type annotation"); - TempVector result(scratchAnnotation); - TempVector> resultNames(scratchOptArgName); - AstTypePack* varargAnnotation = nullptr; - Lexeme begin = lexer.current(); if (lexer.current().type != '(') { - if (shouldParseTypePackAnnotation(lexer)) - varargAnnotation = parseTypePackAnnotation(); + if (shouldParseTypePack(lexer)) + { + AstTypePack* typePack = parseTypePack(); + + return {typePack->location, AstTypeList{{}, typePack}}; + } else - result.push_back(parseTypeAnnotation()); + { + AstType* type = parseType(); - Location resultLocation = result.size() == 0 ? varargAnnotation->location : result[0]->location; - - return {resultLocation, AstTypeList{copy(result), varargAnnotation}}; + return {type->location, AstTypeList{copy(&type, 1), nullptr}}; + } } nextLexeme(); @@ -1231,6 +1232,10 @@ std::pair Parser::parseReturnTypeAnnotation() matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; + TempVector result(scratchType); + TempVector> resultNames(scratchOptArgName); + AstTypePack* varargAnnotation = nullptr; + // possibly () -> ReturnType if (lexer.current().type != ')') varargAnnotation = parseTypeList(result, resultNames); @@ -1246,9 +1251,9 @@ std::pair Parser::parseReturnTypeAnnotation() // If it turns out that it's just '(A)', it's possible that there are unions/intersections to follow, so fold over it. if (result.size() == 1) { - AstType* returnType = parseTypeAnnotation(result, innerBegin); + AstType* returnType = parseTypeSuffix(result[0], innerBegin); - // If parseTypeAnnotation parses nothing, then returnType->location.end only points at the last non-type-pack + // If parseType parses nothing, then returnType->location.end only points at the last non-type-pack // type to successfully parse. We need the span of the whole annotation. Position endPos = result.size() == 1 ? location.end : returnType->location.end; @@ -1258,39 +1263,33 @@ std::pair Parser::parseReturnTypeAnnotation() return {location, AstTypeList{copy(result), varargAnnotation}}; } - AstArray generics{nullptr, 0}; - AstArray genericPacks{nullptr, 0}; - AstArray types = copy(result); - AstArray> names = copy(resultNames); + AstType* tail = parseFunctionTypeTail(begin, {}, {}, copy(result), copy(resultNames), varargAnnotation); - TempVector fallbackReturnTypes(scratchAnnotation); - fallbackReturnTypes.push_back(parseFunctionTypeAnnotationTail(begin, generics, genericPacks, types, names, varargAnnotation)); - - return {Location{location, fallbackReturnTypes[0]->location}, AstTypeList{copy(fallbackReturnTypes), varargAnnotation}}; + return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}}; } -// TableIndexer ::= `[' TypeAnnotation `]' `:' TypeAnnotation -AstTableIndexer* Parser::parseTableIndexerAnnotation() +// TableIndexer ::= `[' Type `]' `:' Type +AstTableIndexer* Parser::parseTableIndexer() { const Lexeme begin = lexer.current(); nextLexeme(); // [ - AstType* index = parseTypeAnnotation(); + AstType* index = parseType(); expectMatchAndConsume(']', begin); expectAndConsume(':', "table field"); - AstType* result = parseTypeAnnotation(); + AstType* result = parseType(); return allocator.alloc(AstTableIndexer{index, result, Location(begin.location, result->location)}); } -// TableProp ::= Name `:' TypeAnnotation +// TableProp ::= Name `:' Type // TablePropOrIndexer ::= TableProp | TableIndexer // PropList ::= TablePropOrIndexer {fieldsep TablePropOrIndexer} [fieldsep] -// TableTypeAnnotation ::= `{' PropList `}' -AstType* Parser::parseTableTypeAnnotation() +// TableType ::= `{' PropList `}' +AstType* Parser::parseTableType() { incrementRecursionCounter("type annotation"); @@ -1313,7 +1312,7 @@ AstType* Parser::parseTableTypeAnnotation() expectMatchAndConsume(']', begin); expectAndConsume(':', "table field"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); // TODO: since AstName conains a char*, it can't contain null bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); @@ -1329,19 +1328,19 @@ AstType* Parser::parseTableTypeAnnotation() { // maybe we don't need to parse the entire badIndexer... // however, we either have { or [ to lint, not the entire table type or the bad indexer. - AstTableIndexer* badIndexer = parseTableIndexerAnnotation(); + AstTableIndexer* badIndexer = parseTableIndexer(); // we lose all additional indexer expressions from the AST after error recovery here report(badIndexer->location, "Cannot have more than one table indexer"); } else { - indexer = parseTableIndexerAnnotation(); + indexer = parseTableIndexer(); } } else if (props.empty() && !indexer && !(lexer.current().type == Lexeme::Name && lexer.lookahead().type == ':')) { - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); // array-like table type: {T} desugars into {[number]: T} AstType* index = allocator.alloc(type->location, std::nullopt, nameNumber); @@ -1358,7 +1357,7 @@ AstType* Parser::parseTableTypeAnnotation() expectAndConsume(':', "table field"); - AstType* type = parseTypeAnnotation(); + AstType* type = parseType(); props.push_back({name->name, name->location, type}); } @@ -1382,9 +1381,9 @@ AstType* Parser::parseTableTypeAnnotation() return allocator.alloc(Location(start, end), copy(props), indexer); } -// ReturnType ::= TypeAnnotation | `(' TypeList `)' -// FunctionTypeAnnotation ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) +// ReturnType ::= Type | `(' TypeList `)' +// FunctionType ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType +AstTypeOrPack Parser::parseFunctionType(bool allowPack) { incrementRecursionCounter("type annotation"); @@ -1400,7 +1399,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; - TempVector params(scratchAnnotation); + TempVector params(scratchType); TempVector> names(scratchOptArgName); AstTypePack* varargAnnotation = nullptr; @@ -1432,12 +1431,11 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) AstArray> paramNames = copy(names); - return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; + return {parseFunctionTypeTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } -AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, - AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation) - +AstType* Parser::parseFunctionTypeTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, + AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation) { incrementRecursionCounter("type annotation"); @@ -1458,21 +1456,22 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray(Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList); } -// typeannotation ::= +// Type ::= // nil | // Name[`.' Name] [`<' namelist `>'] | // `{' [PropList] `}' | // `(' [TypeList] `)' `->` ReturnType -// `typeof` typeannotation -AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location& begin) +// `typeof` Type +AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) { - LUAU_ASSERT(!parts.empty()); + TempVector parts(scratchType); + parts.push_back(type); incrementRecursionCounter("type annotation"); @@ -1487,7 +1486,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (c == '|') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); + parts.push_back(parseSimpleType(/* allowPack= */ false).type); isUnion = true; } else if (c == '?') @@ -1500,7 +1499,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location else if (c == '&') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); + parts.push_back(parseSimpleType(/* allowPack= */ false).type); isIntersection = true; } else if (c == Lexeme::Dot3) @@ -1513,11 +1512,11 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location } if (parts.size() == 1) - return parts[0]; + return type; if (isUnion && isIntersection) { - return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), + return reportTypeError(Location(begin, parts.back()->location), copy(parts), "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); } @@ -1533,16 +1532,14 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location ParseError::raise(begin, "Composite type was not an intersection or union."); } -AstTypeOrPack Parser::parseTypeOrPackAnnotation() +AstTypeOrPack Parser::parseTypeOrPack() { unsigned int oldRecursionCount = recursionCounter; incrementRecursionCounter("type annotation"); Location begin = lexer.current().location; - TempVector parts(scratchAnnotation); - - auto [type, typePack] = parseSimpleTypeAnnotation(/* allowPack= */ true); + auto [type, typePack] = parseSimpleType(/* allowPack= */ true); if (typePack) { @@ -1550,31 +1547,28 @@ AstTypeOrPack Parser::parseTypeOrPackAnnotation() return {{}, typePack}; } - parts.push_back(type); - recursionCounter = oldRecursionCount; - return {parseTypeAnnotation(parts, begin), {}}; + return {parseTypeSuffix(type, begin), {}}; } -AstType* Parser::parseTypeAnnotation() +AstType* Parser::parseType() { unsigned int oldRecursionCount = recursionCounter; incrementRecursionCounter("type annotation"); Location begin = lexer.current().location; - TempVector parts(scratchAnnotation); - parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); + AstType* type = parseSimpleType(/* allowPack= */ false).type; recursionCounter = oldRecursionCount; - return parseTypeAnnotation(parts, begin); + return parseTypeSuffix(type, begin); } -// typeannotation ::= nil | Name[`.' Name] [ `<' typeannotation [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' +// Type ::= nil | Name[`.' Name] [ `<' Type [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' // | [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) +AstTypeOrPack Parser::parseSimpleType(bool allowPack) { incrementRecursionCounter("type annotation"); @@ -1603,18 +1597,18 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) return {allocator.alloc(start, svalue)}; } else - return {reportTypeAnnotationError(start, {}, "String literal contains malformed escape sequence")}; + return {reportTypeError(start, {}, "String literal contains malformed escape sequence")}; } else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) { parseInterpString(); - return {reportTypeAnnotationError(start, {}, "Interpolated string literals cannot be used as types")}; + return {reportTypeError(start, {}, "Interpolated string literals cannot be used as types")}; } else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); - return {reportTypeAnnotationError(start, {}, "Malformed string")}; + return {reportTypeError(start, {}, "Malformed string")}; } else if (lexer.current().type == Lexeme::Name) { @@ -1663,17 +1657,17 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) } else if (lexer.current().type == '{') { - return {parseTableTypeAnnotation(), {}}; + return {parseTableType(), {}}; } else if (lexer.current().type == '(' || lexer.current().type == '<') { - return parseFunctionTypeAnnotation(allowPack); + return parseFunctionType(allowPack); } else if (lexer.current().type == Lexeme::ReservedFunction) { nextLexeme(); - return {reportTypeAnnotationError(start, {}, + return {reportTypeError(start, {}, "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " "...any'"), {}}; @@ -1685,12 +1679,11 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) // The parse error includes the next lexeme to make it easier to display where the error is (e.g. in an IDE or a CLI error message). // Including the current lexeme also makes the parse error consistent with other parse errors returned by Luau. Location parseErrorLocation(lexer.previousLocation().end, start.end); - return { - reportMissingTypeAnnotationError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}}; + return {reportMissingTypeError(parseErrorLocation, astErrorlocation, "Expected type, got %s", lexer.current().toString().c_str()), {}}; } } -AstTypePack* Parser::parseVariadicArgumentAnnotation() +AstTypePack* Parser::parseVariadicArgumentTypePack() { // Generic: a... if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == Lexeme::Dot3) @@ -1705,19 +1698,19 @@ AstTypePack* Parser::parseVariadicArgumentAnnotation() // Variadic: T else { - AstType* variadicAnnotation = parseTypeAnnotation(); + AstType* variadicAnnotation = parseType(); return allocator.alloc(variadicAnnotation->location, variadicAnnotation); } } -AstTypePack* Parser::parseTypePackAnnotation() +AstTypePack* Parser::parseTypePack() { // Variadic: ...T if (lexer.current().type == Lexeme::Dot3) { Location start = lexer.current().location; nextLexeme(); - AstType* varargTy = parseTypeAnnotation(); + AstType* varargTy = parseType(); return allocator.alloc(Location(start, varargTy->location), varargTy); } // Generic: a... @@ -2054,7 +2047,7 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) return expr; } -// asexp -> simpleexp [`::' typeannotation] +// asexp -> simpleexp [`::' Type] AstExpr* Parser::parseAssertionExpr() { Location start = lexer.current().location; @@ -2063,7 +2056,7 @@ AstExpr* Parser::parseAssertionExpr() if (options.allowTypeAnnotations && lexer.current().type == Lexeme::DoubleColon) { nextLexeme(); - AstType* annotation = parseTypeAnnotation(); + AstType* annotation = parseType(); return allocator.alloc(Location(start, annotation->location), expr, annotation); } else @@ -2455,15 +2448,15 @@ std::pair, AstArray> Parser::parseG Lexeme packBegin = lexer.current(); - if (shouldParseTypePackAnnotation(lexer)) + if (shouldParseTypePack(lexer)) { - AstTypePack* typePack = parseTypePackAnnotation(); + AstTypePack* typePack = parseTypePack(); namePacks.push_back({name, nameLocation, typePack}); } else if (!FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument && lexer.current().type == '(') { - auto [type, typePack] = parseTypeOrPackAnnotation(); + auto [type, typePack] = parseTypeOrPack(); if (type) report(Location(packBegin.location.begin, lexer.previousLocation().end), "Expected type pack after '=', got type"); @@ -2472,7 +2465,7 @@ std::pair, AstArray> Parser::parseG } else if (FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument) { - auto [type, typePack] = parseTypeOrPackAnnotation(); + auto [type, typePack] = parseTypeOrPack(); if (type) report(type->location, "Expected type pack after '=', got type"); @@ -2495,7 +2488,7 @@ std::pair, AstArray> Parser::parseG seenDefault = true; nextLexeme(); - AstType* defaultType = parseTypeAnnotation(); + AstType* defaultType = parseType(); names.push_back({name, nameLocation, defaultType}); } @@ -2532,7 +2525,7 @@ std::pair, AstArray> Parser::parseG AstArray Parser::parseTypeParams() { - TempVector parameters{scratchTypeOrPackAnnotation}; + TempVector parameters{scratchTypeOrPack}; if (lexer.current().type == '<') { @@ -2541,15 +2534,15 @@ AstArray Parser::parseTypeParams() while (true) { - if (shouldParseTypePackAnnotation(lexer)) + if (shouldParseTypePack(lexer)) { - AstTypePack* typePack = parseTypePackAnnotation(); + AstTypePack* typePack = parseTypePack(); parameters.push_back({{}, typePack}); } else if (lexer.current().type == '(') { - auto [type, typePack] = parseTypeOrPackAnnotation(); + auto [type, typePack] = parseTypeOrPack(); if (typePack) parameters.push_back({{}, typePack}); @@ -2562,7 +2555,7 @@ AstArray Parser::parseTypeParams() } else { - parameters.push_back({parseTypeAnnotation(), {}}); + parameters.push_back({parseType(), {}}); } if (lexer.current().type == ',') @@ -3018,7 +3011,7 @@ AstExprError* Parser::reportExprError(const Location& location, const AstArray(location, expressions, unsigned(parseErrors.size() - 1)); } -AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray& types, const char* format, ...) +AstTypeError* Parser::reportTypeError(const Location& location, const AstArray& types, const char* format, ...) { va_list args; va_start(args, format); @@ -3028,7 +3021,7 @@ AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const return allocator.alloc(location, types, false, unsigned(parseErrors.size() - 1)); } -AstTypeError* Parser::reportMissingTypeAnnotationError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) +AstTypeError* Parser::reportMissingTypeError(const Location& parseErrorLocation, const Location& astErrorLocation, const char* format, ...) { va_list args; va_start(args, format); diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index d6f1822..4fdb044 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -14,6 +14,7 @@ #endif LUAU_FASTFLAG(DebugLuauTimeTracing) +LUAU_FASTFLAG(LuauLintInTypecheck) enum class ReportFormat { @@ -80,7 +81,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat for (auto& error : cr.errors) reportError(frontend, format, error); - Luau::LintResult lr = frontend.lint(name); + Luau::LintResult lr = FFlag::LuauLintInTypecheck ? cr.lintResult : frontend.lint_DEPRECATED(name); std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name); for (auto& error : lr.errors) @@ -263,6 +264,7 @@ int main(int argc, char** argv) Luau::FrontendOptions frontendOptions; frontendOptions.retainFullTypeGraphs = annotate; + frontendOptions.runLintChecks = FFlag::LuauLintInTypecheck; CliFileResolver fileResolver; CliConfigResolver configResolver(mode); diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 94d8f81..0179967 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -93,6 +93,13 @@ public: // Assigns label position to the current location void setLabel(Label& label); + // Extracts code offset (in bytes) from label + uint32_t getLabelOffset(const Label& label) + { + LUAU_ASSERT(label.location != ~0u); + return label.location * 4; + } + void logAppend(const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); uint32_t getCodeSize() const; diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index 597f2b2..17076ed 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -155,6 +155,13 @@ public: // Assigns label position to the current location void setLabel(Label& label); + // Extracts code offset (in bytes) from label + uint32_t getLabelOffset(const Label& label) + { + LUAU_ASSERT(label.location != ~0u); + return label.location; + } + // Constant allocation (uses rip-relative addressing) OperandX64 i64(int64_t value); OperandX64 f32(float value); diff --git a/CodeGen/include/Luau/CodeAllocator.h b/CodeGen/include/Luau/CodeAllocator.h index a6cab4a..e0537b6 100644 --- a/CodeGen/include/Luau/CodeAllocator.h +++ b/CodeGen/include/Luau/CodeAllocator.h @@ -21,7 +21,8 @@ struct CodeAllocator // Places data and code into the executable page area // To allow allocation while previously allocated code is already running, allocation has page granularity // It's important to group functions together so that page alignment won't result in a lot of wasted space - bool allocate(uint8_t* data, size_t dataSize, uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart); + bool allocate( + const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart); // Provided to callbacks void* context = nullptr; diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index 21fa755..5c2bc4d 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -42,6 +42,7 @@ struct CfgInfo std::vector successorsOffsets; std::vector in; + std::vector def; std::vector out; RegisterSet captured; diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index 916c6ee..e6202c7 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -1,8 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Common.h" #include "Luau/Bytecode.h" +#include "Luau/Common.h" +#include "Luau/DenseHash.h" #include "Luau/IrData.h" #include @@ -19,6 +20,8 @@ struct AssemblyOptions; struct IrBuilder { + IrBuilder(); + void buildFunctionIr(Proto* proto); void rebuildBytecodeBasicBlocks(Proto* proto); @@ -38,7 +41,7 @@ struct IrBuilder IrOp constUint(unsigned value); IrOp constDouble(double value); IrOp constTag(uint8_t value); - IrOp constAny(IrConst constant); + IrOp constAny(IrConst constant, uint64_t asCommonKey); IrOp cond(IrCondition cond); @@ -67,6 +70,45 @@ struct IrBuilder uint32_t activeBlockIdx = ~0u; std::vector instIndexToBlock; // Block index at the bytecode instruction + + // Similar to BytecodeBuilder, duplicate constants are removed used the same method + struct ConstantKey + { + IrConstKind kind; + // Note: this stores value* from IrConst; when kind is Double, this stores the same bits as double does but in uint64_t. + uint64_t value; + + bool operator==(const ConstantKey& key) const + { + return kind == key.kind && value == key.value; + } + }; + + struct ConstantKeyHash + { + size_t operator()(const ConstantKey& key) const + { + // finalizer from MurmurHash64B + const uint32_t m = 0x5bd1e995; + + uint32_t h1 = uint32_t(key.value); + uint32_t h2 = uint32_t(key.value >> 32) ^ (int(key.kind) * m); + + h1 ^= h2 >> 18; + h1 *= m; + h2 ^= h1 >> 22; + h2 *= m; + h1 ^= h2 >> 17; + h1 *= m; + h2 ^= h1 >> 19; + h2 *= m; + + // ... truncated to 32-bit output (normally hash is equal to (uint64_t(h1) << 32) | h2, but we only really need the lower 32-bit half) + return size_t(h2); + } + }; + + DenseHashMap constantMap; }; } // namespace CodeGen diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 439abb9..67e7063 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -22,7 +22,7 @@ namespace CodeGen // In the command description, following abbreviations are used: // * Rn - VM stack register slot, n in 0..254 // * Kn - VM proto constant slot, n in 0..2^23-1 -// * UPn - VM function upvalue slot, n in 0..254 +// * UPn - VM function upvalue slot, n in 0..199 // * A, B, C, D, E are instruction arguments enum class IrCmd : uint8_t { @@ -64,6 +64,11 @@ enum class IrCmd : uint8_t // A: pointer (Table) GET_SLOT_NODE_ADDR, + // Get pointer (LuaNode) to table node element at the main position of the specified key hash + // A: pointer (Table) + // B: unsigned int + GET_HASH_NODE_ADDR, + // Store a tag into TValue // A: Rn // B: tag @@ -173,6 +178,13 @@ enum class IrCmd : uint8_t // E: block (if false) JUMP_CMP_ANY, + // Perform a conditional jump based on cached table node slot matching the actual table node slot for a key + // A: pointer (LuaNode) + // B: Kn + // C: block (if matches) + // D: block (if it doesn't) + JUMP_SLOT_MATCH, + // Get table length // A: pointer (Table) TABLE_LEN, @@ -189,7 +201,13 @@ enum class IrCmd : uint8_t // Try to convert a double number into a table index (int) or jump if it's not an integer // A: double // B: block - NUM_TO_INDEX, + TRY_NUM_TO_INDEX, + + // Try to get pointer to tag method TValue inside the table's metatable or jump if there is no such value or metatable + // A: table + // B: int + // C: block + TRY_CALL_FASTGETTM, // Convert integer into a double number // A: int @@ -315,6 +333,11 @@ enum class IrCmd : uint8_t // C: block CHECK_SLOT_MATCH, + // Guard against table node with a linked next node to ensure that our lookup hits the main position of the key + // A: pointer (LuaNode) + // B: block + CHECK_NODE_NO_NEXT, + // Special operations // Check interrupt handler @@ -361,14 +384,6 @@ enum class IrCmd : uint8_t // E: unsigned int (table index to start from) LOP_SETLIST, - // Load function from source register using name into target register and copying source register into target register + 1 - // A: unsigned int (bytecode instruction index) - // B: Rn (target) - // C: Rn (source) - // D: block (next) - // E: block (fallback) - LOP_NAMECALL, - // Call specified function // A: unsigned int (bytecode instruction index) // B: Rn (function, followed by arguments) @@ -576,6 +591,16 @@ struct IrOp , index(index) { } + + bool operator==(const IrOp& rhs) const + { + return kind == rhs.kind && index == rhs.index; + } + + bool operator!=(const IrOp& rhs) const + { + return !(*this == rhs); + } }; static_assert(sizeof(IrOp) == 4); diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index a6329ec..ae517e8 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -37,5 +37,9 @@ std::string toString(IrFunction& function, bool includeUseInfo); std::string dump(IrFunction& function); +std::string toDot(IrFunction& function, bool includeInst); + +std::string dumpDot(IrFunction& function, bool includeInst); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 3b14a8c..0fc1402 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -98,7 +98,7 @@ inline bool isBlockTerminator(IrCmd cmd) case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_CMP_ANY: - case IrCmd::LOP_NAMECALL: + case IrCmd::JUMP_SLOT_MATCH: case IrCmd::LOP_RETURN: case IrCmd::LOP_FORGLOOP: case IrCmd::LOP_FORGLOOP_FALLBACK: @@ -125,6 +125,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::LOAD_ENV: case IrCmd::GET_ARR_ADDR: case IrCmd::GET_SLOT_NODE_ADDR: + case IrCmd::GET_HASH_NODE_ADDR: case IrCmd::ADD_INT: case IrCmd::SUB_INT: case IrCmd::ADD_NUM: @@ -140,7 +141,8 @@ inline bool hasResult(IrCmd cmd) case IrCmd::TABLE_LEN: case IrCmd::NEW_TABLE: case IrCmd::DUP_TABLE: - case IrCmd::NUM_TO_INDEX: + case IrCmd::TRY_NUM_TO_INDEX: + case IrCmd::TRY_CALL_FASTGETTM: case IrCmd::INT_TO_NUM: case IrCmd::SUBSTITUTE: case IrCmd::INVOKE_FASTCALL: diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp index e1950db..4d04a24 100644 --- a/CodeGen/src/CodeAllocator.cpp +++ b/CodeGen/src/CodeAllocator.cpp @@ -112,7 +112,7 @@ CodeAllocator::~CodeAllocator() } bool CodeAllocator::allocate( - uint8_t* data, size_t dataSize, uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart) + const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart) { // 'Round up' to preserve code alignment size_t alignedDataSize = (dataSize + (kCodeAlignment - 1)) & ~(kCodeAlignment - 1); diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index c794972..ce490f9 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -1,7 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/CodeGen.h" -#include "Luau/AssemblyBuilderX64.h" #include "Luau/Common.h" #include "Luau/CodeAllocator.h" #include "Luau/CodeBlockUnwind.h" @@ -9,12 +8,17 @@ #include "Luau/IrBuilder.h" #include "Luau/OptimizeConstProp.h" #include "Luau/OptimizeFinalX64.h" + #include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilderDwarf2.h" #include "Luau/UnwindBuilderWin.h" +#include "Luau/AssemblyBuilderX64.h" +#include "Luau/AssemblyBuilderA64.h" + #include "CustomExecUtils.h" #include "CodeGenX64.h" +#include "CodeGenA64.h" #include "EmitCommonX64.h" #include "EmitInstructionX64.h" #include "IrLoweringX64.h" @@ -39,32 +43,55 @@ namespace Luau namespace CodeGen { -constexpr uint32_t kFunctionAlignment = 32; - -static void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers) -{ - if (build.logText) - build.logAppend("; exitContinueVm\n"); - helpers.exitContinueVm = build.setLabel(); - emitExit(build, /* continueInVm */ true); - - if (build.logText) - build.logAppend("; exitNoContinueVm\n"); - helpers.exitNoContinueVm = build.setLabel(); - emitExit(build, /* continueInVm */ false); - - if (build.logText) - build.logAppend("; continueCallInVm\n"); - helpers.continueCallInVm = build.setLabel(); - emitContinueCallInVm(build); -} - -static NativeProto* assembleFunction(X64::AssemblyBuilderX64& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +static NativeProto* createNativeProto(Proto* proto, const IrBuilder& ir) { NativeProto* result = new NativeProto(); result->proto = proto; + result->instTargets = new uintptr_t[proto->sizecode]; + for (int i = 0; i < proto->sizecode; i++) + { + auto [irLocation, asmLocation] = ir.function.bcMapping[i]; + + result->instTargets[i] = irLocation == ~0u ? 0 : asmLocation; + } + + return result; +} + +[[maybe_unused]] static void lowerIr( + X64::AssemblyBuilderX64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +{ + constexpr uint32_t kFunctionAlignment = 32; + + optimizeMemoryOperandsX64(ir.function); + + build.align(kFunctionAlignment, X64::AlignmentDataX64::Ud2); + + X64::IrLoweringX64 lowering(build, helpers, data, proto, ir.function); + + lowering.lower(options); +} + +[[maybe_unused]] static void lowerIr( + A64::AssemblyBuilderA64& build, IrBuilder& ir, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +{ + Label start = build.setLabel(); + + build.mov(A64::x0, 1); // finish function in VM + build.ret(); + + // TODO: This is only needed while we don't support all IR opcodes + // When we can't translate some parts of the function, we instead encode a dummy assembly sequence that hands off control to VM + // In the future we could return nullptr from assembleFunction and handle it because there may be other reasons for why we refuse to assemble. + for (int i = 0; i < proto->sizecode; i++) + ir.function.bcMapping[i].asmLocation = build.getLabelOffset(start); +} + +template +static NativeProto* assembleFunction(AssemblyBuilder& build, NativeState& data, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +{ if (options.includeAssembly || options.includeIr) { if (proto->debugname) @@ -93,43 +120,24 @@ static NativeProto* assembleFunction(X64::AssemblyBuilderX64& build, NativeState build.logAppend("\n"); } - build.align(kFunctionAlignment, X64::AlignmentDataX64::Ud2); - - Label start = build.setLabel(); - - IrBuilder builder; - builder.buildFunctionIr(proto); + IrBuilder ir; + ir.buildFunctionIr(proto); if (!FFlag::DebugCodegenNoOpt) { - constPropInBlockChains(builder); + constPropInBlockChains(ir); } // TODO: cfg info has to be computed earlier to use in optimizations // It's done here to appear in text output and to measure performance impact on code generation - computeCfgInfo(builder.function); + computeCfgInfo(ir.function); - optimizeMemoryOperandsX64(builder.function); - - X64::IrLoweringX64 lowering(build, helpers, data, proto, builder.function); - - lowering.lower(options); - - result->instTargets = new uintptr_t[proto->sizecode]; - - for (int i = 0; i < proto->sizecode; i++) - { - auto [irLocation, asmLocation] = builder.function.bcMapping[i]; - - result->instTargets[i] = irLocation == ~0u ? 0 : asmLocation - start.location; - } - - result->location = start.location; + lowerIr(build, ir, data, helpers, proto, options); if (build.logText) build.logAppend("\n"); - return result; + return createNativeProto(proto, ir); } static void destroyNativeProto(NativeProto* nativeProto) @@ -207,6 +215,8 @@ bool isSupported() if ((cpuinfo[2] & (1 << 28)) == 0) return false; + return true; +#elif defined(__aarch64__) return true; #else return false; @@ -232,11 +242,19 @@ void create(lua_State* L) initFallbackTable(data); initHelperFunctions(data); +#if defined(__x86_64__) || defined(_M_X64) if (!X64::initEntryFunction(data)) { destroyNativeState(L); return; } +#elif defined(__aarch64__) + if (!A64::initEntryFunction(data)) + { + destroyNativeState(L); + return; + } +#endif lua_ExecutionCallbacks* ecb = getExecutionCallbacks(L); @@ -270,14 +288,21 @@ void compile(lua_State* L, int idx) if (!getNativeState(L)) return; +#if defined(__aarch64__) + A64::AssemblyBuilderA64 build(/* logText= */ false); +#else X64::AssemblyBuilderX64 build(/* logText= */ false); +#endif + NativeState* data = getNativeState(L); std::vector protos; gatherFunctions(protos, clvalue(func)->l.p); ModuleHelpers helpers; - assembleHelpers(build, helpers); +#if !defined(__aarch64__) + X64::assembleHelpers(build, helpers); +#endif std::vector results; results.reserve(protos.size()); @@ -292,8 +317,8 @@ void compile(lua_State* L, int idx) uint8_t* nativeData = nullptr; size_t sizeNativeData = 0; uint8_t* codeStart = nullptr; - if (!data->codeAllocator.allocate( - build.data.data(), int(build.data.size()), build.code.data(), int(build.code.size()), nativeData, sizeNativeData, codeStart)) + if (!data->codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), + int(build.code.size() * sizeof(build.code[0])), nativeData, sizeNativeData, codeStart)) { for (NativeProto* result : results) destroyNativeProto(result); @@ -305,7 +330,7 @@ void compile(lua_State* L, int idx) for (NativeProto* result : results) { for (int i = 0; i < result->proto->sizecode; i++) - result->instTargets[i] += uintptr_t(codeStart + result->location); + result->instTargets[i] += uintptr_t(codeStart); LUAU_ASSERT(result->proto->sizecode); result->entryTarget = result->instTargets[0]; @@ -321,7 +346,11 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) LUAU_ASSERT(lua_isLfunction(L, idx)); const TValue* func = luaA_toobject(L, idx); +#if defined(__aarch64__) + A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly); +#else X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly); +#endif NativeState data; initFallbackTable(data); @@ -330,7 +359,9 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) gatherFunctions(protos, clvalue(func)->l.p); ModuleHelpers helpers; - assembleHelpers(build, helpers); +#if !defined(__aarch64__) + X64::assembleHelpers(build, helpers); +#endif for (Proto* p : protos) if (p) @@ -342,7 +373,9 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) build.finalize(); if (options.outputBinary) - return std::string(build.code.begin(), build.code.end()) + std::string(build.data.begin(), build.data.end()); + return std::string( + reinterpret_cast(build.code.data()), reinterpret_cast(build.code.data() + build.code.size())) + + std::string(build.data.begin(), build.data.end()); else return build.text; } diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp new file mode 100644 index 0000000..94d6f2e --- /dev/null +++ b/CodeGen/src/CodeGenA64.cpp @@ -0,0 +1,69 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "CodeGenA64.h" + +#include "Luau/AssemblyBuilderA64.h" +#include "Luau/UnwindBuilder.h" + +#include "CustomExecUtils.h" +#include "NativeState.h" + +#include "lstate.h" + +namespace Luau +{ +namespace CodeGen +{ +namespace A64 +{ + +bool initEntryFunction(NativeState& data) +{ + AssemblyBuilderA64 build(/* logText= */ false); + UnwindBuilder& unwind = *data.unwindBuilder.get(); + + unwind.start(); + unwind.allocStack(8); // TODO: this is only necessary to align stack by 16 bytes, as start() allocates 8b return pointer + + // TODO: prologue goes here + + unwind.finish(); + + size_t prologueSize = build.setLabel().location; + + // Setup native execution environment + // TODO: figure out state layout + + // Jump to the specified instruction; further control flow will be handled with custom ABI with register setup from EmitCommonX64.h + build.br(x2); + + // Even though we jumped away, we will return here in the end + Label returnOff = build.setLabel(); + + // Cleanup and exit + // TODO: epilogue + + build.ret(); + + build.finalize(); + + LUAU_ASSERT(build.data.empty()); + + if (!data.codeAllocator.allocate(build.data.data(), int(build.data.size()), reinterpret_cast(build.code.data()), + int(build.code.size() * sizeof(build.code[0])), data.gateData, data.gateDataSize, data.context.gateEntry)) + { + LUAU_ASSERT(!"failed to create entry function"); + return false; + } + + // Set the offset at the begining so that functions in new blocks will not overlay the locations + // specified by the unwind information of the entry function + unwind.setBeginOffset(prologueSize); + + data.context.gateExit = data.context.gateEntry + returnOff.location; + + return true; +} + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeGenA64.h b/CodeGen/src/CodeGenA64.h new file mode 100644 index 0000000..5043e5c --- /dev/null +++ b/CodeGen/src/CodeGenA64.h @@ -0,0 +1,18 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +namespace Luau +{ +namespace CodeGen +{ + +struct NativeState; + +namespace A64 +{ + +bool initEntryFunction(NativeState& data); + +} // namespace A64 +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index ac6c941..7df1a90 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -143,6 +143,24 @@ bool initEntryFunction(NativeState& data) return true; } +void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers) +{ + if (build.logText) + build.logAppend("; exitContinueVm\n"); + helpers.exitContinueVm = build.setLabel(); + emitExit(build, /* continueInVm */ true); + + if (build.logText) + build.logAppend("; exitNoContinueVm\n"); + helpers.exitNoContinueVm = build.setLabel(); + emitExit(build, /* continueInVm */ false); + + if (build.logText) + build.logAppend("; continueCallInVm\n"); + helpers.continueCallInVm = build.setLabel(); + emitContinueCallInVm(build); +} + } // namespace X64 } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGenX64.h b/CodeGen/src/CodeGenX64.h index b82266a..1f48311 100644 --- a/CodeGen/src/CodeGenX64.h +++ b/CodeGen/src/CodeGenX64.h @@ -7,11 +7,15 @@ namespace CodeGen { struct NativeState; +struct ModuleHelpers; namespace X64 { +class AssemblyBuilderX64; + bool initEntryFunction(NativeState& data); +void assembleHelpers(AssemblyBuilderX64& build, ModuleHelpers& helpers); } // namespace X64 } // namespace CodeGen diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 05b6355..d70b6ed 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -286,6 +286,31 @@ void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, int npa build.vmovsd(luauRegValue(ra), tmp0.reg); } +void emitBuiltinType(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +{ + ScopedRegX64 tmp0{regs, SizeX64::qword}; + ScopedRegX64 tag{regs, SizeX64::dword}; + + build.mov(tag.reg, luauRegTag(arg)); + + build.mov(tmp0.reg, qword[rState + offsetof(lua_State, global)]); + build.mov(tmp0.reg, qword[tmp0.reg + qwordReg(tag.reg) * sizeof(TString*) + offsetof(global_State, ttname)]); + + build.mov(luauRegValue(ra), tmp0.reg); +} + +void emitBuiltinTypeof(IrRegAllocX64& regs, AssemblyBuilderX64& build, int nparams, int ra, int arg, OperandX64 args, int nresults) +{ + regs.assertAllFree(); + + build.mov(rArg1, rState); + build.lea(rArg2, luauRegAddress(arg)); + + build.call(qword[rNativeContext + offsetof(NativeContext, luaT_objtypenamestr)]); + + build.mov(luauRegValue(ra), rax); +} + void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults) { OperandX64 argsOp = 0; @@ -353,6 +378,10 @@ void emitBuiltin(IrRegAllocX64& regs, AssemblyBuilderX64& build, int bfid, int r return emitBuiltinMathModf(regs, build, nparams, ra, arg, argsOp, nresults); case LBF_MATH_SIGN: return emitBuiltinMathSign(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_TYPE: + return emitBuiltinType(regs, build, nparams, ra, arg, argsOp, nresults); + case LBF_TYPEOF: + return emitBuiltinTypeof(regs, build, nparams, ra, arg, argsOp, nresults); default: LUAU_ASSERT(!"missing x64 lowering"); break; diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index 3b0aa25..e8f61eb 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -18,51 +18,6 @@ namespace CodeGen namespace X64 { -void emitInstNameCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, const TValue* k, Label& next, Label& fallback) -{ - int ra = LUAU_INSN_A(*pc); - int rb = LUAU_INSN_B(*pc); - uint32_t aux = pc[1]; - - Label secondfpath; - - jumpIfTagIsNot(build, rb, LUA_TTABLE, fallback); - - RegisterX64 table = r8; - build.mov(table, luauRegValue(rb)); - - // &h->node[tsvalue(kv)->hash & (sizenode(h) - 1)]; - RegisterX64 node = rdx; - build.mov(node, qword[table + offsetof(Table, node)]); - build.mov(eax, 1); - build.mov(cl, byte[table + offsetof(Table, lsizenode)]); - build.shl(eax, cl); - build.dec(eax); - build.and_(eax, tsvalue(&k[aux])->hash); - build.shl(rax, kLuaNodeSizeLog2); - build.add(node, rax); - - jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), secondfpath); - - setLuauReg(build, xmm0, ra + 1, luauReg(rb)); - setLuauReg(build, xmm0, ra, luauNodeValue(node)); - build.jmp(next); - - build.setLabel(secondfpath); - - jumpIfNodeHasNext(build, node, fallback); - callGetFastTmOrFallback(build, table, TM_INDEX, fallback); - jumpIfTagIsNot(build, rax, LUA_TTABLE, fallback); - - build.mov(table, qword[rax + offsetof(TValue, value)]); - - getTableNodeAtCachedSlot(build, rax, node, table, pcpos); - jumpIfNodeKeyNotInExpectedSlot(build, rax, node, luauConstantValue(aux), fallback); - - setLuauReg(build, xmm0, ra + 1, luauReg(rb)); - setLuauReg(build, xmm0, ra, luauNodeValue(node)); -} - void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos) { int ra = LUAU_INSN_A(*pc); diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index dcca52a..6a8a3c0 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -21,7 +21,6 @@ namespace X64 class AssemblyBuilderX64; -void emitInstNameCall(AssemblyBuilderX64& build, const Instruction* pc, int pcpos, const TValue* k, Label& next, Label& fallback); void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, const Instruction* pc, int pcpos); void emitInstSetList(AssemblyBuilderX64& build, const Instruction* pc, Label& next); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index dc7d771..b998487 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -124,6 +124,10 @@ static void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& de { if (!defRs.varargSeq) { + // Peel away registers from variadic sequence that we define + while (defRs.regs.test(varargStart)) + varargStart++; + LUAU_ASSERT(!sourceRs.varargSeq || sourceRs.varargStart == varargStart); sourceRs.varargSeq = true; @@ -296,11 +300,6 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& use(inst.b); useRange(inst.c.index, function.intOp(inst.d)); break; - case IrCmd::LOP_NAMECALL: - use(inst.c); - - defRange(inst.b.index, 2); - break; case IrCmd::LOP_CALL: use(inst.b); useRange(inst.b.index + 1, function.intOp(inst.c)); @@ -411,6 +410,13 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& break; default: + // All instructions which reference registers have to be handled explicitly + LUAU_ASSERT(inst.a.kind != IrOpKind::VmReg); + LUAU_ASSERT(inst.b.kind != IrOpKind::VmReg); + LUAU_ASSERT(inst.c.kind != IrOpKind::VmReg); + LUAU_ASSERT(inst.d.kind != IrOpKind::VmReg); + LUAU_ASSERT(inst.e.kind != IrOpKind::VmReg); + LUAU_ASSERT(inst.f.kind != IrOpKind::VmReg); break; } } @@ -430,17 +436,20 @@ static void computeCfgLiveInOutRegSets(IrFunction& function) { CfgInfo& info = function.cfg; + // Clear existing data + // 'in' and 'captured' data is not cleared because it will be overwritten below + info.def.clear(); + info.out.clear(); + // Try to compute Luau VM register use-def info info.in.resize(function.blocks.size()); + info.def.resize(function.blocks.size()); info.out.resize(function.blocks.size()); // Captured registers are tracked for the whole function // It should be possible to have a more precise analysis for them in the future std::bitset<256> capturedRegs; - std::vector defRss; - defRss.resize(function.blocks.size()); - // First we compute live-in set of each block for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) { @@ -449,7 +458,7 @@ static void computeCfgLiveInOutRegSets(IrFunction& function) if (block.kind == IrBlockKind::Dead) continue; - info.in[blockIdx] = computeBlockLiveInRegSet(function, block, defRss[blockIdx], capturedRegs); + info.in[blockIdx] = computeBlockLiveInRegSet(function, block, info.def[blockIdx], capturedRegs); } info.captured.regs = capturedRegs; @@ -480,8 +489,8 @@ static void computeCfgLiveInOutRegSets(IrFunction& function) IrBlock& curr = function.blocks[blockIdx]; RegisterSet& inRs = info.in[blockIdx]; + RegisterSet& defRs = info.def[blockIdx]; RegisterSet& outRs = info.out[blockIdx]; - RegisterSet& defRs = defRss[blockIdx]; // Current block has to provide all registers in successor blocks for (uint32_t succIdx : successors(info, blockIdx)) @@ -547,6 +556,10 @@ static void computeCfgBlockEdges(IrFunction& function) { CfgInfo& info = function.cfg; + // Clear existing data + info.predecessorsOffsets.clear(); + info.successorsOffsets.clear(); + // Compute predecessors block edges info.predecessorsOffsets.reserve(function.blocks.size()); info.successorsOffsets.reserve(function.blocks.size()); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 0a700db..f1099cf 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrBuilder.h" -#include "Luau/Common.h" -#include "Luau/DenseHash.h" #include "Luau/IrAnalysis.h" #include "Luau/IrUtils.h" @@ -11,6 +9,8 @@ #include "lapi.h" +#include + namespace Luau { namespace CodeGen @@ -18,6 +18,11 @@ namespace CodeGen constexpr unsigned kNoAssociatedBlockIndex = ~0u; +IrBuilder::IrBuilder() + : constantMap({IrConstKind::Bool, ~0ull}) +{ +} + void IrBuilder::buildFunctionIr(Proto* proto) { function.proto = proto; @@ -377,19 +382,8 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstCapture(*this, pc, i); break; case LOP_NAMECALL: - { - IrOp next = blockAtInst(i + getOpLength(LOP_NAMECALL)); - IrOp fallback = block(IrBlockKind::Fallback); - - inst(IrCmd::LOP_NAMECALL, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), next, fallback); - - beginBlock(fallback); - inst(IrCmd::FALLBACK_NAMECALL, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), vmConst(pc[1])); - inst(IrCmd::JUMP, next); - - beginBlock(next); + translateInstNamecall(*this, pc, i); break; - } case LOP_PREPVARARGS: inst(IrCmd::FALLBACK_PREPVARARGS, constUint(i), constInt(LUAU_INSN_A(*pc))); break; @@ -501,7 +495,7 @@ IrOp IrBuilder::constBool(bool value) IrConst constant; constant.kind = IrConstKind::Bool; constant.valueBool = value; - return constAny(constant); + return constAny(constant, uint64_t(value)); } IrOp IrBuilder::constInt(int value) @@ -509,7 +503,7 @@ IrOp IrBuilder::constInt(int value) IrConst constant; constant.kind = IrConstKind::Int; constant.valueInt = value; - return constAny(constant); + return constAny(constant, uint64_t(value)); } IrOp IrBuilder::constUint(unsigned value) @@ -517,7 +511,7 @@ IrOp IrBuilder::constUint(unsigned value) IrConst constant; constant.kind = IrConstKind::Uint; constant.valueUint = value; - return constAny(constant); + return constAny(constant, uint64_t(value)); } IrOp IrBuilder::constDouble(double value) @@ -525,7 +519,12 @@ IrOp IrBuilder::constDouble(double value) IrConst constant; constant.kind = IrConstKind::Double; constant.valueDouble = value; - return constAny(constant); + + uint64_t asCommonKey; + static_assert(sizeof(asCommonKey) == sizeof(value), "Expecting double to be 64-bit"); + memcpy(&asCommonKey, &value, sizeof(value)); + + return constAny(constant, asCommonKey); } IrOp IrBuilder::constTag(uint8_t value) @@ -533,13 +532,21 @@ IrOp IrBuilder::constTag(uint8_t value) IrConst constant; constant.kind = IrConstKind::Tag; constant.valueTag = value; - return constAny(constant); + return constAny(constant, uint64_t(value)); } -IrOp IrBuilder::constAny(IrConst constant) +IrOp IrBuilder::constAny(IrConst constant, uint64_t asCommonKey) { + ConstantKey key{constant.kind, asCommonKey}; + + if (uint32_t* cache = constantMap.find(key)) + return {IrOpKind::Constant, *cache}; + uint32_t index = uint32_t(function.constants.size()); function.constants.push_back(constant); + + constantMap[key] = index; + return {IrOpKind::Constant, index}; } diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 2787fb1..3c4e420 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -90,6 +90,8 @@ const char* getCmdName(IrCmd cmd) return "GET_ARR_ADDR"; case IrCmd::GET_SLOT_NODE_ADDR: return "GET_SLOT_NODE_ADDR"; + case IrCmd::GET_HASH_NODE_ADDR: + return "GET_HASH_NODE_ADDR"; case IrCmd::STORE_TAG: return "STORE_TAG"; case IrCmd::STORE_POINTER: @@ -142,14 +144,18 @@ const char* getCmdName(IrCmd cmd) return "JUMP_CMP_NUM"; case IrCmd::JUMP_CMP_ANY: return "JUMP_CMP_ANY"; + case IrCmd::JUMP_SLOT_MATCH: + return "JUMP_SLOT_MATCH"; case IrCmd::TABLE_LEN: return "TABLE_LEN"; case IrCmd::NEW_TABLE: return "NEW_TABLE"; case IrCmd::DUP_TABLE: return "DUP_TABLE"; - case IrCmd::NUM_TO_INDEX: - return "NUM_TO_INDEX"; + case IrCmd::TRY_NUM_TO_INDEX: + return "TRY_NUM_TO_INDEX"; + case IrCmd::TRY_CALL_FASTGETTM: + return "TRY_CALL_FASTGETTM"; case IrCmd::INT_TO_NUM: return "INT_TO_NUM"; case IrCmd::ADJUST_STACK_TO_REG: @@ -192,6 +198,8 @@ const char* getCmdName(IrCmd cmd) return "CHECK_ARRAY_SIZE"; case IrCmd::CHECK_SLOT_MATCH: return "CHECK_SLOT_MATCH"; + case IrCmd::CHECK_NODE_NO_NEXT: + return "CHECK_NODE_NO_NEXT"; case IrCmd::INTERRUPT: return "INTERRUPT"; case IrCmd::CHECK_GC: @@ -210,8 +218,6 @@ const char* getCmdName(IrCmd cmd) return "CAPTURE"; case IrCmd::LOP_SETLIST: return "LOP_SETLIST"; - case IrCmd::LOP_NAMECALL: - return "LOP_NAMECALL"; case IrCmd::LOP_CALL: return "LOP_CALL"; case IrCmd::LOP_RETURN: @@ -397,7 +403,7 @@ static void appendBlockSet(IrToStringContext& ctx, BlockIteratorWrapper blocks) } } -static void appendRegisterSet(IrToStringContext& ctx, const RegisterSet& rs) +static void appendRegisterSet(IrToStringContext& ctx, const RegisterSet& rs, const char* separator) { bool comma = false; @@ -406,7 +412,7 @@ static void appendRegisterSet(IrToStringContext& ctx, const RegisterSet& rs) if (rs.regs.test(i)) { if (comma) - append(ctx.result, ", "); + ctx.result.append(separator); comma = true; append(ctx.result, "R%d", int(i)); @@ -416,7 +422,7 @@ static void appendRegisterSet(IrToStringContext& ctx, const RegisterSet& rs) if (rs.varargSeq) { if (comma) - append(ctx.result, ", "); + ctx.result.append(separator); append(ctx.result, "R%d...", rs.varargStart); } @@ -428,7 +434,7 @@ void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t ind if (block.useCount == 0 && block.kind != IrBlockKind::Dead && ctx.cfg.captured.regs.any()) { append(ctx.result, "; captured regs: "); - appendRegisterSet(ctx, ctx.cfg.captured); + appendRegisterSet(ctx, ctx.cfg.captured, ", "); append(ctx.result, "\n\n"); } @@ -484,7 +490,7 @@ void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t ind if (in.regs.any() || in.varargSeq) { append(ctx.result, "; in regs: "); - appendRegisterSet(ctx, in); + appendRegisterSet(ctx, in, ", "); append(ctx.result, "\n"); } } @@ -497,7 +503,7 @@ void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t ind if (out.regs.any() || out.varargSeq) { append(ctx.result, "; out regs: "); - appendRegisterSet(ctx, out); + appendRegisterSet(ctx, out, ", "); append(ctx.result, "\n"); } } @@ -551,5 +557,108 @@ std::string dump(IrFunction& function) return result; } +std::string toDot(IrFunction& function, bool includeInst) +{ + std::string result; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; + + auto appendLabelRegset = [&ctx](std::vector& regSets, size_t blockIdx, const char* name) { + if (blockIdx < regSets.size()) + { + const RegisterSet& rs = regSets[blockIdx]; + + if (rs.regs.any() || rs.varargSeq) + { + append(ctx.result, "|{%s|", name); + appendRegisterSet(ctx, rs, "|"); + append(ctx.result, "}"); + } + } + }; + + append(ctx.result, "digraph CFG {\n"); + append(ctx.result, "node[shape=record]\n"); + + for (size_t i = 0; i < function.blocks.size(); i++) + { + IrBlock& block = function.blocks[i]; + + append(ctx.result, "b%u [", unsigned(i)); + + if (block.kind == IrBlockKind::Fallback) + append(ctx.result, "style=filled;fillcolor=salmon;"); + else if (block.kind == IrBlockKind::Bytecode) + append(ctx.result, "style=filled;fillcolor=palegreen;"); + + append(ctx.result, "label=\"{"); + toString(ctx, block, uint32_t(i)); + + appendLabelRegset(ctx.cfg.in, i, "in"); + + if (includeInst && block.start != ~0u) + { + for (uint32_t instIdx = block.start; instIdx <= block.finish; instIdx++) + { + IrInst& inst = function.instructions[instIdx]; + + // Skip pseudo instructions unless they are still referenced + if (isPseudo(inst.cmd) && inst.useCount == 0) + continue; + + append(ctx.result, "|"); + toString(ctx, inst, instIdx); + } + } + + appendLabelRegset(ctx.cfg.def, i, "def"); + appendLabelRegset(ctx.cfg.out, i, "out"); + + append(ctx.result, "}\"];\n"); + } + + for (size_t i = 0; i < function.blocks.size(); i++) + { + IrBlock& block = function.blocks[i]; + + if (block.start == ~0u) + continue; + + for (uint32_t instIdx = block.start; instIdx != ~0u && instIdx <= block.finish; instIdx++) + { + IrInst& inst = function.instructions[instIdx]; + + auto checkOp = [&](IrOp op) { + if (op.kind == IrOpKind::Block) + { + if (function.blocks[op.index].kind != IrBlockKind::Fallback) + append(ctx.result, "b%u -> b%u [weight=10];\n", unsigned(i), op.index); + else + append(ctx.result, "b%u -> b%u;\n", unsigned(i), op.index); + } + }; + + checkOp(inst.a); + checkOp(inst.b); + checkOp(inst.c); + checkOp(inst.d); + checkOp(inst.e); + checkOp(inst.f); + } + } + + append(ctx.result, "}\n"); + + return result; +} + +std::string dumpDot(IrFunction& function, bool includeInst) +{ + std::string result = toDot(function, includeInst); + + printf("%s\n", result.c_str()); + + return result; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 3b27d09..b45ce22 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -200,6 +200,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(inst.regX64, luauRegValue(inst.a.index)); else if (inst.a.kind == IrOpKind::VmConst) build.mov(inst.regX64, luauConstantValue(inst.a.index)); + // If we have a register, we assume it's a pointer to TValue + // We might introduce explicit operand types in the future to make this more robust + else if (inst.a.kind == IrOpKind::Inst) + build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(TValue, value)]); else LUAU_ASSERT(!"Unsupported instruction form"); break; @@ -277,6 +281,25 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) getTableNodeAtCachedSlot(build, tmp.reg, inst.regX64, regOp(inst.a), uintOp(inst.b)); break; } + case IrCmd::GET_HASH_NODE_ADDR: + { + inst.regX64 = regs.allocGprReg(SizeX64::qword); + + // Custom bit shift value can only be placed in cl + ScopedRegX64 shiftTmp{regs, regs.takeGprReg(rcx)}; + + ScopedRegX64 tmp{regs, SizeX64::qword}; + + build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(Table, node)]); + build.mov(dwordReg(tmp.reg), 1); + build.mov(byteReg(shiftTmp.reg), byte[regOp(inst.a) + offsetof(Table, lsizenode)]); + build.shl(dwordReg(tmp.reg), byteReg(shiftTmp.reg)); + build.dec(dwordReg(tmp.reg)); + build.and_(dwordReg(tmp.reg), uintOp(inst.b)); + build.shl(tmp.reg, kLuaNodeSizeLog2); + build.add(inst.regX64, tmp.reg); + break; + }; case IrCmd::STORE_TAG: LUAU_ASSERT(inst.a.kind == IrOpKind::VmReg); @@ -686,6 +709,16 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpOrFallthrough(blockOp(inst.e), next); break; } + case IrCmd::JUMP_SLOT_MATCH: + { + LUAU_ASSERT(inst.b.kind == IrOpKind::VmConst); + + ScopedRegX64 tmp{regs, SizeX64::qword}; + + jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(inst.b.index), labelOp(inst.d)); + jumpOrFallthrough(blockOp(inst.c), next); + break; + } case IrCmd::TABLE_LEN: inst.regX64 = regs.allocXmmReg(); @@ -715,7 +748,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) if (inst.regX64 != rax) build.mov(inst.regX64, rax); break; - case IrCmd::NUM_TO_INDEX: + case IrCmd::TRY_NUM_TO_INDEX: { inst.regX64 = regs.allocGprReg(SizeX64::dword); @@ -724,6 +757,16 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) convertNumberToIndexOrJump(build, tmp.reg, regOp(inst.a), inst.regX64, labelOp(inst.b)); break; } + case IrCmd::TRY_CALL_FASTGETTM: + { + inst.regX64 = regs.allocGprReg(SizeX64::qword); + + callGetFastTmOrFallback(build, regOp(inst.a), TMS(intOp(inst.b)), labelOp(inst.c)); + + if (inst.regX64 != rax) + build.mov(inst.regX64, rax); + break; + } case IrCmd::INT_TO_NUM: inst.regX64 = regs.allocXmmReg(); @@ -1017,6 +1060,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) jumpIfNodeKeyNotInExpectedSlot(build, tmp.reg, regOp(inst.a), luauConstantValue(inst.b.index), labelOp(inst.c)); break; } + case IrCmd::CHECK_NODE_NO_NEXT: + jumpIfNodeHasNext(build, regOp(inst.a), labelOp(inst.b)); + break; case IrCmd::INTERRUPT: emitInterrupt(build, uintOp(inst.a)); break; @@ -1114,16 +1160,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.setLabel(next); break; } - case IrCmd::LOP_NAMECALL: - { - const Instruction* pc = proto->code + uintOp(inst.a); - LUAU_ASSERT(inst.b.kind == IrOpKind::VmReg); - LUAU_ASSERT(inst.c.kind == IrOpKind::VmReg); - - emitInstNameCall(build, pc, uintOp(inst.a), proto->k, blockOp(inst.d).label, blockOp(inst.e).label); - jumpOrFallthrough(blockOp(inst.d), next); - break; - } case IrCmd::LOP_CALL: { const Instruction* pc = proto->code + uintOp(inst.a); diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index bc90910..d9f935c 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -210,6 +210,34 @@ BuiltinImplResult translateBuiltinMathClamp(IrBuilder& build, int nparams, int r return {BuiltinImplType::UsesFallback, 1}; } +BuiltinImplResult translateBuiltinType(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.inst( + IrCmd::FASTCALL, build.constUint(LBF_TYPE), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + + // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TSTRING)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +BuiltinImplResult translateBuiltinTypeof(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, IrOp fallback) +{ + if (nparams < 1 || nresults > 1) + return {BuiltinImplType::None, -1}; + + build.inst( + IrCmd::FASTCALL, build.constUint(LBF_TYPEOF), build.vmReg(ra), build.vmReg(arg), args, build.constInt(nparams), build.constInt(nresults)); + + // TODO: tag update might not be required, we place it here now because FASTCALL is not modeled in constant propagation yet + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TSTRING)); + + return {BuiltinImplType::UsesFallback, 1}; +} + BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, IrOp args, int nparams, int nresults, IrOp fallback) { switch (bfid) @@ -254,6 +282,10 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, case LBF_MATH_FREXP: case LBF_MATH_MODF: return translateBuiltinNumberTo2Number(build, LuauBuiltinFunction(bfid), nparams, ra, arg, args, nresults, fallback); + case LBF_TYPE: + return translateBuiltinType(build, nparams, ra, arg, args, nresults, fallback); + case LBF_TYPEOF: + return translateBuiltinTypeof(build, nparams, ra, arg, args, nresults, fallback); default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 48ca397..28c6aca 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -806,7 +806,7 @@ void translateInstGetTable(IrBuilder& build, const Instruction* pc, int pcpos) IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rc)); - IrOp index = build.inst(IrCmd::NUM_TO_INDEX, vc, fallback); + IrOp index = build.inst(IrCmd::TRY_NUM_TO_INDEX, vc, fallback); index = build.inst(IrCmd::SUB_INT, index, build.constInt(1)); @@ -843,7 +843,7 @@ void translateInstSetTable(IrBuilder& build, const Instruction* pc, int pcpos) IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); IrOp vc = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(rc)); - IrOp index = build.inst(IrCmd::NUM_TO_INDEX, vc, fallback); + IrOp index = build.inst(IrCmd::TRY_NUM_TO_INDEX, vc, fallback); index = build.inst(IrCmd::SUB_INT, index, build.constInt(1)); @@ -1035,5 +1035,63 @@ void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos) } } +void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) +{ + int ra = LUAU_INSN_A(*pc); + int rb = LUAU_INSN_B(*pc); + uint32_t aux = pc[1]; + + IrOp next = build.blockAtInst(pcpos + getOpLength(LOP_NAMECALL)); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp firstFastPathSuccess = build.block(IrBlockKind::Internal); + IrOp secondFastPath = build.block(IrBlockKind::Internal); + + build.loadAndCheckTag(build.vmReg(rb), LUA_TTABLE, fallback); + IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); + + LUAU_ASSERT(build.function.proto); + IrOp addrNodeEl = build.inst(IrCmd::GET_HASH_NODE_ADDR, table, build.constUint(tsvalue(&build.function.proto->k[aux])->hash)); + + // We use 'jump' version instead of 'check' guard because we are jumping away into a non-fallback block + // This is required by CFG live range analysis because both non-fallback blocks define the same registers + build.inst(IrCmd::JUMP_SLOT_MATCH, addrNodeEl, build.vmConst(aux), firstFastPathSuccess, secondFastPath); + + build.beginBlock(firstFastPathSuccess); + build.inst(IrCmd::STORE_POINTER, build.vmReg(ra + 1), table); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TTABLE)); + + IrOp nodeEl = build.inst(IrCmd::LOAD_NODE_VALUE_TV, addrNodeEl); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), nodeEl); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(secondFastPath); + + build.inst(IrCmd::CHECK_NODE_NO_NEXT, addrNodeEl, fallback); + + IrOp indexPtr = build.inst(IrCmd::TRY_CALL_FASTGETTM, table, build.constInt(TM_INDEX), fallback); + + build.loadAndCheckTag(indexPtr, LUA_TTABLE, fallback); + IrOp index = build.inst(IrCmd::LOAD_POINTER, indexPtr); + + IrOp addrIndexNodeEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, index, build.constUint(pcpos)); + build.inst(IrCmd::CHECK_SLOT_MATCH, addrIndexNodeEl, build.vmConst(aux), fallback); + + // TODO: original 'table' was clobbered by a call inside 'FASTGETTM' + // Ideally, such calls should have to effect on SSA IR values, but simple register allocator doesn't support it + IrOp table2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(ra + 1), table2); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TTABLE)); + + IrOp indexNodeEl = build.inst(IrCmd::LOAD_NODE_VALUE_TV, addrIndexNodeEl); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), indexNodeEl); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(fallback); + build.inst(IrCmd::FALLBACK_NAMECALL, build.constUint(pcpos), build.vmReg(ra), build.vmReg(rb), build.vmConst(aux)); + build.inst(IrCmd::JUMP, next); + + build.beginBlock(next); +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index 0d4a509..0be111d 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -60,6 +60,7 @@ void translateInstGetGlobal(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstConcat(IrBuilder& build, const Instruction* pc, int pcpos); void translateInstCapture(IrBuilder& build, const Instruction* pc, int pcpos); +void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos); } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index d8115be..e29a5b0 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -360,7 +360,7 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 replace(function, block, index, {IrCmd::JUMP, inst.e}); } break; - case IrCmd::NUM_TO_INDEX: + case IrCmd::TRY_NUM_TO_INDEX: if (inst.a.kind == IrOpKind::Constant) { double value = function.doubleOp(inst.a); diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index 33db54e..f79bcab 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -80,6 +80,7 @@ void initHelperFunctions(NativeState& data) data.context.luaF_close = luaF_close; data.context.luaT_gettm = luaT_gettm; + data.context.luaT_objtypenamestr = luaT_objtypenamestr; data.context.libm_exp = exp; data.context.libm_pow = pow; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index ad5aca6..bebf421 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -39,7 +39,6 @@ struct NativeProto uintptr_t* instTargets = nullptr; // TODO: NativeProto should be variable-size with all target embedded Proto* proto = nullptr; - uint32_t location = 0; }; struct NativeContext @@ -79,6 +78,7 @@ struct NativeContext void (*luaF_close)(lua_State* L, StkId level) = nullptr; const TValue* (*luaT_gettm)(Table* events, TMS event, TString* ename) = nullptr; + const TString* (*luaT_objtypenamestr)(lua_State* L, const TValue* o) = nullptr; double (*libm_exp)(double) = nullptr; double (*libm_pow)(double, double) = nullptr; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 956c96d..b12a9b9 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -319,10 +319,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& { if (inst.b.kind == IrOpKind::Constant) { - std::optional oldValue = function.asDoubleOp(state.tryGetValue(inst.a)); - double newValue = function.doubleOp(inst.b); - - if (oldValue && *oldValue == newValue) + if (state.tryGetValue(inst.a) == inst.b) kill(function, inst); else state.saveValue(inst.a, inst.b); @@ -338,10 +335,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& { if (inst.b.kind == IrOpKind::Constant) { - std::optional oldValue = function.asIntOp(state.tryGetValue(inst.a)); - int newValue = function.intOp(inst.b); - - if (oldValue && *oldValue == newValue) + if (state.tryGetValue(inst.a) == inst.b) kill(function, inst); else state.saveValue(inst.a, inst.b); @@ -504,6 +498,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::LOAD_ENV: case IrCmd::GET_ARR_ADDR: case IrCmd::GET_SLOT_NODE_ADDR: + case IrCmd::GET_HASH_NODE_ADDR: case IrCmd::STORE_NODE_VALUE_TV: case IrCmd::ADD_INT: case IrCmd::SUB_INT: @@ -519,13 +514,16 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::NOT_ANY: case IrCmd::JUMP: case IrCmd::JUMP_EQ_POINTER: + case IrCmd::JUMP_SLOT_MATCH: case IrCmd::TABLE_LEN: case IrCmd::NEW_TABLE: case IrCmd::DUP_TABLE: - case IrCmd::NUM_TO_INDEX: + case IrCmd::TRY_NUM_TO_INDEX: + case IrCmd::TRY_CALL_FASTGETTM: case IrCmd::INT_TO_NUM: case IrCmd::CHECK_ARRAY_SIZE: case IrCmd::CHECK_SLOT_MATCH: + case IrCmd::CHECK_NODE_NO_NEXT: case IrCmd::BARRIER_TABLE_BACK: case IrCmd::LOP_RETURN: case IrCmd::LOP_COVERAGE: @@ -552,7 +550,6 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::CONCAT: case IrCmd::PREPARE_FORN: case IrCmd::INTERRUPT: // TODO: it will be important to keep tag/value state, but we have to track register capture - case IrCmd::LOP_NAMECALL: case IrCmd::LOP_CALL: case IrCmd::LOP_FORGLOOP: case IrCmd::LOP_FORGLOOP_FALLBACK: @@ -633,7 +630,7 @@ static std::vector collectDirectBlockJumpPath(IrFunction& function, st // * if the successor has multiple uses, it can't have such 'live in' values without phi nodes that we don't have yet // Another possibility is to have two paths from 'block' into the target through two intermediate blocks // Usually that would mean that we would have a conditional jump at the end of 'block' - // But using check guards and fallback clocks it becomes a possible setup + // But using check guards and fallback blocks it becomes a possible setup // We avoid this by making sure fallbacks rejoin the other immediate successor of 'block' LUAU_ASSERT(getLiveOutValueCount(function, *block) == 0); diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index a95ed09..0b3134b 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -201,6 +201,7 @@ void UnwindBuilderDwarf2::setupFrameReg(X64::RegisterX64 reg, int espOffset) void UnwindBuilderDwarf2::finish() { LUAU_ASSERT(stackOffset % 16 == 0 && "stack has to be aligned to 16 bytes after prologue"); + LUAU_ASSERT(fdeEntryStart != nullptr); pos = alignPosition(fdeEntryStart, pos); writeu32(fdeEntryStart, unsigned(pos - fdeEntryStart - 4)); // Length field itself is excluded from length @@ -220,7 +221,9 @@ void UnwindBuilderDwarf2::finalize(char* target, void* funcAddress, size_t funcS { memcpy(target, rawData, getSize()); + LUAU_ASSERT(fdeEntryStart != nullptr); unsigned fdeEntryStartPos = unsigned(fdeEntryStart - rawData); + writeu64((uint8_t*)target + fdeEntryStartPos + kFdeInitialLocationOffset, uintptr_t(funcAddress)); writeu64((uint8_t*)target + fdeEntryStartPos + kFdeAddressRangeOffset, funcSize); } diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 67f9fbe..82bf6e5 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -25,7 +25,7 @@ // Additionally, in some specific instructions such as ANDK, the limit on the encoded value is smaller; this means that if a value is larger, a different instruction must be selected. // // Registers: 0-254. Registers refer to the values on the function's stack frame, including arguments. -// Upvalues: 0-254. Upvalues refer to the values stored in the closure object. +// Upvalues: 0-199. Upvalues refer to the values stored in the closure object. // Constants: 0-2^23-1. Constants are stored in a table allocated with each proto; to allow for future bytecode tweaks the encodable value is limited to 23 bits. // Closures: 0-2^15-1. Closures are created from child protos via a child index; the limit is for the number of closures immediately referenced in each function. // Jumps: -2^23..2^23. Jump offsets are specified in word increments, so jumping over an instruction may sometimes require an offset of 2 or more. Note that for jump instructions with AUX, the AUX word is included as part of the jump offset. @@ -93,12 +93,12 @@ enum LuauOpcode // GETUPVAL: load upvalue from the upvalue table for the current function // A: target register - // B: upvalue index (0..255) + // B: upvalue index LOP_GETUPVAL, // SETUPVAL: store value into the upvalue table for the current function // A: target register - // B: upvalue index (0..255) + // B: upvalue index LOP_SETUPVAL, // CLOSEUPVALS: close (migrate to heap) all upvalues that were captured for registers >= target diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index 35e11ca..8eca105 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -11,8 +11,9 @@ inline bool isFlagExperimental(const char* flag) // Flags in this list are disabled by default in various command-line tools. They may have behavior that is not fully final, // or critical bugs that are found after the code has been submitted. static const char* const kList[] = { - "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code - "LuauTypecheckTypeguards", // requires some fixes to lua-apps code (CLI-67030) + "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code + "LuauTypecheckTypeguards", // requires some fixes to lua-apps code (CLI-67030) + "LuauTinyControlFlowAnalysis", // waiting for updates to packages depended by internal builtin plugins // makes sure we always have at least one entry nullptr, }; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 78896d3..03f4b3e 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -25,7 +25,6 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileTerminateBC, false) LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinArity, false) namespace Luau @@ -143,7 +142,7 @@ struct Compiler return stat->body.size > 0 && alwaysTerminates(stat->body.data[stat->body.size - 1]); else if (node->is()) return true; - else if (FFlag::LuauCompileTerminateBC && (node->is() || node->is())) + else if (node->is() || node->is()) return true; else if (AstStatIf* stat = node->as()) return stat->elsebody && alwaysTerminates(stat->thenbody) && alwaysTerminates(stat->elsebody); diff --git a/Makefile b/Makefile index 66d6016..5851229 100644 --- a/Makefile +++ b/Makefile @@ -143,6 +143,9 @@ aliases: $(EXECUTABLE_ALIASES) test: $(TESTS_TARGET) $(TESTS_TARGET) $(TESTS_ARGS) +conformance: $(TESTS_TARGET) + $(TESTS_TARGET) $(TESTS_ARGS) -ts=Conformance + clean: rm -rf $(BUILD) rm -rf $(EXECUTABLE_ALIASES) diff --git a/Sources.cmake b/Sources.cmake index 88c6e9b..6e0a32e 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -135,6 +135,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Constraint.h Analysis/include/Luau/ConstraintGraphBuilder.h Analysis/include/Luau/ConstraintSolver.h + Analysis/include/Luau/ControlFlow.h Analysis/include/Luau/DataFlowGraph.h Analysis/include/Luau/DcrLogger.h Analysis/include/Luau/Def.h @@ -370,6 +371,7 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.annotations.test.cpp tests/TypeInfer.anyerror.test.cpp tests/TypeInfer.builtins.test.cpp + tests/TypeInfer.cfa.test.cpp tests/TypeInfer.classes.test.cpp tests/TypeInfer.definitions.test.cpp tests/TypeInfer.functions.test.cpp diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 8d59ecb..5eceea7 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -33,6 +33,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauArrBoundResizeFix, false) + // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) @@ -454,15 +456,43 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) int nasize = numusearray(t, nums); // count keys in array part int totaluse = nasize; // all those keys are integer keys totaluse += numusehash(t, nums, &nasize); // count keys in hash part + // count extra key if (ttisnumber(ek)) nasize += countint(nvalue(ek), nums); totaluse++; + // compute new size for array part int na = computesizes(nums, &nasize); int nh = totaluse - na; - // enforce the boundary invariant; for performance, only do hash lookups if we must - nasize = adjustasize(t, nasize, ek); + + if (FFlag::LuauArrBoundResizeFix) + { + // enforce the boundary invariant; for performance, only do hash lookups if we must + int nadjusted = adjustasize(t, nasize, ek); + + // count how many extra elements belong to array part instead of hash part + int aextra = nadjusted - nasize; + + if (aextra != 0) + { + // we no longer need to store those extra array elements in hash part + nh -= aextra; + + // because hash nodes are twice as large as array nodes, the memory we saved for hash parts can be used by array part + // this follows the general sparse array part optimization where array is allocated when 50% occupation is reached + nasize = nadjusted + aextra; + + // since the size was changed, it's again important to enforce the boundary invariant at the new size + nasize = adjustasize(t, nasize, ek); + } + } + else + { + // enforce the boundary invariant; for performance, only do hash lookups if we must + nasize = adjustasize(t, nasize, ek); + } + // resize the table to new computed sizes resize(L, t, nasize, nh); } diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 135a555..c9d0c01 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -1677,8 +1677,6 @@ RETURN R0 0 TEST_CASE("LoopBreak") { - ScopedFastFlag sff("LuauCompileTerminateBC", true); - // default codegen: compile breaks as unconditional jumps CHECK_EQ("\n" + compileFunction0("while true do if math.random() < 0.5 then break else end end"), R"( L0: GETIMPORT R0 2 [math.random] @@ -1703,8 +1701,6 @@ L1: RETURN R0 0 TEST_CASE("LoopContinue") { - ScopedFastFlag sff("LuauCompileTerminateBC", true); - // default codegen: compile continue as unconditional jumps CHECK_EQ("\n" + compileFunction0("repeat if math.random() < 0.5 then continue else end break until false error()"), R"( L0: GETIMPORT R0 2 [math.random] @@ -2214,6 +2210,46 @@ TEST_CASE("RecursionParse") { CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile"); } + + try + { + Luau::compileOrThrow(bcb, "local f: " + rep("(", 1500) + "nil" + rep(")", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "local f: () " + rep("-> ()", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "local f: " + rep("{x:", 1500) + "nil" + rep("}", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "local f: " + rep("(nil & ", 1500) + "nil" + rep(")", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + } } TEST_CASE("ArrayIndexLiteral") @@ -6816,8 +6852,6 @@ RETURN R0 0 TEST_CASE("ElideJumpAfterIf") { - ScopedFastFlag sff("LuauCompileTerminateBC", true); - // break refers to outer loop => we can elide unconditional branches CHECK_EQ("\n" + compileFunction0(R"( local foo, bar = ... diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index 81e5c41..cc239b7 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -13,7 +13,7 @@ ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() { mainModule->reduction = std::make_unique(NotNull{&mainModule->internalTypes}, builtinTypes, NotNull{&ice}); - BlockedType::nextIndex = 0; + BlockedType::DEPRECATED_nextIndex = 0; BlockedTypePack::nextIndex = 0; } diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index a9c94ee..4d2e83f 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -137,7 +137,7 @@ const Config& TestConfigResolver::getConfig(const ModuleName& name) const Fixture::Fixture(bool freeze, bool prepareAutocomplete) : sff_DebugLuauFreezeArena("DebugLuauFreezeArena", freeze) , frontend(&fileResolver, &configResolver, - {/* retainFullTypeGraphs= */ true, /* forAutocomplete */ false, /* randomConstraintResolutionSeed */ randomSeed}) + {/* retainFullTypeGraphs= */ true, /* forAutocomplete */ false, /* runLintChecks */ false, /* randomConstraintResolutionSeed */ randomSeed}) , builtinTypes(frontend.builtinTypes) { configResolver.defaultConfig.mode = Mode::Strict; @@ -173,15 +173,19 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars // if AST is available, check how lint and typecheck handle error nodes if (result.root) { - frontend.lint(*sourceModule); - if (FFlag::DebugLuauDeferredConstraintResolution) { - Luau::check(*sourceModule, {}, builtinTypes, NotNull{&ice}, NotNull{&moduleResolver}, NotNull{&fileResolver}, + ModulePtr module = Luau::check(*sourceModule, {}, builtinTypes, NotNull{&ice}, NotNull{&moduleResolver}, NotNull{&fileResolver}, frontend.globals.globalScope, frontend.options); + + Luau::lint(sourceModule->root, *sourceModule->names, frontend.globals.globalScope, module.get(), sourceModule->hotcomments, {}); } else - frontend.typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); + { + ModulePtr module = frontend.typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); + + Luau::lint(sourceModule->root, *sourceModule->names, frontend.globals.globalScope, module.get(), sourceModule->hotcomments, {}); + } } throw ParseErrors(result.errors); @@ -209,20 +213,23 @@ CheckResult Fixture::check(const std::string& source) LintResult Fixture::lint(const std::string& source, const std::optional& lintOptions) { - ParseOptions parseOptions; - parseOptions.captureComments = true; - configResolver.defaultConfig.mode = Mode::Nonstrict; - parse(source, parseOptions); + ModuleName mm = fromString(mainModuleName); + configResolver.defaultConfig.mode = Mode::Strict; + fileResolver.source[mm] = std::move(source); + frontend.markDirty(mm); - return frontend.lint(*sourceModule, lintOptions); + return lintModule(mm); } -LintResult Fixture::lintTyped(const std::string& source, const std::optional& lintOptions) +LintResult Fixture::lintModule(const ModuleName& moduleName, const std::optional& lintOptions) { - check(source); - ModuleName mm = fromString(mainModuleName); + FrontendOptions options = frontend.options; + options.runLintChecks = true; + options.enabledLintWarnings = lintOptions; - return frontend.lint(mm, lintOptions); + CheckResult result = frontend.check(moduleName, options); + + return result.lintResult; } ParseResult Fixture::parseEx(const std::string& source, const ParseOptions& options) diff --git a/tests/Fixture.h b/tests/Fixture.h index 5db6ed1..4c49593 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -66,7 +66,7 @@ struct Fixture CheckResult check(const std::string& source); LintResult lint(const std::string& source, const std::optional& lintOptions = {}); - LintResult lintTyped(const std::string& source, const std::optional& lintOptions = {}); + LintResult lintModule(const ModuleName& moduleName, const std::optional& lintOptions = {}); /// Parse with all language extensions enabled ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {}); @@ -94,6 +94,7 @@ struct Fixture TypeId requireTypeAlias(const std::string& name); ScopedFastFlag sff_DebugLuauFreezeArena; + ScopedFastFlag luauLintInTypecheck{"LuauLintInTypecheck", true}; TestFileResolver fileResolver; TestConfigResolver configResolver; diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index e09990f..3b1ec4a 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -456,16 +456,16 @@ TEST_CASE_FIXTURE(FrontendFixture, "dont_reparse_clean_file_when_linting") end )"; - frontend.check("Modules/A"); + configResolver.defaultConfig.enabledLint.enableWarning(LintWarning::Code_ForRange); + + lintModule("Modules/A"); fileResolver.source["Modules/A"] = R"( -- We have fixed the lint error, but we did not tell the Frontend that the file is changed! - -- Therefore, we expect Frontend to reuse the parse tree. + -- Therefore, we expect Frontend to reuse the results from previous lint. )"; - configResolver.defaultConfig.enabledLint.enableWarning(LintWarning::Code_ForRange); - - LintResult lintResult = frontend.lint("Modules/A"); + LintResult lintResult = lintModule("Modules/A"); CHECK_EQ(1, lintResult.warnings.size()); } @@ -760,25 +760,49 @@ TEST_CASE_FIXTURE(FrontendFixture, "test_lint_uses_correct_config") configResolver.configFiles["Module/A"].enabledLint.enableWarning(LintWarning::Code_ForRange); - auto result = frontend.lint("Module/A"); + auto result = lintModule("Module/A"); CHECK_EQ(1, result.warnings.size()); configResolver.configFiles["Module/A"].enabledLint.disableWarning(LintWarning::Code_ForRange); + frontend.markDirty("Module/A"); - auto result2 = frontend.lint("Module/A"); + auto result2 = lintModule("Module/A"); CHECK_EQ(0, result2.warnings.size()); LintOptions overrideOptions; overrideOptions.enableWarning(LintWarning::Code_ForRange); - auto result3 = frontend.lint("Module/A", overrideOptions); + frontend.markDirty("Module/A"); + + auto result3 = lintModule("Module/A", overrideOptions); CHECK_EQ(1, result3.warnings.size()); overrideOptions.disableWarning(LintWarning::Code_ForRange); - auto result4 = frontend.lint("Module/A", overrideOptions); + frontend.markDirty("Module/A"); + + auto result4 = lintModule("Module/A", overrideOptions); CHECK_EQ(0, result4.warnings.size()); } +TEST_CASE_FIXTURE(FrontendFixture, "lint_results_are_only_for_checked_module") +{ + fileResolver.source["Module/A"] = R"( +local _ = 0b10000000000000000000000000000000000000000000000000000000000000000 + )"; + + fileResolver.source["Module/B"] = R"( +require(script.Parent.A) +local _ = 0x10000000000000000 + )"; + + LintResult lintResult = lintModule("Module/B"); + CHECK_EQ(1, lintResult.warnings.size()); + + // Check cached result + lintResult = lintModule("Module/B"); + CHECK_EQ(1, lintResult.warnings.size()); +} + TEST_CASE_FIXTURE(FrontendFixture, "discard_type_graphs") { Frontend fe{&fileResolver, &configResolver, {false}}; diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 41146d7..37c12dc 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -62,22 +62,16 @@ public: build.inst(IrCmd::LOP_RETURN, build.constUint(2)); }; - void checkEq(IrOp lhs, IrOp rhs) - { - CHECK_EQ(lhs.kind, rhs.kind); - LUAU_ASSERT(lhs.kind != IrOpKind::Constant && "can't compare constants, each ref is unique"); - CHECK_EQ(lhs.index, rhs.index); - } - void checkEq(IrOp instOp, const IrInst& inst) { const IrInst& target = build.function.instOp(instOp); CHECK(target.cmd == inst.cmd); - checkEq(target.a, inst.a); - checkEq(target.b, inst.b); - checkEq(target.c, inst.c); - checkEq(target.d, inst.d); - checkEq(target.e, inst.e); + CHECK(target.a == inst.a); + CHECK(target.b == inst.b); + CHECK(target.c == inst.c); + CHECK(target.d == inst.d); + CHECK(target.e == inst.e); + CHECK(target.f == inst.f); } IrBuilder build; @@ -405,18 +399,18 @@ bb_11: TEST_CASE_FIXTURE(IrBuilderFixture, "NumToIndex") { withOneBlock([this](IrOp a) { - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INDEX, build.constDouble(4), a)); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, build.constDouble(4), a)); build.inst(IrCmd::LOP_RETURN, build.constUint(0)); }); withOneBlock([this](IrOp a) { - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INDEX, build.constDouble(1.2), a)); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, build.constDouble(1.2), a)); build.inst(IrCmd::LOP_RETURN, build.constUint(0)); }); withOneBlock([this](IrOp a) { IrOp nan = build.inst(IrCmd::DIV_NUM, build.constDouble(0.0), build.constDouble(0.0)); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::NUM_TO_INDEX, nan, a)); + build.inst(IrCmd::STORE_INT, build.vmReg(0), build.inst(IrCmd::TRY_NUM_TO_INDEX, nan, a)); build.inst(IrCmd::LOP_RETURN, build.constUint(0)); }); @@ -1676,4 +1670,64 @@ bb_2: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequencePeeling") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp a = build.block(IrBlockKind::Internal); + IrOp b = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(3), build.constInt(-1)); + build.inst(IrCmd::JUMP_EQ_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), a, b); + + build.beginBlock(a); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(b); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1))); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(2), build.constInt(-1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; successors: bb_1, bb_2 +; in regs: R0, R1 +; out regs: R0, R1, R3... + FALLBACK_GETVARARGS 0u, R3, -1i + %1 = LOAD_TAG R0 + JUMP_EQ_TAG %1, tnumber, bb_1, bb_2 + +bb_1: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R0, R3... +; out regs: R2... + %3 = LOAD_TVALUE R0 + STORE_TVALUE R2, %3 + JUMP bb_3 + +bb_2: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R1, R3... +; out regs: R2... + %6 = LOAD_TVALUE R1 + STORE_TVALUE R2, %6 + JUMP bb_3 + +bb_3: +; predecessors: bb_1, bb_2 +; in regs: R2... + LOP_RETURN 0u, R2, -1i + +)"); +} + TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index ebd004d..0f13461 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -733,6 +733,7 @@ end TEST_CASE_FIXTURE(Fixture, "ImplicitReturn") { LintResult result = lint(R"( +--!nonstrict function f1(a) if not a then return 5 @@ -789,20 +790,21 @@ return f1,f2,f3,f4,f5,f6,f7 )"); REQUIRE(3 == result.warnings.size()); - CHECK_EQ(result.warnings[0].location.begin.line, 4); + CHECK_EQ(result.warnings[0].location.begin.line, 5); CHECK_EQ(result.warnings[0].text, - "Function 'f1' can implicitly return no values even though there's an explicit return at line 4; add explicit return to silence"); - CHECK_EQ(result.warnings[1].location.begin.line, 28); + "Function 'f1' can implicitly return no values even though there's an explicit return at line 5; add explicit return to silence"); + CHECK_EQ(result.warnings[1].location.begin.line, 29); CHECK_EQ(result.warnings[1].text, - "Function 'f4' can implicitly return no values even though there's an explicit return at line 25; add explicit return to silence"); - CHECK_EQ(result.warnings[2].location.begin.line, 44); + "Function 'f4' can implicitly return no values even though there's an explicit return at line 26; add explicit return to silence"); + CHECK_EQ(result.warnings[2].location.begin.line, 45); CHECK_EQ(result.warnings[2].text, - "Function can implicitly return no values even though there's an explicit return at line 44; add explicit return to silence"); + "Function can implicitly return no values even though there's an explicit return at line 45; add explicit return to silence"); } TEST_CASE_FIXTURE(Fixture, "ImplicitReturnInfiniteLoop") { LintResult result = lint(R"( +--!nonstrict function f1(a) while true do if math.random() > 0.5 then @@ -845,12 +847,12 @@ return f1,f2,f3,f4 )"); REQUIRE(2 == result.warnings.size()); - CHECK_EQ(result.warnings[0].location.begin.line, 25); + CHECK_EQ(result.warnings[0].location.begin.line, 26); CHECK_EQ(result.warnings[0].text, - "Function 'f3' can implicitly return no values even though there's an explicit return at line 21; add explicit return to silence"); - CHECK_EQ(result.warnings[1].location.begin.line, 36); + "Function 'f3' can implicitly return no values even though there's an explicit return at line 22; add explicit return to silence"); + CHECK_EQ(result.warnings[1].location.begin.line, 37); CHECK_EQ(result.warnings[1].text, - "Function 'f4' can implicitly return no values even though there's an explicit return at line 32; add explicit return to silence"); + "Function 'f4' can implicitly return no values even though there's an explicit return at line 33; add explicit return to silence"); } TEST_CASE_FIXTURE(Fixture, "TypeAnnotationsShouldNotProduceWarnings") @@ -1164,7 +1166,7 @@ os.date("!*t") TEST_CASE_FIXTURE(Fixture, "FormatStringTyped") { - LintResult result = lintTyped(R"~( + LintResult result = lint(R"~( local s: string, nons = ... string.match(s, "[]") @@ -1285,7 +1287,7 @@ TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") local _bar: typeof(os.clock) = os.clock )"; - LintResult result = frontend.lint("A"); + LintResult result = lintModule("A"); REQUIRE(0 == result.warnings.size()); } @@ -1471,7 +1473,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiTyped") freeze(frontend.globals.globalTypes); - LintResult result = lintTyped(R"( + LintResult result = lint(R"( return function (i: Instance) i:Wait(1.0) print(i.Name) @@ -1518,7 +1520,7 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "TableOperations") { - LintResult result = lintTyped(R"( + LintResult result = lint(R"( local t = {} local tt = {} diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index b86af0e..384a39f 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -28,6 +28,18 @@ struct IsSubtypeFixture : Fixture return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, ice); } + + bool isConsistentSubtype(TypeId a, TypeId b) + { + Location location; + ModulePtr module = getMainModule(); + REQUIRE(module); + + if (!module->hasModuleScope()) + FAIL("isSubtype: module scope data is not available"); + + return ::Luau::isConsistentSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, ice); + } }; } // namespace @@ -86,8 +98,8 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_and_any") // any makes things work even when it makes no sense. - CHECK(isSubtype(b, a)); - CHECK(isSubtype(a, b)); + CHECK(isConsistentSubtype(b, a)); + CHECK(isConsistentSubtype(a, b)); } TEST_CASE_FIXTURE(IsSubtypeFixture, "variadic_functions_with_no_head") @@ -163,6 +175,10 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_union_prop") TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_any_prop") { + ScopedFastFlag sffs[] = { + {"LuauTransitiveSubtyping", true}, + }; + check(R"( local a: {x: number} local b: {x: any} @@ -172,7 +188,8 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_any_prop") TypeId b = requireType("b"); CHECK(isSubtype(a, b)); - CHECK(isSubtype(b, a)); + CHECK(!isSubtype(b, a)); + CHECK(isConsistentSubtype(b, a)); } TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection") @@ -216,6 +233,10 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "union_and_intersection") TEST_CASE_FIXTURE(IsSubtypeFixture, "tables") { + ScopedFastFlag sffs[] = { + {"LuauTransitiveSubtyping", true}, + }; + check(R"( local a: {x: number} local b: {x: any} @@ -229,7 +250,8 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "tables") TypeId d = requireType("d"); CHECK(isSubtype(a, b)); - CHECK(isSubtype(b, a)); + CHECK(!isSubtype(b, a)); + CHECK(isConsistentSubtype(b, a)); CHECK(!isSubtype(c, a)); CHECK(!isSubtype(a, c)); @@ -358,6 +380,92 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "metatable" * doctest::expected_failures{1}) } #endif +TEST_CASE_FIXTURE(IsSubtypeFixture, "any_is_unknown_union_error") +{ + ScopedFastFlag sffs[] = { + {"LuauTransitiveSubtyping", true}, + }; + + check(R"( + local err = 5.nope.nope -- err is now an error type + local a : any + local b : (unknown | typeof(err)) + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + + CHECK(isSubtype(a, b)); + CHECK(isSubtype(b, a)); + CHECK_EQ("*error-type*", toString(requireType("err"))); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "any_intersect_T_is_T") +{ + ScopedFastFlag sffs[] = { + {"LuauTransitiveSubtyping", true}, + }; + + check(R"( + local a : (any & string) + local b : string + local c : number + )"); + + TypeId a = requireType("a"); + TypeId b = requireType("b"); + TypeId c = requireType("c"); + + CHECK(isSubtype(a, b)); + CHECK(isSubtype(b, a)); + CHECK(!isSubtype(a, c)); + CHECK(!isSubtype(c, a)); +} + +TEST_CASE_FIXTURE(IsSubtypeFixture, "error_suppression") +{ + ScopedFastFlag sffs[] = { + {"LuauTransitiveSubtyping", true}, + }; + + check(""); + + TypeId any = builtinTypes->anyType; + TypeId err = builtinTypes->errorType; + TypeId str = builtinTypes->stringType; + TypeId unk = builtinTypes->unknownType; + + CHECK(!isSubtype(any, err)); + CHECK(isSubtype(err, any)); + CHECK(isConsistentSubtype(any, err)); + CHECK(isConsistentSubtype(err, any)); + + CHECK(!isSubtype(any, str)); + CHECK(isSubtype(str, any)); + CHECK(isConsistentSubtype(any, str)); + CHECK(isConsistentSubtype(str, any)); + + CHECK(!isSubtype(any, unk)); + CHECK(isSubtype(unk, any)); + CHECK(isConsistentSubtype(any, unk)); + CHECK(isConsistentSubtype(unk, any)); + + CHECK(!isSubtype(err, str)); + CHECK(!isSubtype(str, err)); + CHECK(isConsistentSubtype(err, str)); + CHECK(isConsistentSubtype(str, err)); + + CHECK(!isSubtype(err, unk)); + CHECK(!isSubtype(unk, err)); + CHECK(isConsistentSubtype(err, unk)); + CHECK(isConsistentSubtype(unk, err)); + + CHECK(isSubtype(str, unk)); + CHECK(!isSubtype(unk, str)); + CHECK(isConsistentSubtype(str, unk)); + CHECK(!isConsistentSubtype(unk, str)); +} + TEST_SUITE_END(); struct NormalizeFixture : Fixture @@ -692,4 +800,17 @@ TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_tables") CHECK("table" == toString(normal("Not>"))); } +TEST_CASE_FIXTURE(NormalizeFixture, "normalize_blocked_types") +{ + ScopedFastFlag sff[] { + {"LuauNormalizeBlockedTypes", true}, + }; + + Type blocked{BlockedType{}}; + + const NormalizedType* norm = normalizer.normalize(&blocked); + + CHECK_EQ(normalizer.typeFromNormal(*norm), &blocked); +} + TEST_SUITE_END(); diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index 7e50d5b..093570d 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -263,9 +263,8 @@ TEST_CASE_FIXTURE(LimitFixture, "typescript_port_of_Result_type") )LUA"; CheckResult result = check(src); - CodeTooComplex ctc; - CHECK(hasError(result, &ctc)); + CHECK(hasError(result)); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 0488196..c6766ca 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -225,7 +225,10 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") CHECK_EQ("unknown", err->name); - CHECK_EQ("*error-type*", toString(requireType("a"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("any", toString(requireType("a"))); + else + CHECK_EQ("*error-type*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") @@ -234,7 +237,10 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") local a = Utility.Create "Foo" {} )"); - CHECK_EQ("*error-type*", toString(requireType("a"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("any", toString(requireType("a"))); + else + CHECK_EQ("*error-type*", toString(requireType("a"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") @@ -343,4 +349,19 @@ stat = stat and tonumber(stat) or stat LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "intersection_of_any_can_have_props") +{ + // *blocked-130* ~ hasProp any & ~(false?), "_status" + CheckResult result = check(R"( +function foo(x: any, y) + if x then + return x._status + end + return y +end +)"); + + CHECK("(any, any) -> any" == toString(requireType("foo"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 5318b40..49209a4 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -704,11 +704,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_variadic_typepack_tail_and_strin LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("string", toString(requireType("foo"))); - else - CHECK_EQ("any", toString(requireType("foo"))); - + CHECK_EQ("any", toString(requireType("foo"))); CHECK_EQ("any", toString(requireType("bar"))); CHECK_EQ("any", toString(requireType("baz"))); CHECK_EQ("any", toString(requireType("quux"))); @@ -996,11 +992,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ("Key 'b' not found in table '{| a: number |}'", toString(result.errors[0])); + CHECK(Location({13, 18}, {13, 23}) == result.errors[0].location); CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("string", toString(requireType("b"))); CHECK_EQ("boolean", toString(requireType("c"))); - CHECK_EQ("*error-type*", toString(requireType("d"))); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("any", toString(requireType("d"))); + else + CHECK_EQ("*error-type*", toString(requireType("d"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") diff --git a/tests/TypeInfer.cfa.test.cpp b/tests/TypeInfer.cfa.test.cpp new file mode 100644 index 0000000..7374295 --- /dev/null +++ b/tests/TypeInfer.cfa.test.cpp @@ -0,0 +1,380 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Fixture.h" + +#include "Luau/Symbol.h" +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("ControlFlowAnalysis"); + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + if not x then + return + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_return") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?, y: string?) + if not x then + return + elseif not y then + return + end + + local foo = x + local bar = y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({8, 24}))); + CHECK_EQ("string", toString(requireTypeAtPosition({9, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_rand_return_elif_not_y_return") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?, y: string?) + if not x then + return + elseif math.random() > 0.5 then + return + elseif not y then + return + end + + local foo = x + local bar = y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 24}))); + CHECK_EQ("string", toString(requireTypeAtPosition({11, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_rand_return_elif_not_y_fallthrough") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?, y: string?) + if not x then + return + elseif math.random() > 0.5 then + return + elseif not y then + + end + + local foo = x + local bar = y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 24}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({11, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_fallthrough_elif_not_z_return") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?, y: string?, z: string?) + if not x then + return + elseif not y then + + elseif not z then + return + end + + local foo = x + local bar = y + local baz = z + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 24}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({11, 24}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({12, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "do_if_not_x_return") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + do + if not x then + return + end + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({8, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "early_return_in_a_loop_which_isnt_guaranteed_to_run_first") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + while math.random() > 0.5 do + if not x then + return + end + + local foo = x + end + + local bar = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({10, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "early_return_in_a_loop_which_is_guaranteed_to_run_first") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + repeat + if not x then + return + end + + local foo = x + until math.random() > 0.5 + + local bar = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({10, 24}))); // TODO: This is wrong, should be `string`. +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "early_return_in_a_loop_which_is_guaranteed_to_run_first_2") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + for i = 1, 10 do + if not x then + return + end + + local foo = x + end + + local bar = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({7, 28}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({10, 24}))); // TODO: This is wrong, should be `string`. +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_then_error") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + if not x then + error("oops") + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_then_assert_false") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + if not x then + assert(false) + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_if_not_y_return") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?, y: string?) + if not x then + return + end + + if not y then + return + end + + local foo = x + local bar = y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireTypeAtPosition({10, 24}))); + CHECK_EQ("string", toString(requireTypeAtPosition({11, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_does_not_leak_out") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + if typeof(x) == "string" then + return + else + type Foo = number + end + + local foo: Foo = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Unknown type 'Foo'", toString(result.errors[0])); + + CHECK_EQ("nil", toString(requireTypeAtPosition({8, 29}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "prototyping_and_visiting_alias_has_the_same_scope") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + // In CGB, we walk the block to prototype aliases. We then visit the block in-order, which will resolve the prototype to a real type. + // That second walk assumes that the name occurs in the same `Scope` that the prototype walk had. If we arbitrarily change scope midway + // through, we'd invoke UB. + CheckResult result = check(R"( + local function f(x: string?) + type Foo = number + + if typeof(x) == "string" then + return + end + + local foo: Foo = x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'nil' could not be converted into 'number'", toString(result.errors[0])); + + CHECK_EQ("nil", toString(requireTypeAtPosition({8, 29}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + type Ok = { tag: "ok", value: T } + type Err = { tag: "err", error: E } + type Result = Ok | Err + + local function map(result: Result, f: (T) -> U): Result + if result.tag == "ok" then + local tag = result.tag + local val = result.value + + return { tag = "ok", value = f(result.value) } + end + + local tag = result.tag + local err = result.error + + return result + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("\"ok\"", toString(requireTypeAtPosition({7, 35}))); + CHECK_EQ("T", toString(requireTypeAtPosition({8, 35}))); + + CHECK_EQ("\"err\"", toString(requireTypeAtPosition({13, 31}))); + CHECK_EQ("E", toString(requireTypeAtPosition({14, 31}))); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("{| error: E, tag: \"err\" |}", toString(requireTypeAtPosition({16, 19}))); + else + CHECK_EQ("Err", toString(requireTypeAtPosition({16, 19}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "do_assert_x") +{ + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + + CheckResult result = check(R"( + local function f(x: string?) + do + assert(x) + end + + local foo = x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 50e9f80..511cbc7 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -689,4 +689,22 @@ TEST_CASE_FIXTURE(Fixture, "for_loop_lower_bound_is_string_3") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "cli_68448_iterators_need_not_accept_nil") +{ + CheckResult result = check(R"( + local function makeEnum(members) + local enum = {} + for _, memberName in ipairs(members) do + enum[memberName] = memberName + end + return enum + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + // HACK (CLI-68453): We name this inner table `enum`. For now, use the + // exhaustive switch to see past it. + CHECK(toString(requireType("makeEnum"), {true}) == "({a}) -> {| [a]: a |}"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index ab07ee2..8670729 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -485,6 +485,8 @@ return unpack(l0[_]) TEST_CASE_FIXTURE(BuiltinsFixture, "check_imported_module_names") { + ScopedFastFlag sff{"LuauTinyControlFlowAnalysis", true}; + fileResolver.source["game/A"] = R"( return function(...) end )"; @@ -506,19 +508,10 @@ return l0 ModulePtr mod = getMainModule(); REQUIRE(mod); - if (FFlag::DebugLuauDeferredConstraintResolution) - { - REQUIRE(mod->scopes.size() >= 4); - CHECK(mod->scopes[0].second->importedModules["l0"] == "game/B"); - CHECK(mod->scopes[3].second->importedModules["l1"] == "game/A"); - } - else - { - REQUIRE(mod->scopes.size() >= 3); - CHECK(mod->scopes[0].second->importedModules["l0"] == "game/B"); - CHECK(mod->scopes[2].second->importedModules["l1"] == "game/A"); - } + REQUIRE(mod->scopes.size() == 4); + CHECK(mod->scopes[0].second->importedModules["l0"] == "game/B"); + CHECK(mod->scopes[3].second->importedModules["l1"] == "game/A"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index ab41ce3..0f540f6 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -309,4 +309,21 @@ TEST_CASE_FIXTURE(Fixture, "dont_bind_free_tables_to_themselves") )"); } +// We should probably flag an error on this. See CLI-68672 +TEST_CASE_FIXTURE(BuiltinsFixture, "flag_when_index_metamethod_returns_0_values") +{ + CheckResult result = check(R"( + local T = {} + function T.__index() + end + + local a = setmetatable({}, T) + local p = a.prop + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK("nil" == toString(requireType("p"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index dcdc2e3..720784c 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -1109,4 +1109,28 @@ local f1 = f or 'f' CHECK("string" == toString(requireType("f1"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "reducing_and") +{ + ScopedFastFlag sff[]{ + {"LuauTryhardAnd", true}, + {"LuauReducingAndOr", true}, + }; + + CheckResult result = check(R"( +type Foo = { name: string?, flag: boolean? } +local arr: {Foo} = {} + +local function foo(arg: {name: string}?) + local name = if arg and arg.name then arg.name else nil + + table.insert(arr, { + name = name or "", + flag = name ~= nil and name ~= "", + }) +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 0aacb8a..30f77d6 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -470,6 +470,10 @@ TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") { + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + TypeArena arena; TypeId nilType = builtinTypes->nilType; @@ -488,7 +492,7 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") u.tryUnify(option1, option2); - CHECK(u.errors.empty()); + CHECK(!u.failure); u.log.commit(); @@ -548,7 +552,10 @@ return wrapStrictTable(Constants, "Constants") std::optional result = first(m->returnType); REQUIRE(result); - CHECK(get(*result)); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("(any?) & ~table", toString(*result)); + else + CHECK_MESSAGE(get(*result), *result); } TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface_variadic") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 064ec16..890e9b6 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1615,7 +1615,8 @@ TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") )"); LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ("*error-type*", toString(requireTypeAtPosition({4, 30}))); + + CHECK_EQ("~false & ~nil", toString(requireTypeAtPosition({4, 30}))); } TEST_CASE_FIXTURE(BuiltinsFixture, "dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 0f5e3d3..21ac642 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1590,8 +1590,16 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") local hi: number = foo({ a = "hi" }, "a") -- shouldn't typecheck since at runtime hi is "hi" )"); - // This typechecks but shouldn't - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Type 'string' could not be converted into 'number'"); + } + else + { + // This typechecks but shouldn't + LUAU_REQUIRE_NO_ERRORS(result); + } } TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multiple_errors") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 417f80a..7c4bfb2 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -103,6 +103,16 @@ TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "obvious_type_error_in_nocheck_mode") +{ + CheckResult result = check(R"( + --!nocheck + local x: string = 5 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "expr_statement") { CheckResult result = check("local foo = 5 foo()"); @@ -1185,6 +1195,9 @@ TEST_CASE_FIXTURE(Fixture, "dcr_delays_expansion_of_function_containing_blocked_ ScopedFastFlag sff[] = { {"DebugLuauDeferredConstraintResolution", true}, {"LuauTinyUnifyNormalsFix", true}, + // If we run this with error-suppression, it triggers an assertion. + // FATAL ERROR: Assertion failed: !"Internal error: Trying to normalize a BlockedType" + {"LuauTransitiveSubtyping", false}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 66e0701..5a9c77d 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -27,16 +27,25 @@ TEST_SUITE_BEGIN("TryUnifyTests"); TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") { + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + Type numberOne{TypeVariant{PrimitiveType{PrimitiveType::Number}}}; Type numberTwo = numberOne; state.tryUnify(&numberTwo, &numberOne); + CHECK(!state.failure); CHECK(state.errors.empty()); } TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") { + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + Type functionOne{ TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType}))}}; @@ -44,6 +53,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))}}; state.tryUnify(&functionTwo, &functionOne); + CHECK(!state.failure); CHECK(state.errors.empty()); state.log.commit(); @@ -53,6 +63,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") { + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + TypePackVar argPackOne{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; Type functionOne{ TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType}))}}; @@ -66,6 +80,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") Type functionTwoSaved = functionTwo; state.tryUnify(&functionTwo, &functionOne); + CHECK(state.failure); CHECK(!state.errors.empty()); CHECK_EQ(functionOne, functionOneSaved); @@ -74,6 +89,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") { + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + Type tableOne{TypeVariant{ TableType{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; @@ -86,6 +105,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") state.tryUnify(&tableTwo, &tableOne); + CHECK(!state.failure); CHECK(state.errors.empty()); state.log.commit(); @@ -95,6 +115,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") { + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + Type tableOne{TypeVariant{ TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->numberType}}}, std::nullopt, globalScope->level, TableState::Unsealed}, @@ -109,6 +133,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") state.tryUnify(&tableTwo, &tableOne); + CHECK(state.failure); CHECK_EQ(1, state.errors.size()); CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); @@ -218,6 +243,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") TypePackVar variadicPack{VariadicTypePack{builtinTypes->numberType}}; state.tryUnify(&testPack, &variadicPack); + CHECK(state.failure); CHECK(!state.errors.empty()); } @@ -228,6 +254,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") TypePackVar b{TypePack{{builtinTypes->numberType, builtinTypes->stringType}, &variadicPack}}; state.tryUnify(&b, &a); + CHECK(!state.failure); CHECK(state.errors.empty()); } @@ -270,8 +297,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") arena.addTypePack(TypePack{{builtinTypes->numberType, builtinTypes->numberType, builtinTypes->numberType}, std::nullopt}); TypePackId numberAndFreeTail = arena.addTypePack(TypePack{{builtinTypes->numberType}, arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); - ErrorVec unifyErrors = state.canUnify(numberAndFreeTail, threeNumbers); - CHECK(unifyErrors.size() == 0); + CHECK(state.canUnify(numberAndFreeTail, threeNumbers).empty()); } TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") @@ -321,7 +347,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table") { - ScopedFastFlag sff("DebugLuauDeferredConstraintResolution", true); + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + {"DebugLuauDeferredConstraintResolution", true}, + }; TableType::Props freeProps{ {"foo", {builtinTypes->numberType}}, diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 704e2a3..d49f004 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -715,4 +715,62 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types_2") CHECK_EQ("({| x: number |} | {| x: string |}) -> number | string", toString(requireType("f"))); } +TEST_CASE_FIXTURE(Fixture, "union_table_any_property") +{ + CheckResult result = check(R"( + function f(x) + -- x : X + -- sup : { p : { q : X } }? + local sup = if true then { p = { q = x } } else nil + local sub : { p : any } + sup = nil + sup = sub + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "union_function_any_args") +{ + CheckResult result = check(R"( + local sup : ((...any) -> (...any))? + local sub : ((number) -> (...any)) + sup = sub + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "optional_any") +{ + CheckResult result = check(R"( + local sup : any? + local sub : number + sup = sub + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_function_with_optional_arg") +{ + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + }; + + CheckResult result = check(R"( + function f(x : T?) : {T} + local result = {} + if x then + result[1] = x + end + return result + end + local t : {string} = f(nil) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp index f17ada2..410fd52 100644 --- a/tests/TypeInfer.unknownnever.test.cpp +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -303,6 +303,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_i { ScopedFastFlag sff[]{ {"LuauTryhardAnd", true}, + {"LuauReducingAndOr", true}, }; CheckResult result = check(R"( @@ -313,13 +314,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_i LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK_EQ("(nil, a) -> boolean", toString(requireType("ord"))); - else - { - // Widening doesn't normalize yet, so the result is a bit strange - CHECK_EQ("(nil, a) -> boolean | boolean", toString(requireType("ord"))); - } + CHECK_EQ("(nil, a) -> boolean", toString(requireType("ord"))); } TEST_CASE_FIXTURE(Fixture, "math_operators_and_never") diff --git a/tests/conformance/tables.lua b/tests/conformance/tables.lua index 4b47ed2..596eed3 100644 --- a/tests/conformance/tables.lua +++ b/tests/conformance/tables.lua @@ -578,6 +578,21 @@ do assert(#t2 == 6) end +-- test boundary invariant in sparse arrays or various kinds +do + local function obscuredalloc() return {} end + + local bits = 16 + + for i = 1, 2^bits - 1 do + local t1 = obscuredalloc() -- to avoid NEWTABLE guessing correct size + + for k = 1, bits do + t1[k] = if bit32.extract(i, k - 1) == 1 then true else nil + end + end +end + -- test table.unpack fastcall for rejecting large unpacks do local ok, res = pcall(function() diff --git a/tools/faillist.txt b/tools/faillist.txt index bcc1777..d513af1 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -31,8 +31,6 @@ BuiltinTests.table_pack_reduce BuiltinTests.table_pack_variadic DefinitionTests.class_definition_overload_metamethods DefinitionTests.class_definition_string_props -FrontendTest.environments -FrontendTest.nocheck_cycle_used_by_checked GenericsTests.apply_type_function_nested_generics2 GenericsTests.better_mismatch_error_messages GenericsTests.bound_tables_do_not_clone_original_fields @@ -54,19 +52,6 @@ GenericsTests.self_recursive_instantiated_param IntersectionTypes.table_intersection_write_sealed IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect -ModuleTests.clone_self_property -NonstrictModeTests.for_in_iterator_variables_are_any -NonstrictModeTests.function_parameters_are_any -NonstrictModeTests.inconsistent_module_return_types_are_ok -NonstrictModeTests.inconsistent_return_types_are_ok -NonstrictModeTests.infer_nullary_function -NonstrictModeTests.infer_the_maximum_number_of_values_the_function_could_return -NonstrictModeTests.inline_table_props_are_also_any -NonstrictModeTests.local_tables_are_not_any -NonstrictModeTests.locals_are_any_by_default -NonstrictModeTests.offer_a_hint_if_you_use_a_dot_instead_of_a_colon -NonstrictModeTests.parameters_having_type_any_are_optional -NonstrictModeTests.table_props_are_any ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal ProvisionalTests.bail_early_if_unification_is_too_complicated ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack @@ -85,9 +70,7 @@ RefinementTest.typeguard_in_assert_position RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table RuntimeLimits.typescript_port_of_Result_type TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible -TableTests.any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode TableTests.casting_tables_with_props_into_table_with_indexer3 -TableTests.casting_tables_with_props_into_table_with_indexer4 TableTests.checked_prop_too_early TableTests.disallow_indexing_into_an_unsealed_table_with_no_indexer_in_strict_mode TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar @@ -117,7 +100,6 @@ TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred TableTests.mixed_tables_with_implicit_numbered_keys TableTests.nil_assign_doesnt_hit_indexer TableTests.ok_to_set_nil_even_on_non_lvalue_base_expr -TableTests.only_ascribe_synthetic_names_at_module_scope TableTests.oop_polymorphic TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table @@ -138,7 +120,6 @@ ToString.named_metatable_toStringNamedFunction ToString.toStringDetailed2 ToString.toStringErrorPack ToString.toStringNamedFunction_generic_pack -ToString.toStringNamedFunction_map TryUnifyTests.members_of_failed_typepack_unification_are_unified_with_errorType TryUnifyTests.result_of_failed_typepack_unification_is_constrained TryUnifyTests.typepack_unification_should_trim_free_tails @@ -154,15 +135,11 @@ TypeAliases.type_alias_local_rename TypeAliases.type_alias_locations TypeAliases.type_alias_of_an_imported_recursive_generic_type TypeInfer.check_type_infer_recursion_count -TypeInfer.checking_should_not_ice TypeInfer.cli_50041_committing_txnlog_in_apollo_client_error TypeInfer.dont_report_type_errors_within_an_AstExprError TypeInfer.dont_report_type_errors_within_an_AstStatError TypeInfer.fuzz_free_table_type_change_during_index_check -TypeInfer.globals -TypeInfer.globals2 TypeInfer.infer_assignment_value_types_mutable_lval -TypeInfer.it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict TypeInfer.no_stack_overflow_from_isoptional TypeInfer.no_stack_overflow_from_isoptional2 TypeInfer.tc_after_error_recovery_no_replacement_name_in_error @@ -173,17 +150,13 @@ TypeInferClasses.classes_without_overloaded_operators_cannot_be_added TypeInferClasses.index_instance_property TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties TypeInferClasses.warn_when_prop_almost_matches -TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.cannot_hoist_interior_defns_into_signature -TypeInferFunctions.check_function_before_lambda_that_uses_it TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists -TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict TypeInferFunctions.function_cast_error_uses_correct_language TypeInferFunctions.function_decl_non_self_sealed_overwrite_2 TypeInferFunctions.function_decl_non_self_unsealed_overwrite TypeInferFunctions.function_does_not_return_enough_values TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer -TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict TypeInferFunctions.improved_function_arg_mismatch_errors TypeInferFunctions.infer_anonymous_function_arguments TypeInferFunctions.infer_that_function_does_not_return_a_table @@ -191,7 +164,6 @@ TypeInferFunctions.luau_subtyping_is_np_hard TypeInferFunctions.no_lossy_function_type TypeInferFunctions.occurs_check_failure_in_function_return_type TypeInferFunctions.record_matching_overload -TypeInferFunctions.report_exiting_without_return_nonstrict TypeInferFunctions.report_exiting_without_return_strict TypeInferFunctions.return_type_by_overload TypeInferFunctions.too_few_arguments_variadic @@ -204,11 +176,9 @@ TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_va TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_with_generic_next TypeInferLoops.loop_iter_metamethod_ok_with_inference -TypeInferLoops.loop_iter_no_indexer_nonstrict TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.unreachable_code_after_infinite_loop -TypeInferModules.custom_require_global TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated @@ -220,18 +190,14 @@ TypeInferOperators.cli_38355_recursive_union TypeInferOperators.compound_assign_metatable TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops -TypeInferOperators.in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators -TypeInferOperators.infer_any_in_all_modes_when_lhs_is_unknown TypeInferOperators.operator_eq_completely_incompatible TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection TypeInferOperators.typecheck_overloaded_multiply_that_is_an_intersection_on_rhs TypeInferOperators.typecheck_unary_len_error -TypeInferOperators.UnknownGlobalCompoundAssign TypeInferOperators.unrelated_classes_cannot_be_compared TypeInferOperators.unrelated_primitives_cannot_be_compared TypeInferPrimitives.CheckMethodsOfNumber TypeInferPrimitives.string_index -TypeInferUnknownNever.assign_to_global_which_is_never TypeInferUnknownNever.dont_unify_operands_if_one_of_the_operand_is_never_in_any_ordering_operators TypeInferUnknownNever.math_operators_and_never TypePackTests.detect_cyclic_typepacks2 @@ -250,6 +216,7 @@ TypeSingletons.table_properties_type_error_escapes TypeSingletons.taking_the_length_of_union_of_string_singleton TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeSingletons.widening_happens_almost_everywhere +UnionTypes.generic_function_with_optional_arg UnionTypes.index_on_a_union_type_with_missing_property UnionTypes.optional_assignment_errors UnionTypes.optional_call_error