diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 059e97c..fe6a025 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -191,6 +191,9 @@ private: **/ void unblock_(BlockedConstraintId progressed); + TypeId errorRecoveryType() const; + TypePackId errorRecoveryTypePack() const; + ToStringOptions opts; }; diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 4c81d33..eab6a21 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -95,9 +95,11 @@ struct CountMismatch Return, }; size_t expected; + std::optional maximum; size_t actual; Context context = Arg; bool isVariadic = false; + std::string function; bool operator==(const CountMismatch& rhs) const; }; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 61e07e9..dd2d709 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -92,6 +92,8 @@ inline std::string toString(const Constraint& c) return toString(c, ToStringOptions{}); } +std::string toString(const LValue& lvalue); + std::string toString(const TypeVar& tv, ToStringOptions& opts); std::string toString(const TypePackVar& tp, ToStringOptions& opts); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index b0b3f3a..bbb9bd6 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -127,8 +127,8 @@ struct TypeChecker std::optional originalNameLoc, std::optional selfType, std::optional expectedType); void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); - void checkArgumentList( - const ScopePtr& scope, Unifier& state, TypePackId paramPack, TypePackId argPack, const std::vector& argLocations); + void checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId paramPack, TypePackId argPack, + const std::vector& argLocations); WithPredicate checkExprPack(const ScopePtr& scope, const AstExpr& expr); diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 8269230..2968809 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -185,6 +185,9 @@ std::pair, std::optional> flatten(TypePackId tp, bool isVariadic(TypePackId tp); bool isVariadic(TypePackId tp, const TxnLog& log); +// Returns true if the TypePack is Generic or Variadic. Does not walk TypePacks!! +bool isVariadicTail(TypePackId tp, const TxnLog& log, bool includeHiddenVariadics = false); + bool containsNever(TypePackId tp); } // namespace Luau diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 6890f88..efc1d88 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -23,6 +23,6 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro TypeId type, const std::string& prop, const Location& location, bool addErrors, InternalErrorReporter& handle); // Returns the minimum and maximum number of types the argument list can accept. -std::pair> getParameterExtents(const TxnLog* log, TypePackId tp); +std::pair> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics = false); } // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index e9c61e4..6a65ab9 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -729,6 +729,11 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp if (AstExprCall* call = expr->as()) { + TypeId fnType = check(scope, call->func); + + const size_t constraintIndex = scope->constraints.size(); + const size_t scopeIndex = scopes.size(); + std::vector args; for (AstExpr* arg : call->args) @@ -738,7 +743,8 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp // TODO self - TypeId fnType = check(scope, call->func); + const size_t constraintEndIndex = scope->constraints.size(); + const size_t scopeEndIndex = scopes.size(); astOriginalCallTypes[call->func] = fnType; @@ -753,7 +759,23 @@ TypePackId ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* exp scope->unqueuedConstraints.push_back( std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); - NotNull sc(scope->unqueuedConstraints.back().get()); + NotNull sc(scope->unqueuedConstraints.back().get()); + + // We force constraints produced by checking function arguments to wait + // until after we have resolved the constraint on the function itself. + // This ensures, for instance, that we start inferring the contents of + // lambdas under the assumption that their arguments and return types + // will be compatible with the enclosing function call. + for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci) + scope->constraints[ci]->dependencies.push_back(sc); + + for (size_t si = scopeIndex; si < scopeEndIndex; ++si) + { + for (auto& c : scopes[si].second->constraints) + { + c->dependencies.push_back(sc); + } + } addConstraint(scope, call->func->location, FunctionCallConstraint{ @@ -1080,7 +1102,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS signatureScope = bodyScope; } - std::optional varargPack; + TypePackId varargPack = nullptr; if (fn->vararg) { @@ -1096,6 +1118,14 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS signatureScope->varargPack = varargPack; } + else + { + varargPack = arena->addTypePack(VariadicTypePack{singletonTypes->anyType, /*hidden*/ true}); + // We do not add to signatureScope->varargPack because ... is not valid + // in functions without an explicit ellipsis. + } + + LUAU_ASSERT(nullptr != varargPack); if (fn->returnAnnotation) { diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index f964a85..6fb57b1 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -240,6 +240,12 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); printf("\t%d\t\t%s\n", blockCount, toString(*dep, opts).c_str()); } + + if (auto fcc = get(*c)) + { + for (NotNull inner : fcc->innerConstraints) + printf("\t\t\t%s\n", toString(*inner, opts).c_str()); + } } } @@ -522,7 +528,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullty.emplace(singletonTypes->errorRecoveryType()); + asMutable(resultType)->ty.emplace(errorRecoveryType()); // reportError(constraint->location, CannotInferBinaryOperation{c.op, std::nullopt, CannotInferBinaryOperation::Operation}); return true; } @@ -574,10 +580,24 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope, singletonTypes, &iceReporter, singletonTypes->errorRecoveryType(), singletonTypes->errorRecoveryTypePack()}; + Anyification anyify{arena, constraint->scope, singletonTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; std::optional anyified = anyify.substitute(c.variables); LUAU_ASSERT(anyified); unify(*anyified, c.variables, constraint->scope); @@ -704,7 +724,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul if (!tf.has_value()) { reportError(UnknownSymbol{petv->name.value, UnknownSymbol::Context::Type}, constraint->location); - bindResult(singletonTypes->errorRecoveryType()); + bindResult(errorRecoveryType()); return true; } @@ -763,7 +783,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul if (itf.foundInfiniteType) { // TODO (CLI-56761): Report an error. - bindResult(singletonTypes->errorRecoveryType()); + bindResult(errorRecoveryType()); return true; } @@ -786,7 +806,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul if (!maybeInstantiated.has_value()) { // TODO (CLI-56761): Report an error. - bindResult(singletonTypes->errorRecoveryType()); + bindResult(errorRecoveryType()); return true; } @@ -863,13 +883,21 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulldcrMagicFunction(NotNull(this), result, c.astFragment); } - if (!usedMagic) + if (usedMagic) { - for (const auto& inner : c.innerConstraints) - { - unsolvedConstraints.push_back(inner); - } - + // There are constraints that are blocked on these constraints. If we + // are never going to even examine them, then we should not block + // anything else on them. + // + // TODO CLI-58842 +#if 0 + for (auto& c: c.innerConstraints) + unblock(c); +#endif + } + else + { + unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); asMutable(c.result)->ty.emplace(constraint->scope); } @@ -909,8 +937,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl }; auto errorify = [&](auto ty) { - Anyification anyify{ - arena, constraint->scope, singletonTypes, &iceReporter, singletonTypes->errorRecoveryType(), singletonTypes->errorRecoveryTypePack()}; + Anyification anyify{arena, constraint->scope, singletonTypes, &iceReporter, errorRecoveryType(), errorRecoveryTypePack()}; std::optional errorified = anyify.substitute(ty); if (!errorified) reportError(CodeTooComplex{}, constraint->location); @@ -1119,7 +1146,7 @@ void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull sc if (!u.errors.empty()) { - TypeId errorType = singletonTypes->errorRecoveryType(); + TypeId errorType = errorRecoveryType(); u.tryUnify(subType, errorType); u.tryUnify(superType, errorType); } @@ -1160,7 +1187,7 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l if (info.name.empty()) { reportError(UnknownRequire{}, location); - return singletonTypes->errorRecoveryType(); + return errorRecoveryType(); } std::string humanReadableName = moduleResolver->getHumanReadableModuleName(info.name); @@ -1177,24 +1204,24 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l if (!moduleResolver->moduleExists(info.name) && !info.optional) reportError(UnknownRequire{humanReadableName}, location); - return singletonTypes->errorRecoveryType(); + return errorRecoveryType(); } if (module->type != SourceCode::Type::Module) { reportError(IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}, location); - return singletonTypes->errorRecoveryType(); + return errorRecoveryType(); } TypePackId modulePack = module->getModuleScope()->returnType; if (get(modulePack)) - return singletonTypes->errorRecoveryType(); + return errorRecoveryType(); std::optional moduleType = first(modulePack); if (!moduleType) { reportError(IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}, location); - return singletonTypes->errorRecoveryType(); + return errorRecoveryType(); } return *moduleType; @@ -1212,4 +1239,14 @@ void ConstraintSolver::reportError(TypeError e) errors.back().moduleName = currentModuleName; } +TypeId ConstraintSolver::errorRecoveryType() const +{ + return singletonTypes->errorRecoveryType(); +} + +TypePackId ConstraintSolver::errorRecoveryTypePack() const +{ + return singletonTypes->errorRecoveryTypePack(); +} + } // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 93cb65b..9ecdc82 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -10,7 +10,8 @@ LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleNameResolution, false) LUAU_FASTFLAGVARIABLE(LuauUseInternalCompilerErrorException, false) -static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) +static std::string wrongNumberOfArgsString( + size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { std::string s = "expects "; @@ -19,11 +20,14 @@ static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCo s += std::to_string(expectedCount) + " "; + if (maximumCount && expectedCount != *maximumCount) + s += "to " + std::to_string(*maximumCount) + " "; + if (argPrefix) s += std::string(argPrefix) + " "; s += "argument"; - if (expectedCount != 1) + if ((maximumCount ? *maximumCount : expectedCount) != 1) s += "s"; s += ", but "; @@ -185,7 +189,12 @@ struct ErrorConverter return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) + " are required here"; case CountMismatch::Arg: - return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual, /*argPrefix*/ nullptr, e.isVariadic); + if (!e.function.empty()) + return "Argument count mismatch. Function '" + e.function + "' " + + wrongNumberOfArgsString(e.expected, e.maximum, e.actual, /*argPrefix*/ nullptr, e.isVariadic); + else + return "Argument count mismatch. Function " + + wrongNumberOfArgsString(e.expected, e.maximum, e.actual, /*argPrefix*/ nullptr, e.isVariadic); } LUAU_ASSERT(!"Unknown context"); @@ -247,10 +256,10 @@ struct ErrorConverter if (e.typeFun.typeParams.size() != e.actualParameters) return "Generic type '" + name + "' " + - wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty()); + wrongNumberOfArgsString(e.typeFun.typeParams.size(), std::nullopt, e.actualParameters, "type", !e.typeFun.typePackParams.empty()); return "Generic type '" + name + "' " + - wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false); + wrongNumberOfArgsString(e.typeFun.typePackParams.size(), std::nullopt, e.actualPackParameters, "type pack", /*isVariadic*/ false); } std::string operator()(const Luau::SyntaxError& e) const @@ -547,7 +556,7 @@ bool DuplicateTypeDefinition::operator==(const DuplicateTypeDefinition& rhs) con bool CountMismatch::operator==(const CountMismatch& rhs) const { - return expected == rhs.expected && actual == rhs.actual && context == rhs.context; + return expected == rhs.expected && maximum == rhs.maximum && actual == rhs.actual && context == rhs.context && function == rhs.function; } bool FunctionDoesNotTakeSelf::operator==(const FunctionDoesNotTakeSelf&) const diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 62b7afd..4ec2371 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1529,4 +1529,20 @@ std::string dump(const Constraint& c) return s; } +std::string toString(const LValue& lvalue) +{ + std::string s; + for (const LValue* current = &lvalue; current; current = baseof(*current)) + { + if (auto field = get(*current)) + s = "." + field->key + s; + else if (auto symbol = get(*current)) + s = toString(*symbol) + s; + else + LUAU_ASSERT(!"Unknown LValue"); + } + + return s; +} + } // namespace Luau diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 88363b4..76b27ac 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -408,7 +408,7 @@ struct TypeChecker2 return; NotNull scope = stack.back(); - TypeArena tempArena; + TypeArena& arena = module->internalTypes; std::vector variableTypes; for (AstLocal* var : forInStatement->vars) @@ -424,10 +424,10 @@ struct TypeChecker2 for (size_t i = 0; i < forInStatement->values.size - 1; ++i) valueTypes.emplace_back(lookupType(forInStatement->values.data[i])); TypePackId iteratorTail = lookupPack(forInStatement->values.data[forInStatement->values.size - 1]); - TypePackId iteratorPack = tempArena.addTypePack(valueTypes, iteratorTail); + TypePackId iteratorPack = arena.addTypePack(valueTypes, iteratorTail); // ... and then expand it out to 3 values (if possible) - const std::vector iteratorTypes = flatten(tempArena, iteratorPack, 3); + const std::vector iteratorTypes = flatten(arena, iteratorPack, 3); if (iteratorTypes.empty()) { reportError(GenericError{"for..in loops require at least one value to iterate over. Got zero"}, getLocation(forInStatement->values)); @@ -456,7 +456,7 @@ struct TypeChecker2 reportError(GenericError{"for..in loops must be passed (next, [table[, state]])"}, getLocation(forInStatement->values)); // It is okay if there aren't enough iterators, but the iteratee must provide enough. - std::vector expectedVariableTypes = flatten(tempArena, nextFn->retTypes, variableTypes.size()); + std::vector expectedVariableTypes = flatten(arena, nextFn->retTypes, variableTypes.size()); if (expectedVariableTypes.size() < variableTypes.size()) reportError(GenericError{"next() does not return enough values"}, forInStatement->vars.data[0]->location); @@ -475,23 +475,23 @@ struct TypeChecker2 // If iteratorTypes is too short to be a valid call to nextFn, we have to report a count mismatch error. // If 2 is too short to be a valid call to nextFn, we have to report a count mismatch error. // If 2 is too long to be a valid call to nextFn, we have to report a count mismatch error. - auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), nextFn->argTypes); + auto [minCount, maxCount] = getParameterExtents(TxnLog::empty(), nextFn->argTypes, /*includeHiddenVariadics*/ true); if (minCount > 2) - reportError(CountMismatch{2, minCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + reportError(CountMismatch{2, std::nullopt, minCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); if (maxCount && *maxCount < 2) - reportError(CountMismatch{2, *maxCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + reportError(CountMismatch{2, std::nullopt, *maxCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); - const std::vector flattenedArgTypes = flatten(tempArena, nextFn->argTypes, 2); + const std::vector flattenedArgTypes = flatten(arena, nextFn->argTypes, 2); const auto [argTypes, argsTail] = Luau::flatten(nextFn->argTypes); size_t firstIterationArgCount = iteratorTypes.empty() ? 0 : iteratorTypes.size() - 1; size_t actualArgCount = expectedVariableTypes.size(); if (firstIterationArgCount < minCount) - reportError(CountMismatch{2, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + reportError(CountMismatch{2, std::nullopt, firstIterationArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); else if (actualArgCount < minCount) - reportError(CountMismatch{2, actualArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); + reportError(CountMismatch{2, std::nullopt, actualArgCount, CountMismatch::Arg}, forInStatement->vars.data[0]->location); if (iteratorTypes.size() >= 2 && flattenedArgTypes.size() > 0) { diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index c140819..b6f20ac 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,13 +32,15 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) +LUAU_FASTFLAGVARIABLE(LuauFunctionArgMismatchDetails, false) LUAU_FASTFLAGVARIABLE(LuauInplaceDemoteSkipAllBound, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix3, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. -LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false); +LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) +LUAU_FASTFLAGVARIABLE(LuauCallUnifyPackTails, false) LUAU_FASTFLAGVARIABLE(LuauCheckGenericHOFTypes, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) @@ -1346,7 +1348,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) } Unifier state = mkUnifier(loopScope, firstValue->location); - checkArgumentList(loopScope, state, argPack, iterFunc->argTypes, /*argLocations*/ {}); + checkArgumentList(loopScope, *firstValue, state, argPack, iterFunc->argTypes, /*argLocations*/ {}); state.log.commit(); @@ -3666,8 +3668,8 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope } } -void TypeChecker::checkArgumentList( - const ScopePtr& scope, Unifier& state, TypePackId argPack, TypePackId paramPack, const std::vector& argLocations) +void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funName, Unifier& state, TypePackId argPack, TypePackId paramPack, + const std::vector& argLocations) { /* Important terminology refresher: * A function requires parameters. @@ -3679,14 +3681,27 @@ void TypeChecker::checkArgumentList( size_t paramIndex = 0; - auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack]() { + auto reportCountMismatchError = [&state, &argLocations, paramPack, argPack, &funName]() { // For this case, we want the error span to cover every errant extra parameter Location location = state.location; if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; - size_t minParams = getParameterExtents(&state.log, paramPack).first; - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + if (FFlag::LuauFunctionArgMismatchDetails) + { + std::string namePath; + if (std::optional lValue = tryGetLValue(funName)) + namePath = toString(*lValue); + + auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); + state.reportError(TypeError{location, + CountMismatch{minParams, optMaxParams, std::distance(begin(argPack), end(argPack)), CountMismatch::Context::Arg, false, namePath}}); + } + else + { + size_t minParams = getParameterExtents(&state.log, paramPack).first; + state.reportError(TypeError{location, CountMismatch{minParams, std::nullopt, std::distance(begin(argPack), end(argPack))}}); + } }; while (true) @@ -3698,11 +3713,8 @@ void TypeChecker::checkArgumentList( std::optional argTail = argIter.tail(); std::optional paramTail = paramIter.tail(); - // If we hit the end of both type packs simultaneously, then there are definitely no further type - // errors to report. All we need to do is tie up any free tails. - // - // If one side has a free tail and the other has none at all, we create an empty pack and bind the - // free tail to that. + // If we hit the end of both type packs simultaneously, we have to unify them. + // But if one side has a free tail and the other has none at all, we create an empty pack and bind the free tail to that. if (argTail) { @@ -3713,6 +3725,10 @@ void TypeChecker::checkArgumentList( else state.log.replace(*argTail, TypePackVar(TypePack{{}})); } + else if (FFlag::LuauCallUnifyPackTails && paramTail) + { + state.tryUnify(*argTail, *paramTail); + } } else if (paramTail) { @@ -3784,12 +3800,25 @@ void TypeChecker::checkArgumentList( } // ok else { - size_t minParams = getParameterExtents(&state.log, paramPack).first; + auto [minParams, optMaxParams] = getParameterExtents(&state.log, paramPack); std::optional tail = flatten(paramPack, state.log).second; bool isVariadic = tail && Luau::isVariadic(*tail); - state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex, CountMismatch::Context::Arg, isVariadic}}); + if (FFlag::LuauFunctionArgMismatchDetails) + { + std::string namePath; + if (std::optional lValue = tryGetLValue(funName)) + namePath = toString(*lValue); + + state.reportError(TypeError{ + state.location, CountMismatch{minParams, optMaxParams, paramIndex, CountMismatch::Context::Arg, isVariadic, namePath}}); + } + else + { + state.reportError( + TypeError{state.location, CountMismatch{minParams, std::nullopt, paramIndex, CountMismatch::Context::Arg, isVariadic}}); + } return; } ++paramIter; @@ -4185,13 +4214,13 @@ std::optional> TypeChecker::checkCallOverload(const Sc Unifier state = mkUnifier(scope, expr.location); // Unify return types - checkArgumentList(scope, state, retPack, ftv->retTypes, /*argLocations*/ {}); + checkArgumentList(scope, *expr.func, state, retPack, ftv->retTypes, /*argLocations*/ {}); if (!state.errors.empty()) { return {}; } - checkArgumentList(scope, state, argPack, ftv->argTypes, *argLocations); + checkArgumentList(scope, *expr.func, state, argPack, ftv->argTypes, *argLocations); if (!state.errors.empty()) { @@ -4245,7 +4274,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal TypePackId editedArgPack = addTypePack(TypePack{editedParamList}); Unifier editedState = mkUnifier(scope, expr.location); - checkArgumentList(scope, editedState, editedArgPack, ftv->argTypes, editedArgLocations); + checkArgumentList(scope, *expr.func, editedState, editedArgPack, ftv->argTypes, editedArgLocations); if (editedState.errors.empty()) { @@ -4276,7 +4305,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal Unifier editedState = mkUnifier(scope, expr.location); - checkArgumentList(scope, editedState, editedArgPack, ftv->argTypes, editedArgLocations); + checkArgumentList(scope, *expr.func, editedState, editedArgPack, ftv->argTypes, editedArgLocations); if (editedState.errors.empty()) { @@ -4345,8 +4374,8 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast // Unify return types if (const FunctionTypeVar* ftv = get(overload)) { - checkArgumentList(scope, state, retPack, ftv->retTypes, {}); - checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); + checkArgumentList(scope, *expr.func, state, retPack, ftv->retTypes, {}); + checkArgumentList(scope, *expr.func, state, argPack, ftv->argTypes, argLocations); } if (state.errors.empty()) diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 2fa9413..0fa4df6 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -357,10 +357,15 @@ bool isVariadic(TypePackId tp, const TxnLog& log) if (!tail) return false; - if (log.get(*tail)) + return isVariadicTail(*tail, log); +} + +bool isVariadicTail(TypePackId tp, const TxnLog& log, bool includeHiddenVariadics) +{ + if (log.get(tp)) return true; - if (auto vtp = log.get(*tail); vtp && !vtp->hidden) + if (auto vtp = log.get(tp); vtp && (includeHiddenVariadics || !vtp->hidden)) return true; return false; diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index a96820d..6ea04ea 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -6,6 +6,8 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" +LUAU_FASTFLAG(LuauFunctionArgMismatchDetails) + namespace Luau { @@ -193,7 +195,7 @@ std::optional getIndexTypeFromType(const ScopePtr& scope, ErrorVec& erro return std::nullopt; } -std::pair> getParameterExtents(const TxnLog* log, TypePackId tp) +std::pair> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics) { size_t minCount = 0; size_t optionalCount = 0; @@ -216,7 +218,7 @@ std::pair> getParameterExtents(const TxnLog* log, ++it; } - if (it.tail()) + if (it.tail() && (!FFlag::LuauFunctionArgMismatchDetails || isVariadicTail(*it.tail(), *log, includeHiddenVariadics))) return {minCount, std::nullopt}; else return {minCount, minCount + optionalCount}; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 4f6603f..3a820ea 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -1133,14 +1133,14 @@ std::optional> magicFunctionFormat( size_t numExpectedParams = expected.size() + 1; // + 1 for the format string if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) - typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, numActualParams}}); + typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); } else { size_t actualParamSize = params.size() - paramOffset; if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize)) - typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}}); + typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), std::nullopt, actualParamSize}}); } return WithPredicate{arena.addTypePack({typechecker.stringType})}; } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index fd67843..505e9e4 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,6 +22,7 @@ LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) +LUAU_FASTFLAG(LuauCallUnifyPackTails) namespace Luau { @@ -1159,7 +1160,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal size_t actualSize = size(subTp); if (ctx == CountMismatch::Result) std::swap(expectedSize, actualSize); - reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + reportError(TypeError{location, CountMismatch{expectedSize, std::nullopt, actualSize, ctx}}); while (superIter.good()) { @@ -2118,6 +2119,15 @@ void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel de for (; superIter != superEndIter; ++superIter) tp->head.push_back(*superIter); } + else if (const VariadicTypePack* subVariadic = log.getMutable(subTailPack); + subVariadic && FFlag::LuauCallUnifyPackTails) + { + while (superIter != superEndIter) + { + tryUnify_(subVariadic->ty, *superIter); + ++superIter; + } + } } else { @@ -2125,7 +2135,7 @@ void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel de { if (!isOptional(*superIter)) { - errors.push_back(TypeError{location, CountMismatch{size(superTy), size(subTy), CountMismatch::Return}}); + errors.push_back(TypeError{location, CountMismatch{size(superTy), std::nullopt, size(subTy), CountMismatch::Return}}); return; } ++superIter; diff --git a/CodeGen/include/Luau/CodeAllocator.h b/CodeGen/include/Luau/CodeAllocator.h index c80b5c3..01e1312 100644 --- a/CodeGen/include/Luau/CodeAllocator.h +++ b/CodeGen/include/Luau/CodeAllocator.h @@ -24,7 +24,7 @@ struct CodeAllocator void* context = nullptr; // Called when new block is created to create and setup the unwinding information for all the code in the block - // Some platforms require this data to be placed inside the block itself, so we also return 'unwindDataSizeInBlock' + // If data is placed inside the block itself (some platforms require this), we also return 'unwindDataSizeInBlock' void* (*createBlockUnwindInfo)(void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock) = nullptr; // Called to destroy unwinding information returned by 'createBlockUnwindInfo' diff --git a/CodeGen/include/Luau/CodeBlockUnwind.h b/CodeGen/include/Luau/CodeBlockUnwind.h new file mode 100644 index 0000000..ddae33a --- /dev/null +++ b/CodeGen/include/Luau/CodeBlockUnwind.h @@ -0,0 +1,17 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +// context must be an UnwindBuilder +void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock); +void destroyBlockUnwindInfo(void* context, void* unwindData); + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/UnwindBuilder.h b/CodeGen/include/Luau/UnwindBuilder.h new file mode 100644 index 0000000..c6f611b --- /dev/null +++ b/CodeGen/include/Luau/UnwindBuilder.h @@ -0,0 +1,35 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterX64.h" + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +class UnwindBuilder +{ +public: + virtual ~UnwindBuilder() {} + + virtual void start() = 0; + + virtual void spill(int espOffset, RegisterX64 reg) = 0; + virtual void save(RegisterX64 reg) = 0; + virtual void allocStack(int size) = 0; + virtual void setupFrameReg(RegisterX64 reg, int espOffset) = 0; + + virtual void finish() = 0; + + virtual size_t getSize() const = 0; + + // This will place the unwinding data at the target address and might update values of some fields + virtual void finalize(char* target, void* funcAddress, size_t funcSize) const = 0; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/UnwindBuilderDwarf2.h b/CodeGen/include/Luau/UnwindBuilderDwarf2.h new file mode 100644 index 0000000..09c91d4 --- /dev/null +++ b/CodeGen/include/Luau/UnwindBuilderDwarf2.h @@ -0,0 +1,40 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterX64.h" +#include "UnwindBuilder.h" + +namespace Luau +{ +namespace CodeGen +{ + +class UnwindBuilderDwarf2 : public UnwindBuilder +{ +public: + void start() override; + + void spill(int espOffset, RegisterX64 reg) override; + void save(RegisterX64 reg) override; + void allocStack(int size) override; + void setupFrameReg(RegisterX64 reg, int espOffset) override; + + void finish() override; + + size_t getSize() const override; + + void finalize(char* target, void* funcAddress, size_t funcSize) const override; + +private: + static const unsigned kRawDataLimit = 128; + char rawData[kRawDataLimit]; + char* pos = rawData; + + uint32_t stackOffset = 0; + + // We will remember the FDE location to write some of the fields like entry length, function start and size later + char* fdeEntryStart = nullptr; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/UnwindBuilderWin.h b/CodeGen/include/Luau/UnwindBuilderWin.h new file mode 100644 index 0000000..801eb6e --- /dev/null +++ b/CodeGen/include/Luau/UnwindBuilderWin.h @@ -0,0 +1,51 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/RegisterX64.h" +#include "UnwindBuilder.h" + +#include + +namespace Luau +{ +namespace CodeGen +{ + +// This struct matches the layout of UNWIND_CODE from ehdata.h +struct UnwindCodeWin +{ + uint8_t offset; + uint8_t opcode : 4; + uint8_t opinfo : 4; +}; + +class UnwindBuilderWin : public UnwindBuilder +{ +public: + void start() override; + + void spill(int espOffset, RegisterX64 reg) override; + void save(RegisterX64 reg) override; + void allocStack(int size) override; + void setupFrameReg(RegisterX64 reg, int espOffset) override; + + void finish() override; + + size_t getSize() const override; + + void finalize(char* target, void* funcAddress, size_t funcSize) const override; + +private: + // Windows unwind codes are written in reverse, so we have to collect them all first + std::vector unwindCodes; + + uint8_t prologSize = 0; + RegisterX64 frameReg = rax; // rax means that frame register is not used + uint8_t frameRegOffset = 0; + uint32_t stackOffset = 0; + + size_t infoSize = 0; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp index f743206..aacf40a 100644 --- a/CodeGen/src/CodeAllocator.cpp +++ b/CodeGen/src/CodeAllocator.cpp @@ -6,8 +6,13 @@ #include #if defined(_WIN32) + +#ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX #define NOMINMAX +#endif #include const size_t kPageSize = 4096; @@ -135,17 +140,29 @@ bool CodeAllocator::allocate( if (codeSize) memcpy(blockPos + codeOffset, code, codeSize); - size_t pageSize = alignToPageSize(unwindInfoSize + totalSize); + size_t pageAlignedSize = alignToPageSize(unwindInfoSize + totalSize); - makePagesExecutable(blockPos, pageSize); + makePagesExecutable(blockPos, pageAlignedSize); flushInstructionCache(blockPos + codeOffset, codeSize); result = blockPos + unwindInfoSize; resultSize = totalSize; resultCodeStart = blockPos + codeOffset; - blockPos += pageSize; - LUAU_ASSERT((uintptr_t(blockPos) & (kPageSize - 1)) == 0); // Allocation ends on page boundary + // Ensure that future allocations from the block start from a page boundary. + // This is important since we use W^X, and writing to the previous page would require briefly removing + // executable bit from it, which may result in access violations if that code is being executed concurrently. + if (pageAlignedSize <= size_t(blockEnd - blockPos)) + { + blockPos += pageAlignedSize; + LUAU_ASSERT((uintptr_t(blockPos) & (kPageSize - 1)) == 0); + LUAU_ASSERT(blockPos <= blockEnd); + } + else + { + // Future allocations will need to allocate fresh blocks + blockPos = blockEnd; + } return true; } diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp new file mode 100644 index 0000000..6191cee --- /dev/null +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -0,0 +1,123 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/CodeBlockUnwind.h" + +#include "Luau/UnwindBuilder.h" + +#include + +#if defined(_WIN32) && defined(_M_X64) + +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include + +#elif !defined(_WIN32) + +// Defined in unwind.h which may not be easily discoverable on various platforms +extern "C" void __register_frame(const void*); +extern "C" void __deregister_frame(const void*); + +#endif + +#if defined(__APPLE__) +// On Mac, each FDE inside eh_frame section has to be handled separately +static void visitFdeEntries(char* pos, void (*cb)(const void*)) +{ + for (;;) + { + unsigned partLength; + memcpy(&partLength, pos, sizeof(partLength)); + + if (partLength == 0) // Zero-length section signals completion + break; + + unsigned partId; + memcpy(&partId, pos + 4, sizeof(partId)); + + if (partId != 0) // Skip CIE part + cb(pos); // CIE is found using an offset in FDE + + pos += partLength + 4; + } +} +#endif + +namespace Luau +{ +namespace CodeGen +{ + +void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, size_t& unwindDataSizeInBlock) +{ +#if defined(_WIN32) && defined(_M_X64) + UnwindBuilder* unwind = (UnwindBuilder*)context; + + // All unwinding related data is placed together at the start of the block + size_t unwindSize = sizeof(RUNTIME_FUNCTION) + unwind->getSize(); + unwindSize = (unwindSize + 15) & ~15; // Align to 16 bytes + LUAU_ASSERT(blockSize >= unwindSize); + + RUNTIME_FUNCTION* runtimeFunc = (RUNTIME_FUNCTION*)block; + runtimeFunc->BeginAddress = DWORD(unwindSize); // Code will start after the unwind info + runtimeFunc->EndAddress = DWORD(blockSize); // Whole block is a part of a 'single function' + runtimeFunc->UnwindInfoAddress = DWORD(sizeof(RUNTIME_FUNCTION)); // Unwind info is placed at the start of the block + + char* unwindData = (char*)block + runtimeFunc->UnwindInfoAddress; + unwind->finalize(unwindData, block + unwindSize, blockSize - unwindSize); + + if (!RtlAddFunctionTable(runtimeFunc, 1, uintptr_t(block))) + { + LUAU_ASSERT(!"failed to allocate function table"); + return nullptr; + } + + unwindDataSizeInBlock = unwindSize; + return block; +#elif !defined(_WIN32) + UnwindBuilder* unwind = (UnwindBuilder*)context; + + // All unwinding related data is placed together at the start of the block + size_t unwindSize = unwind->getSize(); + unwindSize = (unwindSize + 15) & ~15; // Align to 16 bytes + LUAU_ASSERT(blockSize >= unwindSize); + + char* unwindData = (char*)block; + unwind->finalize(unwindData, block, blockSize); + +#if defined(__APPLE__) + visitFdeEntries(unwindData, __register_frame); +#else + __register_frame(unwindData); +#endif + + unwindDataSizeInBlock = unwindSize; + return block; +#endif + + return nullptr; +} + +void destroyBlockUnwindInfo(void* context, void* unwindData) +{ +#if defined(_WIN32) && defined(_M_X64) + RUNTIME_FUNCTION* runtimeFunc = (RUNTIME_FUNCTION*)unwindData; + + if (!RtlDeleteFunctionTable(runtimeFunc)) + LUAU_ASSERT(!"failed to deallocate function table"); +#elif !defined(_WIN32) + +#if defined(__APPLE__) + visitFdeEntries((char*)unwindData, __deregister_frame); +#else + __deregister_frame(unwindData); +#endif + +#endif +} + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp new file mode 100644 index 0000000..38e3e71 --- /dev/null +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -0,0 +1,253 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/UnwindBuilderDwarf2.h" + +#include + +// General information about Dwarf2 format can be found at: +// https://dwarfstd.org/doc/dwarf-2.0.0.pdf [DWARF Debugging Information Format] +// Main part for async exception unwinding is in section '6.4 Call Frame Information' + +// Information about System V ABI (AMD64) can be found at: +// https://refspecs.linuxbase.org/elf/x86_64-abi-0.99.pdf [System V Application Binary Interface (AMD64 Architecture Processor Supplement)] +// Interaction between Dwarf2 and System V ABI can be found in sections '3.6.2 DWARF Register Number Mapping' and '4.2.4 EH_FRAME sections' + +static char* writeu8(char* target, uint8_t value) +{ + memcpy(target, &value, sizeof(value)); + return target + sizeof(value); +} + +static char* writeu32(char* target, uint32_t value) +{ + memcpy(target, &value, sizeof(value)); + return target + sizeof(value); +} + +static char* writeu64(char* target, uint64_t value) +{ + memcpy(target, &value, sizeof(value)); + return target + sizeof(value); +} + +static char* writeuleb128(char* target, uint64_t value) +{ + do + { + char byte = value & 0x7f; + value >>= 7; + + if (value) + byte |= 0x80; + + *target++ = byte; + } while (value); + + return target; +} + +// Call frame instruction opcodes +#define DW_CFA_advance_loc 0x40 +#define DW_CFA_offset 0x80 +#define DW_CFA_restore 0xc0 +#define DW_CFA_nop 0x00 +#define DW_CFA_set_loc 0x01 +#define DW_CFA_advance_loc1 0x02 +#define DW_CFA_advance_loc2 0x03 +#define DW_CFA_advance_loc4 0x04 +#define DW_CFA_offset_extended 0x05 +#define DW_CFA_restore_extended 0x06 +#define DW_CFA_undefined 0x07 +#define DW_CFA_same_value 0x08 +#define DW_CFA_register 0x09 +#define DW_CFA_remember_state 0x0a +#define DW_CFA_restore_state 0x0b +#define DW_CFA_def_cfa 0x0c +#define DW_CFA_def_cfa_register 0x0d +#define DW_CFA_def_cfa_offset 0x0e +#define DW_CFA_def_cfa_expression 0x0f +#define DW_CFA_expression 0x10 +#define DW_CFA_offset_extended_sf 0x11 +#define DW_CFA_def_cfa_sf 0x12 +#define DW_CFA_def_cfa_offset_sf 0x13 +#define DW_CFA_val_offset 0x14 +#define DW_CFA_val_offset_sf 0x15 +#define DW_CFA_val_expression 0x16 +#define DW_CFA_lo_user 0x1c +#define DW_CFA_hi_user 0x3f + +// Register numbers for x64 +#define DW_REG_RAX 0 +#define DW_REG_RDX 1 +#define DW_REG_RCX 2 +#define DW_REG_RBX 3 +#define DW_REG_RSI 4 +#define DW_REG_RDI 5 +#define DW_REG_RBP 6 +#define DW_REG_RSP 7 +#define DW_REG_R8 8 +#define DW_REG_R9 9 +#define DW_REG_R10 10 +#define DW_REG_R11 11 +#define DW_REG_R12 12 +#define DW_REG_R13 13 +#define DW_REG_R14 14 +#define DW_REG_R15 15 +#define DW_REG_RA 16 + +const int regIndexToDwRegX64[16] = {DW_REG_RAX, DW_REG_RCX, DW_REG_RDX, DW_REG_RBX, DW_REG_RSP, DW_REG_RBP, DW_REG_RSI, DW_REG_RDI, DW_REG_R8, + DW_REG_R9, DW_REG_R10, DW_REG_R11, DW_REG_R12, DW_REG_R13, DW_REG_R14, DW_REG_R15}; + +const int kCodeAlignFactor = 1; +const int kDataAlignFactor = 8; +const int kDwarfAlign = 8; +const int kFdeInitialLocationOffset = 8; +const int kFdeAddressRangeOffset = 16; + +// Define canonical frame address expression as [reg + offset] +static char* defineCfaExpression(char* pos, int dwReg, uint32_t stackOffset) +{ + pos = writeu8(pos, DW_CFA_def_cfa); + pos = writeuleb128(pos, dwReg); + pos = writeuleb128(pos, stackOffset); + return pos; +} + +// Update offset value in canonical frame address expression +static char* defineCfaExpressionOffset(char* pos, uint32_t stackOffset) +{ + pos = writeu8(pos, DW_CFA_def_cfa_offset); + pos = writeuleb128(pos, stackOffset); + return pos; +} + +static char* defineSavedRegisterLocation(char* pos, int dwReg, uint32_t stackOffset) +{ + LUAU_ASSERT(stackOffset % kDataAlignFactor == 0 && "stack offsets have to be measured in kDataAlignFactor units"); + + if (dwReg <= 15) + { + pos = writeu8(pos, DW_CFA_offset + dwReg); + } + else + { + pos = writeu8(pos, DW_CFA_offset_extended); + pos = writeuleb128(pos, dwReg); + } + + pos = writeuleb128(pos, stackOffset / kDataAlignFactor); + return pos; +} + +static char* advanceLocation(char* pos, uint8_t offset) +{ + pos = writeu8(pos, DW_CFA_advance_loc1); + pos = writeu8(pos, offset); + return pos; +} + +static char* alignPosition(char* start, char* pos) +{ + size_t size = pos - start; + size_t pad = ((size + kDwarfAlign - 1) & ~(kDwarfAlign - 1)) - size; + + for (size_t i = 0; i < pad; i++) + pos = writeu8(pos, DW_CFA_nop); + + return pos; +} + +namespace Luau +{ +namespace CodeGen +{ + +void UnwindBuilderDwarf2::start() +{ + char* cieLength = pos; + pos = writeu32(pos, 0); // Length (to be filled later) + + pos = writeu32(pos, 0); // CIE id. 0 -- .eh_frame + pos = writeu8(pos, 1); // Version + + pos = writeu8(pos, 0); // CIE augmentation String "" + + pos = writeuleb128(pos, kCodeAlignFactor); // Code align factor + pos = writeuleb128(pos, -kDataAlignFactor & 0x7f); // Data align factor of (as signed LEB128) + pos = writeu8(pos, DW_REG_RA); // Return address register + + // Optional CIE augmentation section (not present) + + // Call frame instructions (common for all FDEs, of which we have 1) + stackOffset = 8; // Return address was pushed by calling the function + + pos = defineCfaExpression(pos, DW_REG_RSP, stackOffset); // Define CFA to be the rsp + 8 + pos = defineSavedRegisterLocation(pos, DW_REG_RA, 8); // Define return address register (RA) to be located at CFA - 8 + + pos = alignPosition(cieLength, pos); + writeu32(cieLength, unsigned(pos - cieLength - 4)); // Length field itself is excluded from length + + fdeEntryStart = pos; // Will be written at the end + pos = writeu32(pos, 0); // Length (to be filled later) + pos = writeu32(pos, unsigned(pos - rawData)); // CIE pointer + pos = writeu64(pos, 0); // Initial location (to be filled later) + pos = writeu64(pos, 0); // Address range (to be filled later) + + // Optional CIE augmentation section (not present) + + // Function call frame instructions to follow +} + +void UnwindBuilderDwarf2::spill(int espOffset, RegisterX64 reg) +{ + pos = advanceLocation(pos, 5); // REX.W mov [rsp + imm8], reg +} + +void UnwindBuilderDwarf2::save(RegisterX64 reg) +{ + stackOffset += 8; + pos = advanceLocation(pos, 2); // REX.W push reg + pos = defineCfaExpressionOffset(pos, stackOffset); + pos = defineSavedRegisterLocation(pos, regIndexToDwRegX64[reg.index], stackOffset); +} + +void UnwindBuilderDwarf2::allocStack(int size) +{ + stackOffset += size; + pos = advanceLocation(pos, 4); // REX.W sub rsp, imm8 + pos = defineCfaExpressionOffset(pos, stackOffset); +} + +void UnwindBuilderDwarf2::setupFrameReg(RegisterX64 reg, int espOffset) +{ + // Not required for unwinding +} + +void UnwindBuilderDwarf2::finish() +{ + LUAU_ASSERT(stackOffset % 16 == 0 && "stack has to be aligned to 16 bytes after prologue"); + + pos = alignPosition(fdeEntryStart, pos); + writeu32(fdeEntryStart, unsigned(pos - fdeEntryStart - 4)); // Length field itself is excluded from length + + // Terminate section + pos = writeu32(pos, 0); + + LUAU_ASSERT(getSize() <= kRawDataLimit); +} + +size_t UnwindBuilderDwarf2::getSize() const +{ + return size_t(pos - rawData); +} + +void UnwindBuilderDwarf2::finalize(char* target, void* funcAddress, size_t funcSize) const +{ + memcpy(target, rawData, getSize()); + + unsigned fdeEntryStartPos = unsigned(fdeEntryStart - rawData); + writeu64(target + fdeEntryStartPos + kFdeInitialLocationOffset, uintptr_t(funcAddress)); + writeu64(target + fdeEntryStartPos + kFdeAddressRangeOffset, funcSize); +} + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp new file mode 100644 index 0000000..5405fcf --- /dev/null +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -0,0 +1,120 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/UnwindBuilderWin.h" + +#include + +// Information about the Windows x64 unwinding data setup can be found at: +// https://docs.microsoft.com/en-us/cpp/build/exception-handling-x64 [x64 exception handling] + +#define UWOP_PUSH_NONVOL 0 +#define UWOP_ALLOC_LARGE 1 +#define UWOP_ALLOC_SMALL 2 +#define UWOP_SET_FPREG 3 +#define UWOP_SAVE_NONVOL 4 +#define UWOP_SAVE_NONVOL_FAR 5 +#define UWOP_SAVE_XMM128 8 +#define UWOP_SAVE_XMM128_FAR 9 +#define UWOP_PUSH_MACHFRAME 10 + +namespace Luau +{ +namespace CodeGen +{ + +// This struct matches the layout of UNWIND_INFO from ehdata.h +struct UnwindInfoWin +{ + uint8_t version : 3; + uint8_t flags : 5; + uint8_t prologsize; + uint8_t unwindcodecount; + uint8_t framereg : 4; + uint8_t frameregoff : 4; +}; + +void UnwindBuilderWin::start() +{ + stackOffset = 8; // Return address was pushed by calling the function + + unwindCodes.reserve(16); +} + +void UnwindBuilderWin::spill(int espOffset, RegisterX64 reg) +{ + prologSize += 5; // REX.W mov [rsp + imm8], reg +} + +void UnwindBuilderWin::save(RegisterX64 reg) +{ + prologSize += 2; // REX.W push reg + stackOffset += 8; + unwindCodes.push_back({prologSize, UWOP_PUSH_NONVOL, reg.index}); +} + +void UnwindBuilderWin::allocStack(int size) +{ + LUAU_ASSERT(size >= 8 && size <= 128 && size % 8 == 0); + + prologSize += 4; // REX.W sub rsp, imm8 + stackOffset += size; + unwindCodes.push_back({prologSize, UWOP_ALLOC_SMALL, uint8_t((size - 8) / 8)}); +} + +void UnwindBuilderWin::setupFrameReg(RegisterX64 reg, int espOffset) +{ + LUAU_ASSERT(espOffset < 256 && espOffset % 16 == 0); + + frameReg = reg; + frameRegOffset = uint8_t(espOffset / 16); + + prologSize += 5; // REX.W lea rbp, [rsp + imm8] + unwindCodes.push_back({prologSize, UWOP_SET_FPREG, frameRegOffset}); +} + +void UnwindBuilderWin::finish() +{ + // Windows unwind code count is stored in uint8_t, so we can't have more + LUAU_ASSERT(unwindCodes.size() < 256); + + LUAU_ASSERT(stackOffset % 16 == 0 && "stack has to be aligned to 16 bytes after prologue"); + + size_t codeArraySize = unwindCodes.size(); + codeArraySize = (codeArraySize + 1) & ~1; // Size has to be even, but unwind code count doesn't have to + + infoSize = sizeof(UnwindInfoWin) + sizeof(UnwindCodeWin) * codeArraySize; +} + +size_t UnwindBuilderWin::getSize() const +{ + return infoSize; +} + +void UnwindBuilderWin::finalize(char* target, void* funcAddress, size_t funcSize) const +{ + UnwindInfoWin info; + info.version = 1; + info.flags = 0; // No EH + info.prologsize = prologSize; + info.unwindcodecount = uint8_t(unwindCodes.size()); + info.framereg = frameReg.index; + info.frameregoff = frameRegOffset; + + memcpy(target, &info, sizeof(info)); + target += sizeof(UnwindInfoWin); + + if (!unwindCodes.empty()) + { + // Copy unwind codes in reverse order + // Some unwind codes take up two array slots, but we don't use those atm + char* pos = target + sizeof(UnwindCodeWin) * (unwindCodes.size() - 1); + + for (size_t i = 0; i < unwindCodes.size(); i++) + { + memcpy(pos, &unwindCodes[i], sizeof(UnwindCodeWin)); + pos -= sizeof(UnwindCodeWin); + } + } +} + +} // namespace CodeGen +} // namespace Luau diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 8b6ccdd..f865243 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -411,7 +411,7 @@ enum LuauOpcode enum LuauBytecodeTag { // Bytecode version; runtime supports [MIN, MAX], compiler emits TARGET by default but may emit a higher version when flags are enabled - LBC_VERSION_MIN = 2, + LBC_VERSION_MIN = 3, LBC_VERSION_MAX = 3, LBC_VERSION_TARGET = 3, // Types of constant table entries diff --git a/Sources.cmake b/Sources.cmake index ff0b5a6..580c8b3 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -57,13 +57,20 @@ target_sources(Luau.Compiler PRIVATE target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/AssemblyBuilderX64.h CodeGen/include/Luau/CodeAllocator.h + CodeGen/include/Luau/CodeBlockUnwind.h CodeGen/include/Luau/Condition.h CodeGen/include/Luau/Label.h CodeGen/include/Luau/OperandX64.h CodeGen/include/Luau/RegisterX64.h + CodeGen/include/Luau/UnwindBuilder.h + CodeGen/include/Luau/UnwindBuilderDwarf2.h + CodeGen/include/Luau/UnwindBuilderWin.h CodeGen/src/AssemblyBuilderX64.cpp CodeGen/src/CodeAllocator.cpp + CodeGen/src/CodeBlockUnwind.cpp + CodeGen/src/UnwindBuilderDwarf2.cpp + CodeGen/src/UnwindBuilderWin.cpp ) # Luau.Analysis Sources @@ -258,9 +265,13 @@ endif() if(TARGET Luau.UnitTest) # Luau.UnitTest Sources target_sources(Luau.UnitTest PRIVATE + tests/AstQueryDsl.h + tests/ConstraintGraphBuilderFixture.h tests/Fixture.h tests/IostreamOptional.h tests/ScopedFlags.h + tests/AstQueryDsl.cpp + tests/ConstraintGraphBuilderFixture.cpp tests/Fixture.cpp tests/AssemblyBuilderX64.test.cpp tests/AstJsonEncoder.test.cpp diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 7603a79..cbcaa3c 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -34,8 +34,6 @@ * therefore call luaC_checkGC before luaC_threadbarrier to guarantee the object is pushed to a gray thread. */ -LUAU_FASTFLAG(LuauSimplerUpval) - const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -1298,8 +1296,6 @@ const char* lua_setupvalue(lua_State* L, int funcindex, int n) L->top--; setobj(L, val, L->top); luaC_barrier(L, clvalue(fi), L->top); - if (!FFlag::LuauSimplerUpval) - luaC_upvalbarrier(L, cast_to(UpVal*, NULL), val); } return name; } diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 51f63d3..ecd2fcb 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -243,14 +243,14 @@ void luaD_call(lua_State* L, StkId func, int nResults) { // is a Lua function? L->ci->flags |= LUA_CALLINFO_RETURN; // luau_execute will stop after returning from the stack frame - int oldactive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); + bool oldactive = L->isactive; + L->isactive = true; luaC_threadbarrier(L); luau_execute(L); // call it if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); + L->isactive = false; } L->nCcalls--; @@ -427,7 +427,7 @@ static int resume_error(lua_State* L, const char* msg) static void resume_finish(lua_State* L, int status) { L->nCcalls = L->baseCcalls; - resetbit(L->stackstate, THREAD_ACTIVEBIT); + L->isactive = false; if (status != 0) { // error? @@ -452,7 +452,7 @@ int lua_resume(lua_State* L, lua_State* from, int nargs) return resume_error(L, "C stack overflow"); L->baseCcalls = ++L->nCcalls; - l_setbit(L->stackstate, THREAD_ACTIVEBIT); + L->isactive = true; luaC_threadbarrier(L); @@ -481,7 +481,7 @@ int lua_resumeerror(lua_State* L, lua_State* from) return resume_error(L, "C stack overflow"); L->baseCcalls = ++L->nCcalls; - l_setbit(L->stackstate, THREAD_ACTIVEBIT); + L->isactive = true; luaC_threadbarrier(L); @@ -546,7 +546,7 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e { unsigned short oldnCcalls = L->nCcalls; ptrdiff_t old_ci = saveci(L, L->ci); - int oldactive = luaC_threadactive(L); + bool oldactive = L->isactive; int status = luaD_rawrunprotected(L, func, u); if (status != 0) { @@ -560,7 +560,7 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e // since the call failed with an error, we might have to reset the 'active' thread state if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); + L->isactive = false; // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. L->nCcalls = oldnCcalls; diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 8c78083..3c1869b 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -6,9 +6,6 @@ #include "lmem.h" #include "lgc.h" -LUAU_FASTFLAG(LuauSimplerUpval) -LUAU_FASTFLAG(LuauNoSleepBit) - Proto* luaF_newproto(lua_State* L) { Proto* f = luaM_newgco(L, Proto, sizeof(Proto), L->activememcat); @@ -74,21 +71,16 @@ UpVal* luaF_findupval(lua_State* L, StkId level) UpVal* p; while (*pp != NULL && (p = *pp)->v >= level) { - LUAU_ASSERT(!FFlag::LuauSimplerUpval || !isdead(g, obj2gco(p))); + LUAU_ASSERT(!isdead(g, obj2gco(p))); LUAU_ASSERT(upisopen(p)); if (p->v == level) - { // found a corresponding upvalue? - if (!FFlag::LuauSimplerUpval && isdead(g, obj2gco(p))) // is it dead? - changewhite(obj2gco(p)); // resurrect it return p; - } pp = &p->u.open.threadnext; } - LUAU_ASSERT(luaC_threadactive(L)); - LUAU_ASSERT(!luaC_threadsleeping(L)); - LUAU_ASSERT(!FFlag::LuauNoSleepBit || !isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black + LUAU_ASSERT(L->isactive); + LUAU_ASSERT(!isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black UpVal* uv = luaM_newgco(L, UpVal, sizeof(UpVal), L->activememcat); // not found: create a new one luaC_init(L, uv, LUA_TUPVAL); @@ -96,22 +88,8 @@ UpVal* luaF_findupval(lua_State* L, StkId level) uv->v = level; // current value lives in the stack // chain the upvalue in the threads open upvalue list at the proper position - if (FFlag::LuauSimplerUpval) - { - uv->u.open.threadnext = *pp; - *pp = uv; - } - else - { - UpVal* next = *pp; - uv->u.open.threadnext = next; - - uv->u.open.threadprev = pp; - if (next) - next->u.open.threadprev = &uv->u.open.threadnext; - - *pp = uv; - } + uv->u.open.threadnext = *pp; + *pp = uv; // double link the upvalue in the global open upvalue list uv->u.open.prev = &g->uvhead; @@ -123,26 +101,8 @@ UpVal* luaF_findupval(lua_State* L, StkId level) return uv; } -void luaF_unlinkupval(UpVal* uv) -{ - LUAU_ASSERT(!FFlag::LuauSimplerUpval); - - // unlink upvalue from the global open upvalue list - LUAU_ASSERT(uv->u.open.next->u.open.prev == uv && uv->u.open.prev->u.open.next == uv); - uv->u.open.next->u.open.prev = uv->u.open.prev; - uv->u.open.prev->u.open.next = uv->u.open.next; - - // unlink upvalue from the thread open upvalue list - *uv->u.open.threadprev = uv->u.open.threadnext; - - if (UpVal* next = uv->u.open.threadnext) - next->u.open.threadprev = uv->u.open.threadprev; -} - void luaF_freeupval(lua_State* L, UpVal* uv, lua_Page* page) { - if (!FFlag::LuauSimplerUpval && uv->v != &uv->u.value) // is it open? - luaF_unlinkupval(uv); // remove from open list luaM_freegco(L, uv, sizeof(UpVal), uv->memcat, page); // free upvalue } @@ -154,41 +114,17 @@ void luaF_close(lua_State* L, StkId level) { GCObject* o = obj2gco(uv); LUAU_ASSERT(!isblack(o) && upisopen(uv)); + LUAU_ASSERT(!isdead(g, o)); - if (FFlag::LuauSimplerUpval) - { - LUAU_ASSERT(!isdead(g, o)); + // unlink value *before* closing it since value storage overlaps + L->openupval = uv->u.open.threadnext; - // unlink value *before* closing it since value storage overlaps - L->openupval = uv->u.open.threadnext; - - luaF_closeupval(L, uv, /* dead= */ false); - } - else - { - // by removing the upvalue from global/thread open upvalue lists, L->openupval will be pointing to the next upvalue - luaF_unlinkupval(uv); - - if (isdead(g, o)) - { - // close the upvalue without copying the dead data so that luaF_freeupval will not unlink again - uv->v = &uv->u.value; - } - else - { - setobj(L, &uv->u.value, uv->v); - uv->v = &uv->u.value; - // GC state of a new closed upvalue has to be initialized - luaC_upvalclosed(L, uv); - } - } + luaF_closeupval(L, uv, /* dead= */ false); } } void luaF_closeupval(lua_State* L, UpVal* uv, bool dead) { - LUAU_ASSERT(FFlag::LuauSimplerUpval); - // unlink value from all lists *before* closing it since value storage overlaps LUAU_ASSERT(uv->u.open.next->u.open.prev == uv && uv->u.open.prev->u.open.next == uv); uv->u.open.next->u.open.prev = uv->u.open.prev; diff --git a/VM/src/lfunc.h b/VM/src/lfunc.h index 899d040..679836e 100644 --- a/VM/src/lfunc.h +++ b/VM/src/lfunc.h @@ -15,7 +15,6 @@ LUAI_FUNC void luaF_close(lua_State* L, StkId level); LUAI_FUNC void luaF_closeupval(lua_State* L, UpVal* uv, bool dead); LUAI_FUNC void luaF_freeproto(lua_State* L, Proto* f, struct lua_Page* page); LUAI_FUNC void luaF_freeclosure(lua_State* L, Closure* c, struct lua_Page* page); -LUAI_FUNC void luaF_unlinkupval(UpVal* uv); LUAI_FUNC void luaF_freeupval(lua_State* L, UpVal* uv, struct lua_Page* page); LUAI_FUNC const LocVar* luaF_getlocal(const Proto* func, int local_number, int pc); LUAI_FUNC const LocVar* luaF_findlocal(const Proto* func, int local_reg, int pc); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index fb610e1..c2a672e 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -13,6 +13,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauBetterThreadMark, false) + /* * Luau uses an incremental non-generational non-moving mark&sweep garbage collector. * @@ -108,21 +110,14 @@ * API calls that can write to thread stacks outside of execution (which implies active) uses a thread barrier that checks if the thread is * black, and if it is it marks it as gray and puts it on a gray list to be rescanned during atomic phase. * - * NOTE: The above is only true when LuauNoSleepBit is enabled. - * * Upvalues are special objects that can be closed, in which case they contain the value (acting as a reference cell) and can be dealt * with using the regular algorithm, or open, in which case they refer to a stack slot in some other thread. These are difficult to deal * with because the stack writes are not monitored. Because of this open upvalues are treated in a somewhat special way: they are never marked * as black (doing so would violate the GC invariant), and they are kept in a special global list (global_State::uvhead) which is traversed * during atomic phase. This is needed because an open upvalue might point to a stack location in a dead thread that never marked the stack * slot - upvalues like this are identified since they don't have `markedopen` bit set during thread traversal and closed in `clearupvals`. - * - * NOTE: The above is only true when LuauSimplerUpval is enabled. */ -LUAU_FASTFLAGVARIABLE(LuauSimplerUpval, false) -LUAU_FASTFLAGVARIABLE(LuauNoSleepBit, false) -LUAU_FASTFLAGVARIABLE(LuauEagerShrink, false) LUAU_FASTFLAGVARIABLE(LuauFasterSweep, false) #define GC_SWEEPPAGESTEPCOST 16 @@ -408,14 +403,11 @@ static void traversestack(global_State* g, lua_State* l) stringmark(l->namecall); for (StkId o = l->stack; o < l->top; o++) markvalue(g, o); - if (FFlag::LuauSimplerUpval) + for (UpVal* uv = l->openupval; uv; uv = uv->u.open.threadnext) { - for (UpVal* uv = l->openupval; uv; uv = uv->u.open.threadnext) - { - LUAU_ASSERT(upisopen(uv)); - uv->markedopen = 1; - markobject(g, uv); - } + LUAU_ASSERT(upisopen(uv)); + uv->markedopen = 1; + markobject(g, uv); } } @@ -426,8 +418,29 @@ static void clearstack(lua_State* l) setnilvalue(o); } -// TODO: pull function definition here when FFlag::LuauEagerShrink is removed -static void shrinkstack(lua_State* L); +static void shrinkstack(lua_State* L) +{ + // compute used stack - note that we can't use th->top if we're in the middle of vararg call + StkId lim = L->top; + for (CallInfo* ci = L->base_ci; ci <= L->ci; ci++) + { + LUAU_ASSERT(ci->top <= L->stack_last); + if (lim < ci->top) + lim = ci->top; + } + + // shrink stack and callinfo arrays if we aren't using most of the space + int ci_used = cast_int(L->ci - L->base_ci); // number of `ci' in use + int s_used = cast_int(lim - L->stack); // part of stack in use + if (L->size_ci > LUAI_MAXCALLS) // handling overflow? + return; // do not touch the stacks + if (3 * ci_used < L->size_ci && 2 * BASIC_CI_SIZE < L->size_ci) + luaD_reallocCI(L, L->size_ci / 2); // still big enough... + condhardstacktests(luaD_reallocCI(L, ci_used + 1)); + if (3 * s_used < L->stacksize && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) + luaD_reallocstack(L, L->stacksize / 2); // still big enough... + condhardstacktests(luaD_reallocstack(L, s_used)); +} /* ** traverse one gray object, turning it to black. @@ -460,37 +473,56 @@ static size_t propagatemark(global_State* g) lua_State* th = gco2th(o); g->gray = th->gclist; - LUAU_ASSERT(!luaC_threadsleeping(th)); + bool active = th->isactive || th == th->global->mainthread; - // threads that are executing and the main thread remain gray - bool active = luaC_threadactive(th) || th == th->global->mainthread; - - // TODO: Refactor this logic after LuauNoSleepBit is removed - if (!active && g->gcstate == GCSpropagate) + if (FFlag::LuauBetterThreadMark) { traversestack(g, th); - clearstack(th); - if (!FFlag::LuauNoSleepBit) - l_setbit(th->stackstate, THREAD_SLEEPINGBIT); + // active threads will need to be rescanned later to mark new stack writes so we mark them gray again + if (active) + { + th->gclist = g->grayagain; + g->grayagain = o; + + black2gray(o); + } + + // the stack needs to be cleared after the last modification of the thread state before sweep begins + // if the thread is inactive, we might not see the thread in this cycle so we must clear it now + if (!active || g->gcstate == GCSatomic) + clearstack(th); + + // we could shrink stack at any time but we opt to do it during initial mark to do that just once per cycle + if (g->gcstate == GCSpropagate) + shrinkstack(th); } else { - th->gclist = g->grayagain; - g->grayagain = o; - - black2gray(o); - - traversestack(g, th); - - // final traversal? - if (g->gcstate == GCSatomic) + // TODO: Refactor this logic! + if (!active && g->gcstate == GCSpropagate) + { + traversestack(g, th); clearstack(th); - } + } + else + { + th->gclist = g->grayagain; + g->grayagain = o; - // we could shrink stack at any time but we opt to skip it during atomic since it's redundant to do that more than once per cycle - if (FFlag::LuauEagerShrink && g->gcstate != GCSatomic) - shrinkstack(th); + black2gray(o); + + traversestack(g, th); + + // final traversal? + if (g->gcstate == GCSatomic) + clearstack(th); + } + + // we could shrink stack at any time but we opt to skip it during atomic since it's redundant to do that more than once per cycle + if (g->gcstate != GCSatomic) + shrinkstack(th); + } return sizeof(lua_State) + sizeof(TValue) * th->stacksize + sizeof(CallInfo) * th->size_ci; } @@ -593,30 +625,6 @@ static size_t cleartable(lua_State* L, GCObject* l) return work; } -static void shrinkstack(lua_State* L) -{ - // compute used stack - note that we can't use th->top if we're in the middle of vararg call - StkId lim = L->top; - for (CallInfo* ci = L->base_ci; ci <= L->ci; ci++) - { - LUAU_ASSERT(ci->top <= L->stack_last); - if (lim < ci->top) - lim = ci->top; - } - - // shrink stack and callinfo arrays if we aren't using most of the space - int ci_used = cast_int(L->ci - L->base_ci); // number of `ci' in use - int s_used = cast_int(lim - L->stack); // part of stack in use - if (L->size_ci > LUAI_MAXCALLS) // handling overflow? - return; // do not touch the stacks - if (3 * ci_used < L->size_ci && 2 * BASIC_CI_SIZE < L->size_ci) - luaD_reallocCI(L, L->size_ci / 2); // still big enough... - condhardstacktests(luaD_reallocCI(L, ci_used + 1)); - if (3 * s_used < L->stacksize && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) - luaD_reallocstack(L, L->stacksize / 2); // still big enough... - condhardstacktests(luaD_reallocstack(L, s_used)); -} - static void freeobj(lua_State* L, GCObject* o, lua_Page* page) { switch (o->gch.tt) @@ -669,21 +677,6 @@ static void shrinkbuffersfull(lua_State* L) static bool deletegco(void* context, lua_Page* page, GCObject* gco) { - // we are in the process of deleting everything - // threads with open upvalues will attempt to close them all on removal - // but those upvalues might point to stack values that were already deleted - if (!FFlag::LuauSimplerUpval && gco->gch.tt == LUA_TTHREAD) - { - lua_State* th = gco2th(gco); - - while (UpVal* uv = th->openupval) - { - luaF_unlinkupval(uv); - // close the upvalue without copying the dead data so that luaF_freeupval will not unlink again - uv->v = &uv->u.value; - } - } - lua_State* L = (lua_State*)context; freeobj(L, gco, page); return true; @@ -701,7 +694,6 @@ void luaC_freeall(lua_State* L) LUAU_ASSERT(g->strt.hash[i] == NULL); LUAU_ASSERT(L->global->strt.nuse == 0); - LUAU_ASSERT(g->strbufgc == NULL); } static void markmt(global_State* g) @@ -829,15 +821,12 @@ static size_t atomic(lua_State* L) g->gcmetrics.currcycle.atomictimeclear += recordGcDeltaTime(currts); #endif - if (FFlag::LuauSimplerUpval) - { - // close orphaned live upvalues of dead threads and clear dead upvalues - work += clearupvals(L); + // close orphaned live upvalues of dead threads and clear dead upvalues + work += clearupvals(L); #ifdef LUAI_GCMETRICS - g->gcmetrics.currcycle.atomictimeupval += recordGcDeltaTime(currts); + g->gcmetrics.currcycle.atomictimeupval += recordGcDeltaTime(currts); #endif - } // flip current white g->currentwhite = cast_byte(otherwhite(g)); @@ -857,20 +846,6 @@ static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco) int alive = (gco->gch.marked ^ WHITEBITS) & deadmask; - if (gco->gch.tt == LUA_TTHREAD) - { - lua_State* th = gco2th(gco); - - if (alive) - { - if (!FFlag::LuauNoSleepBit) - resetbit(th->stackstate, THREAD_SLEEPINGBIT); - - if (!FFlag::LuauEagerShrink) - shrinkstack(th); - } - } - if (alive) { LUAU_ASSERT(!isdead(g, gco)); @@ -896,8 +871,6 @@ static int sweepgcopage(lua_State* L, lua_Page* page) if (FFlag::LuauFasterSweep) { - LUAU_ASSERT(FFlag::LuauNoSleepBit && FFlag::LuauEagerShrink); - global_State* g = L->global; int deadmask = otherwhite(g); @@ -1183,7 +1156,7 @@ void luaC_fullgc(lua_State* L) startGcCycleMetrics(g); #endif - if (FFlag::LuauSimplerUpval ? keepinvariant(g) : g->gcstate <= GCSatomic) + if (keepinvariant(g)) { // reset sweep marks to sweep all elements (returning them to white) g->sweepgcopage = g->allgcopages; @@ -1201,14 +1174,11 @@ void luaC_fullgc(lua_State* L) gcstep(L, SIZE_MAX); } - if (FFlag::LuauSimplerUpval) + // clear markedopen bits for all open upvalues; these might be stuck from half-finished mark prior to full gc + for (UpVal* uv = g->uvhead.u.open.next; uv != &g->uvhead; uv = uv->u.open.next) { - // clear markedopen bits for all open upvalues; these might be stuck from half-finished mark prior to full gc - for (UpVal* uv = g->uvhead.u.open.next; uv != &g->uvhead; uv = uv->u.open.next) - { - LUAU_ASSERT(upisopen(uv)); - uv->markedopen = 0; - } + LUAU_ASSERT(upisopen(uv)); + uv->markedopen = 0; } #ifdef LUAI_GCMETRICS @@ -1245,16 +1215,6 @@ void luaC_fullgc(lua_State* L) #endif } -void luaC_barrierupval(lua_State* L, GCObject* v) -{ - LUAU_ASSERT(!FFlag::LuauSimplerUpval); - global_State* g = L->global; - LUAU_ASSERT(iswhite(v) && !isdead(g, v)); - - if (keepinvariant(g)) - reallymarkobject(g, v); -} - void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v) { global_State* g = L->global; @@ -1346,29 +1306,6 @@ int64_t luaC_allocationrate(lua_State* L) return int64_t((g->gcstats.atomicstarttotalsizebytes - g->gcstats.endtotalsizebytes) / duration); } -void luaC_wakethread(lua_State* L) -{ - LUAU_ASSERT(!FFlag::LuauNoSleepBit); - if (!luaC_threadsleeping(L)) - return; - - global_State* g = L->global; - - resetbit(L->stackstate, THREAD_SLEEPINGBIT); - - if (keepinvariant(g)) - { - GCObject* o = obj2gco(L); - - LUAU_ASSERT(isblack(o)); - - L->gclist = g->grayagain; - g->grayagain = o; - - black2gray(o); - } -} - const char* luaC_statename(int state) { switch (state) diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 69379c8..51216bd 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -6,8 +6,6 @@ #include "lobject.h" #include "lstate.h" -LUAU_FASTFLAG(LuauNoSleepBit) - /* ** Default settings for GC tunables (settable via lua_gc) */ @@ -75,14 +73,6 @@ LUAU_FASTFLAG(LuauNoSleepBit) #define luaC_white(g) cast_to(uint8_t, ((g)->currentwhite) & WHITEBITS) -// Thread stack states -// TODO: Remove with FFlag::LuauNoSleepBit and replace with lua_State::threadactive -#define THREAD_ACTIVEBIT 0 // thread is currently active -#define THREAD_SLEEPINGBIT 1 // thread is not executing and stack should not be modified - -#define luaC_threadactive(L) (testbit((L)->stackstate, THREAD_ACTIVEBIT)) -#define luaC_threadsleeping(L) (testbit((L)->stackstate, THREAD_SLEEPINGBIT)) - #define luaC_checkGC(L) \ { \ condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK)); \ @@ -121,25 +111,10 @@ LUAU_FASTFLAG(LuauNoSleepBit) luaC_barrierf(L, obj2gco(p), obj2gco(o)); \ } -// TODO: Remove with FFlag::LuauSimplerUpval -#define luaC_upvalbarrier(L, uv, tv) \ - { \ - if (iscollectable(tv) && iswhite(gcvalue(tv)) && (!(uv) || (uv)->v != &(uv)->u.value)) \ - luaC_barrierupval(L, gcvalue(tv)); \ - } - #define luaC_threadbarrier(L) \ { \ - if (FFlag::LuauNoSleepBit) \ - { \ - if (isblack(obj2gco(L))) \ - luaC_barrierback(L, obj2gco(L), &L->gclist); \ - } \ - else \ - { \ - if (luaC_threadsleeping(L)) \ - luaC_wakethread(L); \ - } \ + if (isblack(obj2gco(L))) \ + luaC_barrierback(L, obj2gco(L), &L->gclist); \ } #define luaC_init(L, o, tt_) \ @@ -154,12 +129,10 @@ LUAI_FUNC size_t luaC_step(lua_State* L, bool assist); LUAI_FUNC void luaC_fullgc(lua_State* L); LUAI_FUNC void luaC_initobj(lua_State* L, GCObject* o, uint8_t tt); LUAI_FUNC void luaC_upvalclosed(lua_State* L, UpVal* uv); -LUAI_FUNC void luaC_barrierupval(lua_State* L, GCObject* v); LUAI_FUNC void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v); LUAI_FUNC void luaC_barriertable(lua_State* L, Table* t, GCObject* v); LUAI_FUNC void luaC_barrierback(lua_State* L, GCObject* o, GCObject** gclist); LUAI_FUNC void luaC_validate(lua_State* L); LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); LUAI_FUNC int64_t luaC_allocationrate(lua_State* L); -LUAI_FUNC void luaC_wakethread(lua_State* L); LUAI_FUNC const char* luaC_statename(int state); diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 778e22b..48aaf94 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -332,8 +332,6 @@ typedef struct UpVal // thread linked list (when open) struct UpVal* threadnext; - // note: this is the location of a pointer to this upvalue in the previous element that can be either an UpVal or a lua_State - struct UpVal** threadprev; // TODO: remove with FFlag::LuauSimplerUpval } open; } u; } UpVal; diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index e1cb2ab..fdd7fc2 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -10,8 +10,6 @@ #include "ldo.h" #include "ldebug.h" -LUAU_FASTFLAG(LuauSimplerUpval) - /* ** Main thread combines a thread state and the global state */ @@ -79,7 +77,7 @@ static void preinit_state(lua_State* L, global_State* g) L->namecall = NULL; L->cachedslot = 0; L->singlestep = false; - L->stackstate = 0; + L->isactive = false; L->activememcat = 0; L->userdata = NULL; } @@ -89,7 +87,6 @@ static void close_state(lua_State* L) global_State* g = L->global; luaF_close(L, L->stack); // close all upvalues for this thread luaC_freeall(L); // collect all objects - LUAU_ASSERT(g->strbufgc == NULL); LUAU_ASSERT(g->strt.nuse == 0); luaM_freearray(L, L->global->strt.hash, L->global->strt.size, TString*, 0); freestack(L, L); @@ -121,11 +118,6 @@ lua_State* luaE_newthread(lua_State* L) void luaE_freethread(lua_State* L, lua_State* L1, lua_Page* page) { - if (!FFlag::LuauSimplerUpval) - { - luaF_close(L1, L1->stack); // close all upvalues for this thread - LUAU_ASSERT(L1->openupval == NULL); - } global_State* g = L->global; if (g->cb.userthread) g->cb.userthread(NULL, L1); @@ -199,7 +191,6 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->gray = NULL; g->grayagain = NULL; g->weak = NULL; - g->strbufgc = NULL; g->totalbytes = sizeof(LG); g->gcgoal = LUAI_GCGOAL; g->gcstepmul = LUAI_GCSTEPMUL; diff --git a/VM/src/lstate.h b/VM/src/lstate.h index df47ce7..0654446 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -167,8 +167,6 @@ typedef struct global_State GCObject* grayagain; // list of objects to be traversed atomically GCObject* weak; // list of weak tables (to be cleared) - TString* strbufgc; // list of all string buffer objects; TODO: remove with LuauNoStrbufLink - size_t GCthreshold; // when totalbytes > GCthreshold, run GC step size_t totalbytes; // number of bytes currently allocated @@ -222,8 +220,8 @@ struct lua_State uint8_t status; uint8_t activememcat; // memory category that is used for new GC object allocations - uint8_t stackstate; + bool isactive; // thread is currently executing, stack may be mutated without barriers bool singlestep; // call debugstep hook after each instruction diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index f43d03b..e57f6c2 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -7,8 +7,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauNoStrbufLink, false) - unsigned int luaS_hash(const char* str, size_t len) { // Note that this hashing algorithm is replicated in BytecodeBuilder.cpp, BytecodeBuilder::getStringHash @@ -96,51 +94,18 @@ static TString* newlstr(lua_State* L, const char* str, size_t l, unsigned int h) return ts; } -static void unlinkstrbuf(lua_State* L, TString* ts) -{ - LUAU_ASSERT(!FFlag::LuauNoStrbufLink); - global_State* g = L->global; - - TString** p = &g->strbufgc; - - while (TString* curr = *p) - { - if (curr == ts) - { - *p = curr->next; - return; - } - else - { - p = &curr->next; - } - } - - LUAU_ASSERT(!"failed to find string buffer"); -} - TString* luaS_bufstart(lua_State* L, size_t size) { if (size > MAXSSIZE) luaM_toobig(L); - global_State* g = L->global; - TString* ts = luaM_newgco(L, TString, sizestring(size), L->activememcat); luaC_init(L, ts, LUA_TSTRING); ts->atom = ATOM_UNDEF; ts->hash = 0; // computed in luaS_buffinish ts->len = unsigned(size); - if (FFlag::LuauNoStrbufLink) - { - ts->next = NULL; - } - else - { - ts->next = g->strbufgc; - g->strbufgc = ts; - } + ts->next = NULL; return ts; } @@ -164,10 +129,7 @@ TString* luaS_buffinish(lua_State* L, TString* ts) } } - if (FFlag::LuauNoStrbufLink) - LUAU_ASSERT(ts->next == NULL); - else - unlinkstrbuf(L, ts); + LUAU_ASSERT(ts->next == NULL); ts->hash = h; ts->data[ts->len] = '\0'; // ending 0 @@ -222,21 +184,10 @@ static bool unlinkstr(lua_State* L, TString* ts) void luaS_free(lua_State* L, TString* ts, lua_Page* page) { - if (FFlag::LuauNoStrbufLink) - { - if (unlinkstr(L, ts)) - L->global->strt.nuse--; - else - LUAU_ASSERT(ts->next == NULL); // orphaned string buffer - } + if (unlinkstr(L, ts)) + L->global->strt.nuse--; else - { - // Unchain from the string table - if (!unlinkstr(L, ts)) - unlinkstrbuf(L, ts); // An unlikely scenario when we have a string buffer on our hands - else - L->global->strt.nuse--; - } + LUAU_ASSERT(ts->next == NULL); // orphaned string buffer luaM_freegco(L, ts, sizestring(ts->len), ts->memcat, page); } diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index e348891..c3c744b 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,9 +16,6 @@ #include -LUAU_FASTFLAG(LuauSimplerUpval) -LUAU_FASTFLAG(LuauNoSleepBit) - // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -69,11 +66,6 @@ LUAU_FASTFLAG(LuauNoSleepBit) #define VM_PATCH_C(pc, slot) *const_cast(pc) = ((uint8_t(slot) << 24) | (0x00ffffffu & *(pc))) #define VM_PATCH_E(pc, slot) *const_cast(pc) = ((uint32_t(slot) << 8) | (0x000000ffu & *(pc))) -// NOTE: If debugging the Luau code, disable this macro to prevent timeouts from -// occurring when tracing code in Visual Studio / XCode -#if 0 -#define VM_INTERRUPT() -#else #define VM_INTERRUPT() \ { \ void (*interrupt)(lua_State*, int) = L->global->cb.interrupt; \ @@ -87,7 +79,6 @@ LUAU_FASTFLAG(LuauNoSleepBit) } \ } \ } -#endif #define VM_DISPATCH_OP(op) &&CASE_##op @@ -150,32 +141,6 @@ LUAU_NOINLINE static void luau_prepareFORN(lua_State* L, StkId plimit, StkId pst luaG_forerror(L, pstep, "step"); } -LUAU_NOINLINE static bool luau_loopFORG(lua_State* L, int a, int c) -{ - // note: it's safe to push arguments past top for complicated reasons (see top of the file) - StkId ra = &L->base[a]; - LUAU_ASSERT(ra + 3 <= L->top); - - setobjs2s(L, ra + 3 + 2, ra + 2); - setobjs2s(L, ra + 3 + 1, ra + 1); - setobjs2s(L, ra + 3, ra); - - L->top = ra + 3 + 3; // func. + 2 args (state and index) - LUAU_ASSERT(L->top <= L->stack_last); - - luaD_call(L, ra + 3, c); - L->top = L->ci->top; - - // recompute ra since stack might have been reallocated - ra = &L->base[a]; - LUAU_ASSERT(ra < L->top); - - // copy first variable back into the iteration index - setobjs2s(L, ra + 2, ra + 3); - - return ttisnil(ra + 2); -} - // calls a C function f with no yielding support; optionally save one resulting value to the res register // the function and arguments have to already be pushed to L->top LUAU_NOINLINE static void luau_callTM(lua_State* L, int nparams, int res) @@ -316,9 +281,8 @@ static void luau_execute(lua_State* L) const Instruction* pc; LUAU_ASSERT(isLua(L->ci)); - LUAU_ASSERT(luaC_threadactive(L)); - LUAU_ASSERT(!luaC_threadsleeping(L)); - LUAU_ASSERT(!FFlag::LuauNoSleepBit || !isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black + LUAU_ASSERT(L->isactive); + LUAU_ASSERT(!isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black pc = L->ci->savedpc; cl = clvalue(L->ci->func); @@ -498,8 +462,6 @@ static void luau_execute(lua_State* L) setobj(L, uv->v, ra); luaC_barrier(L, uv, ra); - if (!FFlag::LuauSimplerUpval) - luaC_upvalbarrier(L, uv, uv->v); VM_NEXT(); } @@ -2403,52 +2365,8 @@ static void luau_execute(lua_State* L) VM_CASE(LOP_DEP_FORGLOOP_INEXT) { - VM_INTERRUPT(); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - - // fast-path: ipairs/inext - if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) - { - Table* h = hvalue(ra + 1); - int index = int(reinterpret_cast(pvalue(ra + 2))); - - // if 1-based index of the last iteration is in bounds, this means 0-based index of the current iteration is in bounds - if (unsigned(index) < unsigned(h->sizearray)) - { - // note that nil elements inside the array terminate the traversal - if (!ttisnil(&h->array[index])) - { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); - setnvalue(ra + 3, double(index + 1)); - setobj2s(L, ra + 4, &h->array[index]); - - pc += LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } - else - { - // fallthrough to exit - VM_NEXT(); - } - } - else - { - // fallthrough to exit - VM_NEXT(); - } - } - else - { - // slow-path; can call Lua/C generators - bool stop; - VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), 2)); - - pc += stop ? 0 : LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } + LUAU_ASSERT(!"Unsupported deprecated opcode"); + LUAU_UNREACHABLE(); } VM_CASE(LOP_FORGPREP_NEXT) @@ -2475,68 +2393,8 @@ static void luau_execute(lua_State* L) VM_CASE(LOP_DEP_FORGLOOP_NEXT) { - VM_INTERRUPT(); - Instruction insn = *pc++; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - - // fast-path: pairs/next - if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) - { - Table* h = hvalue(ra + 1); - int index = int(reinterpret_cast(pvalue(ra + 2))); - - int sizearray = h->sizearray; - int sizenode = 1 << h->lsizenode; - - // first we advance index through the array portion - while (unsigned(index) < unsigned(sizearray)) - { - if (!ttisnil(&h->array[index])) - { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); - setnvalue(ra + 3, double(index + 1)); - setobj2s(L, ra + 4, &h->array[index]); - - pc += LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } - - index++; - } - - // then we advance index through the hash portion - while (unsigned(index - sizearray) < unsigned(sizenode)) - { - LuaNode* n = &h->node[index - sizearray]; - - if (!ttisnil(gval(n))) - { - setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); - getnodekey(L, ra + 3, n); - setobj2s(L, ra + 4, gval(n)); - - pc += LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } - - index++; - } - - // fallthrough to exit - VM_NEXT(); - } - else - { - // slow-path; can call Lua/C generators - bool stop; - VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), 2)); - - pc += stop ? 0 : LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } + LUAU_ASSERT(!"Unsupported deprecated opcode"); + LUAU_UNREACHABLE(); } VM_CASE(LOP_GETVARARGS) @@ -2750,92 +2608,14 @@ static void luau_execute(lua_State* L) VM_CASE(LOP_DEP_JUMPIFEQK) { - Instruction insn = *pc++; - uint32_t aux = *pc; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - TValue* rb = VM_KV(aux); - - // Note that all jumps below jump by 1 in the "false" case to skip over aux - if (ttype(ra) == ttype(rb)) - { - switch (ttype(ra)) - { - case LUA_TNIL: - pc += LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - case LUA_TBOOLEAN: - pc += bvalue(ra) == bvalue(rb) ? LUAU_INSN_D(insn) : 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - case LUA_TNUMBER: - pc += nvalue(ra) == nvalue(rb) ? LUAU_INSN_D(insn) : 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - case LUA_TSTRING: - pc += gcvalue(ra) == gcvalue(rb) ? LUAU_INSN_D(insn) : 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - default:; - } - - LUAU_ASSERT(!"Constant is expected to be of primitive type"); - } - else - { - pc += 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } + LUAU_ASSERT(!"Unsupported deprecated opcode"); + LUAU_UNREACHABLE(); } VM_CASE(LOP_DEP_JUMPIFNOTEQK) { - Instruction insn = *pc++; - uint32_t aux = *pc; - StkId ra = VM_REG(LUAU_INSN_A(insn)); - TValue* rb = VM_KV(aux); - - // Note that all jumps below jump by 1 in the "true" case to skip over aux - if (ttype(ra) == ttype(rb)) - { - switch (ttype(ra)) - { - case LUA_TNIL: - pc += 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - case LUA_TBOOLEAN: - pc += bvalue(ra) != bvalue(rb) ? LUAU_INSN_D(insn) : 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - case LUA_TNUMBER: - pc += nvalue(ra) != nvalue(rb) ? LUAU_INSN_D(insn) : 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - case LUA_TSTRING: - pc += gcvalue(ra) != gcvalue(rb) ? LUAU_INSN_D(insn) : 1; - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - - default:; - } - - LUAU_ASSERT(!"Constant is expected to be of primitive type"); - } - else - { - pc += LUAU_INSN_D(insn); - LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); - VM_NEXT(); - } + LUAU_ASSERT(!"Unsupported deprecated opcode"); + LUAU_UNREACHABLE(); } VM_CASE(LOP_FASTCALL1) diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 2b650fa..4b21c44 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -1,9 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Fixture.h" #include "Luau/AstQuery.h" +#include "AstQueryDsl.h" #include "doctest.h" +#include "Fixture.h" using namespace Luau; diff --git a/tests/AstQueryDsl.cpp b/tests/AstQueryDsl.cpp new file mode 100644 index 0000000..0cf28f3 --- /dev/null +++ b/tests/AstQueryDsl.cpp @@ -0,0 +1,45 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "AstQueryDsl.h" + +namespace Luau +{ + +FindNthOccurenceOf::FindNthOccurenceOf(Nth nth) + : requestedNth(nth) +{ +} + +bool FindNthOccurenceOf::checkIt(AstNode* n) +{ + if (theNode) + return false; + + if (n->classIndex == requestedNth.classIndex) + { + // Human factor: the requestedNth starts from 1 because of the term `nth`. + if (currentOccurrence + 1 != requestedNth.nth) + ++currentOccurrence; + else + theNode = n; + } + + return !theNode; // once found, returns false and stops traversal +} + +bool FindNthOccurenceOf::visit(AstNode* n) +{ + return checkIt(n); +} + +bool FindNthOccurenceOf::visit(AstType* t) +{ + return checkIt(t); +} + +bool FindNthOccurenceOf::visit(AstTypePack* t) +{ + return checkIt(t); +} + +} diff --git a/tests/AstQueryDsl.h b/tests/AstQueryDsl.h new file mode 100644 index 0000000..6bf3bd3 --- /dev/null +++ b/tests/AstQueryDsl.h @@ -0,0 +1,83 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Common.h" + +#include +#include + +namespace Luau +{ + +struct Nth +{ + int classIndex; + int nth; +}; + +template +Nth nth(int nth = 1) +{ + static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); + LUAU_ASSERT(nth > 0); // Did you mean to use `nth(1)`? + + return Nth{T::ClassIndex(), nth}; +} + +struct FindNthOccurenceOf : public AstVisitor +{ + Nth requestedNth; + int currentOccurrence = 0; + AstNode* theNode = nullptr; + + FindNthOccurenceOf(Nth nth); + + bool checkIt(AstNode* n); + + bool visit(AstNode* n) override; + bool visit(AstType* n) override; + bool visit(AstTypePack* n) override; +}; + +/** DSL querying of the AST. + * + * Given an AST, one can query for a particular node directly without having to manually unwrap the tree, for example: + * + * ``` + * if a and b then + * print(a + b) + * end + * + * function f(x, y) + * return x + y + * end + * ``` + * + * There are numerous ways to access the second AstExprBinary. + * 1. Luau::query(block, {nth(), nth()}) + * 2. Luau::query(Luau::query(block)) + * 3. Luau::query(block, {nth(2)}) + */ +template +T* query(AstNode* node, const std::vector& nths = {nth(N)}) +{ + static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); + + // If a nested query call fails to find the node in question, subsequent calls can propagate rather than trying to do more. + // This supports `query(query(...))` + + for (Nth nth : nths) + { + if (!node) + return nullptr; + + FindNthOccurenceOf finder{nth}; + node->visit(&finder); + node = finder.theNode; + } + + return node ? node->as() : nullptr; +} + +} diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 41f3a90..758fb44 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -1,9 +1,16 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/AssemblyBuilderX64.h" #include "Luau/CodeAllocator.h" +#include "Luau/CodeBlockUnwind.h" +#include "Luau/UnwindBuilder.h" +#include "Luau/UnwindBuilderDwarf2.h" +#include "Luau/UnwindBuilderWin.h" #include "doctest.h" +#include +#include + #include using namespace Luau::CodeGen; @@ -41,8 +48,8 @@ TEST_CASE("CodeAllocation") TEST_CASE("CodeAllocationFailure") { - size_t blockSize = 16384; - size_t maxTotalSize = 32768; + size_t blockSize = 3000; + size_t maxTotalSize = 7000; CodeAllocator allocator(blockSize, maxTotalSize); uint8_t* nativeData; @@ -50,13 +57,13 @@ TEST_CASE("CodeAllocationFailure") uint8_t* nativeEntry; std::vector code; - code.resize(18000); + code.resize(4000); // allocation has to fit in a block REQUIRE(!allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); // each allocation exhausts a block, so third allocation fails - code.resize(10000); + code.resize(2000); REQUIRE(allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); REQUIRE(allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); REQUIRE(!allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); @@ -120,19 +127,84 @@ TEST_CASE("CodeAllocationWithUnwindCallbacks") CHECK(info.destroyCalled); } -#if defined(__x86_64__) || defined(_M_X64) -TEST_CASE("GeneratedCodeExecution") +TEST_CASE("WindowsUnwindCodesX64") { + UnwindBuilderWin unwind; + + unwind.start(); + unwind.spill(16, rdx); + unwind.spill(8, rcx); + unwind.save(rdi); + unwind.save(rsi); + unwind.save(rbx); + unwind.save(rbp); + unwind.save(r12); + unwind.save(r13); + unwind.save(r14); + unwind.save(r15); + unwind.allocStack(72); + unwind.setupFrameReg(rbp, 48); + unwind.finish(); + + std::vector data; + data.resize(unwind.getSize()); + unwind.finalize(data.data(), nullptr, 0); + + std::vector expected{0x01, 0x23, 0x0a, 0x35, 0x23, 0x33, 0x1e, 0x82, 0x1a, 0xf0, 0x18, 0xe0, 0x16, 0xd0, 0x14, 0xc0, 0x12, 0x50, 0x10, + 0x30, 0x0e, 0x60, 0x0c, 0x70}; + + REQUIRE(data.size() == expected.size()); + CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); +} + +TEST_CASE("Dwarf2UnwindCodesX64") +{ + UnwindBuilderDwarf2 unwind; + + unwind.start(); + unwind.save(rdi); + unwind.save(rsi); + unwind.save(rbx); + unwind.save(rbp); + unwind.save(r12); + unwind.save(r13); + unwind.save(r14); + unwind.save(r15); + unwind.allocStack(72); + unwind.setupFrameReg(rbp, 48); + unwind.finish(); + + std::vector data; + data.resize(unwind.getSize()); + unwind.finalize(data.data(), nullptr, 0); + + std::vector expected{0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x78, 0x10, 0x0c, 0x07, 0x08, 0x05, 0x10, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x0e, 0x10, 0x85, 0x02, 0x02, 0x02, 0x0e, 0x18, 0x84, 0x03, 0x02, 0x02, 0x0e, 0x20, 0x83, + 0x04, 0x02, 0x02, 0x0e, 0x28, 0x86, 0x05, 0x02, 0x02, 0x0e, 0x30, 0x8c, 0x06, 0x02, 0x02, 0x0e, 0x38, 0x8d, 0x07, 0x02, 0x02, 0x0e, 0x40, + 0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + + REQUIRE(data.size() == expected.size()); + CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); +} + +#if defined(__x86_64__) || defined(_M_X64) + #if defined(_WIN32) - // Windows x64 ABI - constexpr RegisterX64 rArg1 = rcx; - constexpr RegisterX64 rArg2 = rdx; +// Windows x64 ABI +constexpr RegisterX64 rArg1 = rcx; +constexpr RegisterX64 rArg2 = rdx; #else - // System V AMD64 ABI - constexpr RegisterX64 rArg1 = rdi; - constexpr RegisterX64 rArg2 = rsi; +// System V AMD64 ABI +constexpr RegisterX64 rArg1 = rdi; +constexpr RegisterX64 rArg2 = rsi; #endif +constexpr RegisterX64 rNonVol1 = r12; +constexpr RegisterX64 rNonVol2 = rbx; + +TEST_CASE("GeneratedCodeExecution") +{ AssemblyBuilderX64 build(/* logText= */ false); build.mov(rax, rArg1); @@ -157,6 +229,90 @@ TEST_CASE("GeneratedCodeExecution") int64_t result = f(10, 20); CHECK(result == 210); } + +void throwing(int64_t arg) +{ + CHECK(arg == 25); + + throw std::runtime_error("testing"); +} + +TEST_CASE("GeneratedCodeExecutionWithThrow") +{ + AssemblyBuilderX64 build(/* logText= */ false); + +#if defined(_WIN32) + std::unique_ptr unwind = std::make_unique(); +#else + std::unique_ptr unwind = std::make_unique(); +#endif + + unwind->start(); + + // Prologue + build.push(rNonVol1); + unwind->save(rNonVol1); + build.push(rNonVol2); + unwind->save(rNonVol2); + build.push(rbp); + unwind->save(rbp); + + int stackSize = 32; + int localsSize = 16; + + build.sub(rsp, stackSize + localsSize); + unwind->allocStack(stackSize + localsSize); + + build.lea(rbp, qword[rsp + stackSize]); + unwind->setupFrameReg(rbp, stackSize); + + unwind->finish(); + + // Body + build.mov(rNonVol1, rArg1); + build.mov(rNonVol2, rArg2); + + build.add(rNonVol1, 15); + build.mov(rArg1, rNonVol1); + build.call(rNonVol2); + + // Epilogue + build.lea(rsp, qword[rbp + localsSize]); + build.pop(rbp); + build.pop(rNonVol2); + build.pop(rNonVol1); + build.ret(); + + build.finalize(); + + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + allocator.context = unwind.get(); + allocator.createBlockUnwindInfo = createBlockUnwindInfo; + allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + REQUIRE(allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData, sizeNativeData, nativeEntry)); + REQUIRE(nativeEntry); + + using FunctionType = int64_t(int64_t, void (*)(int64_t)); + FunctionType* f = (FunctionType*)nativeEntry; + + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here + try + { + f(10, throwing); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } +} + #endif TEST_SUITE_END(); diff --git a/tests/ConstraintGraphBuilder.test.cpp b/tests/ConstraintGraphBuilder.test.cpp index bbe2942..5c34e3d 100644 --- a/tests/ConstraintGraphBuilder.test.cpp +++ b/tests/ConstraintGraphBuilder.test.cpp @@ -1,8 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Fixture.h" #include "Luau/ConstraintGraphBuilder.h" +#include "Luau/NotNull.h" +#include "Luau/ToString.h" +#include "ConstraintGraphBuilderFixture.h" +#include "Fixture.h" #include "doctest.h" using namespace Luau; diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp new file mode 100644 index 0000000..1958d1d --- /dev/null +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -0,0 +1,17 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "ConstraintGraphBuilderFixture.h" + +namespace Luau +{ + +ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() + : Fixture() + , mainModule(new Module) + , cgb("MainModule", mainModule, &arena, NotNull(&moduleResolver), singletonTypes, NotNull(&ice), frontend.getGlobalScope(), &logger) + , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} +{ + BlockedTypeVar::nextIndex = 0; + BlockedTypePack::nextIndex = 0; +} + +} diff --git a/tests/ConstraintGraphBuilderFixture.h b/tests/ConstraintGraphBuilderFixture.h new file mode 100644 index 0000000..262e390 --- /dev/null +++ b/tests/ConstraintGraphBuilderFixture.h @@ -0,0 +1,27 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/ConstraintGraphBuilder.h" +#include "Luau/DcrLogger.h" +#include "Luau/TypeArena.h" +#include "Luau/Module.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +namespace Luau +{ + +struct ConstraintGraphBuilderFixture : Fixture +{ + TypeArena arena; + ModulePtr mainModule; + ConstraintGraphBuilder cgb; + DcrLogger logger; + + ScopedFastFlag forceTheFlag; + + ConstraintGraphBuilderFixture(); +}; + +} diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp index fba5782..9976bd2 100644 --- a/tests/ConstraintSolver.test.cpp +++ b/tests/ConstraintSolver.test.cpp @@ -1,12 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Fixture.h" - -#include "doctest.h" - #include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintSolver.h" +#include "ConstraintGraphBuilderFixture.h" +#include "Fixture.h" +#include "doctest.h" + using namespace Luau; static TypeId requireBinding(NotNull scope, const char* name) diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 6c4594f..3f77978 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -444,16 +444,6 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); } -ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() - : Fixture() - , mainModule(new Module) - , cgb(mainModuleName, mainModule, &arena, NotNull(&moduleResolver), singletonTypes, NotNull(&ice), frontend.getGlobalScope(), &logger) - , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} -{ - BlockedTypeVar::nextIndex = 0; - BlockedTypePack::nextIndex = 0; -} - ModuleName fromString(std::string_view name) { return ModuleName(name); @@ -516,41 +506,4 @@ void dump(const std::vector& constraints) printf("%s\n", toString(c, opts).c_str()); } -FindNthOccurenceOf::FindNthOccurenceOf(Nth nth) - : requestedNth(nth) -{ -} - -bool FindNthOccurenceOf::checkIt(AstNode* n) -{ - if (theNode) - return false; - - if (n->classIndex == requestedNth.classIndex) - { - // Human factor: the requestedNth starts from 1 because of the term `nth`. - if (currentOccurrence + 1 != requestedNth.nth) - ++currentOccurrence; - else - theNode = n; - } - - return !theNode; // once found, returns false and stops traversal -} - -bool FindNthOccurenceOf::visit(AstNode* n) -{ - return checkIt(n); -} - -bool FindNthOccurenceOf::visit(AstType* t) -{ - return checkIt(t); -} - -bool FindNthOccurenceOf::visit(AstTypePack* t) -{ - return checkIt(t); -} - } // namespace Luau diff --git a/tests/Fixture.h b/tests/Fixture.h index 03101bb..2fb4846 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -12,7 +12,6 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" -#include "Luau/DcrLogger.h" #include "IostreamOptional.h" #include "ScopedFlags.h" @@ -162,18 +161,6 @@ struct BuiltinsFixture : Fixture BuiltinsFixture(bool freeze = true, bool prepareAutocomplete = false); }; -struct ConstraintGraphBuilderFixture : Fixture -{ - TypeArena arena; - ModulePtr mainModule; - ConstraintGraphBuilder cgb; - DcrLogger logger; - - ScopedFastFlag forceTheFlag; - - ConstraintGraphBuilderFixture(); -}; - ModuleName fromString(std::string_view name); template @@ -199,76 +186,6 @@ std::optional lookupName(ScopePtr scope, const std::string& name); // Wa std::optional linearSearchForBinding(Scope* scope, const char* name); -struct Nth -{ - int classIndex; - int nth; -}; - -template -Nth nth(int nth = 1) -{ - static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); - LUAU_ASSERT(nth > 0); // Did you mean to use `nth(1)`? - - return Nth{T::ClassIndex(), nth}; -} - -struct FindNthOccurenceOf : public AstVisitor -{ - Nth requestedNth; - int currentOccurrence = 0; - AstNode* theNode = nullptr; - - FindNthOccurenceOf(Nth nth); - - bool checkIt(AstNode* n); - - bool visit(AstNode* n) override; - bool visit(AstType* n) override; - bool visit(AstTypePack* n) override; -}; - -/** DSL querying of the AST. - * - * Given an AST, one can query for a particular node directly without having to manually unwrap the tree, for example: - * - * ``` - * if a and b then - * print(a + b) - * end - * - * function f(x, y) - * return x + y - * end - * ``` - * - * There are numerous ways to access the second AstExprBinary. - * 1. Luau::query(block, {nth(), nth()}) - * 2. Luau::query(Luau::query(block)) - * 3. Luau::query(block, {nth(2)}) - */ -template -T* query(AstNode* node, const std::vector& nths = {nth(N)}) -{ - static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); - - // If a nested query call fails to find the node in question, subsequent calls can propagate rather than trying to do more. - // This supports `query(query(...))` - - for (Nth nth : nths) - { - if (!node) - return nullptr; - - FindNthOccurenceOf finder{nth}; - node->visit(&finder); - node = finder.theNode; - } - - return node ? node->as() : nullptr; -} - } // namespace Luau #define LUAU_REQUIRE_ERRORS(result) \ diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 07a0436..6da3f56 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -459,7 +459,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "thread_is_a_type") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.threadType, *requireType("co")); + CHECK("thread" == toString(requireType("co"))); } TEST_CASE_FIXTURE(BuiltinsFixture, "coroutine_resume_anything_goes") @@ -627,6 +627,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "select_with_decimal_argument_is_rounded_down // Could be flaky if the fix has regressed. TEST_CASE_FIXTURE(BuiltinsFixture, "bad_select_should_not_crash") { + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + CheckResult result = check(R"( do end local _ = function(l0,...) @@ -638,8 +640,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "bad_select_should_not_crash") )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ("Argument count mismatch. Function expects at least 1 argument, but none are specified", toString(result.errors[0])); - CHECK_EQ("Argument count mismatch. Function expects 1 argument, but none are specified", toString(result.errors[1])); + CHECK_EQ("Argument count mismatch. Function '_' expects at least 1 argument, but none are specified", toString(result.errors[0])); + CHECK_EQ("Argument count mismatch. Function 'select' expects 1 argument, but none are specified", toString(result.errors[1])); } TEST_CASE_FIXTURE(BuiltinsFixture, "select_way_out_of_range") @@ -824,12 +826,12 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_lib_self_noself") TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_definition") { CheckResult result = check(R"_( -local a, b, c = ("hey"):gmatch("(.)(.)(.)")() + local a, b, c = ("hey"):gmatch("(.)(.)(.)")() -for c in ("hey"):gmatch("(.)") do - print(c:upper()) -end -)_"); + for c in ("hey"):gmatch("(.)") do + print(c:upper()) + end + )_"); LUAU_REQUIRE_NO_ERRORS(result); } @@ -1008,6 +1010,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") { + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + ScopedFastFlag sff{"LuauSetMetaTableArgsCheck", true}; CheckResult result = check(R"( local a = {b=setmetatable} @@ -1016,8 +1020,8 @@ a:b() a:b({}) )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function expects 2 arguments, but none are specified"); - CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function expects 2 arguments, but only 1 is specified"); + CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'a.b' expects 2 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function 'a.b' expects 2 arguments, but only 1 is specified"); } TEST_CASE_FIXTURE(Fixture, "typeof_unresolved_function") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 35e67ec..bde28dc 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1732,4 +1732,56 @@ TEST_CASE_FIXTURE(Fixture, "dont_mutate_the_underlying_head_of_typepack_when_cal LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "improved_function_arg_mismatch_errors") +{ + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + + CheckResult result = check(R"( +local function foo1(a: number) end +foo1() + +local function foo2(a: number, b: string?) end +foo2() + +local function foo3(a: number, b: string?, c: any) end -- any is optional +foo3() + +string.find() + +local t = {} +function t.foo(x: number, y: string?, ...: any) end +function t:bar(x: number, y: string?) end +t.foo() + +t:bar() + +local u = { a = t } +u.a.foo() + )"); + + LUAU_REQUIRE_ERROR_COUNT(7, result); + CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo1' expects 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[1]), "Argument count mismatch. Function 'foo2' expects 1 to 2 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[2]), "Argument count mismatch. Function 'foo3' expects 1 to 3 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[3]), "Argument count mismatch. Function 'string.find' expects 2 to 4 arguments, but none are specified"); + CHECK_EQ(toString(result.errors[4]), "Argument count mismatch. Function 't.foo' expects at least 1 argument, but none are specified"); + CHECK_EQ(toString(result.errors[5]), "Argument count mismatch. Function 't.bar' expects 2 to 3 arguments, but only 1 is specified"); + CHECK_EQ(toString(result.errors[6]), "Argument count mismatch. Function 'u.a.foo' expects at least 1 argument, but none are specified"); +} + +// This might be surprising, but since 'any' became optional, unannotated functions in non-strict 'expect' 0 arguments +TEST_CASE_FIXTURE(BuiltinsFixture, "improved_function_arg_mismatch_error_nonstrict") +{ + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + + CheckResult result = check(R"( +--!nonstrict +local function foo(a, b) end +foo(string.find("hello", "e")) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Argument count mismatch. Function 'foo' expects 0 to 2 arguments, but 3 are specified"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 9ac259c..3c86777 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -782,6 +782,8 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_few") { + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + CheckResult result = check(R"( function test(a: number) return 1 @@ -794,11 +796,13 @@ wrapper(test) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function 'wrapper' expects 2 arguments, but only 1 is specified)"); } TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_many") { + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + CheckResult result = check(R"( function test2(a: number, b: string) return 1 @@ -811,7 +815,7 @@ wrapper(test2, 1, "", 3) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 3 arguments, but 4 are specified)"); + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function 'wrapper' expects 3 arguments, but 4 are specified)"); } TEST_CASE_FIXTURE(Fixture, "generic_function") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 3482b75..40ea0ca 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -69,6 +69,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "xpcall_returns_what_f_returns") CHECK("string" == toString(requireType("c"))); CHECK(expected == decorateWithTypes(code)); + + LUAU_REQUIRE_NO_ERRORS(result); } // We had a bug where if you have two type packs that looks like: diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 0a130d4..7d98b5d 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -202,6 +202,15 @@ TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag") LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "table_has_a_boolean") +{ + CheckResult result = check(R"( + local t={a=1,b=false} + )"); + + CHECK("{ a: number, b: boolean }" == toString(requireType("t"), {true})); +} + TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 97c3da4..d183f65 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2615,9 +2615,11 @@ do end TEST_CASE_FIXTURE(BuiltinsFixture, "dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar") { + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + CheckResult result = check("local x = setmetatable({})"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Argument count mismatch. Function expects 2 arguments, but only 1 is specified", toString(result.errors[0])); + CHECK_EQ("Argument count mismatch. Function 'setmetatable' expects 2 arguments, but only 1 is specified", toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "instantiate_table_cloning") @@ -2695,6 +2697,8 @@ local baz = foo[bar] TEST_CASE_FIXTURE(BuiltinsFixture, "table_simple_call") { + ScopedFastFlag luauFunctionArgMismatchDetails{"LuauFunctionArgMismatchDetails", true}; + CheckResult result = check(R"( local a = setmetatable({ x = 2 }, { __call = function(self) @@ -2706,7 +2710,7 @@ local c = a(2) -- too many arguments )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); + CHECK_EQ("Argument count mismatch. Function 'a' expects 1 argument, but 2 are specified", toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "access_index_metamethod_that_returns_variadic") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index e1dc502..e8bfb67 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1105,4 +1105,21 @@ end CHECK_EQ(*getMainModule()->astResolvedTypes.find(annotation), *ty); } +TEST_CASE_FIXTURE(Fixture, "bidirectional_checking_of_higher_order_function") +{ + CheckResult result = check(R"( + function higher(cb: (number) -> ()) end + + higher(function(n) -- no error here. n : number + local e: string = n -- error here. n /: string + end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + Location location = result.errors[0].location; + CHECK(location.begin.line == 4); + CHECK(location.end.line == 4); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index eaa8b05..7d33809 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -964,4 +964,40 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "detect_cyclic_typepacks2") LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments") +{ + ScopedFastFlag luauCallUnifyPackTails{"LuauCallUnifyPackTails", true}; + + CheckResult result = check(R"( + function foo(...: string): number + return 1 + end + + function bar(...: number): number + return foo(...) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type 'number' could not be converted into 'string'"); +} + +TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments_free") +{ + ScopedFastFlag luauCallUnifyPackTails{"LuauCallUnifyPackTails", true}; + + CheckResult result = check(R"( + function foo(...: T...): T... + return ... + end + + function bar(...: number): boolean + return foo(...) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type 'number' could not be converted into 'boolean'"); +} + TEST_SUITE_END(); diff --git a/tools/faillist.txt b/tools/faillist.txt index ef995aa..9c2df80 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -383,8 +383,6 @@ TableTests.casting_sealed_tables_with_props_into_table_with_indexer TableTests.casting_tables_with_props_into_table_with_indexer3 TableTests.casting_tables_with_props_into_table_with_indexer4 TableTests.checked_prop_too_early -TableTests.common_table_element_union_in_call -TableTests.common_table_element_union_in_call_tail TableTests.confusing_indexing TableTests.defining_a_method_for_a_builtin_sealed_table_must_fail TableTests.defining_a_method_for_a_local_sealed_table_must_fail @@ -540,7 +538,6 @@ TypeInfer.type_infer_recursion_limit_no_ice TypeInferAnyError.assign_prop_to_table_by_calling_any_yields_any TypeInferAnyError.can_get_length_of_any TypeInferAnyError.for_in_loop_iterator_is_any2 -TypeInferAnyError.for_in_loop_iterator_returns_any TypeInferAnyError.length_of_error_type_does_not_produce_an_error TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any TypeInferAnyError.union_of_types_regression_test @@ -561,7 +558,6 @@ TypeInferClasses.table_indexers_are_invariant TypeInferClasses.table_properties_are_invariant TypeInferClasses.warn_when_prop_almost_matches TypeInferClasses.we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class -TypeInferFunctions.another_indirect_function_case_where_it_is_ok_to_provide_too_many_arguments TypeInferFunctions.another_recursive_local_function TypeInferFunctions.call_o_with_another_argument_after_foo_was_quantified TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types @@ -586,13 +582,14 @@ TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer TypeInferFunctions.higher_order_function_2 TypeInferFunctions.higher_order_function_4 TypeInferFunctions.ignored_return_values +TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict +TypeInferFunctions.improved_function_arg_mismatch_errors TypeInferFunctions.inconsistent_higher_order_function TypeInferFunctions.inconsistent_return_types TypeInferFunctions.infer_anonymous_function_arguments TypeInferFunctions.infer_return_type_from_selected_overload TypeInferFunctions.infer_that_function_does_not_return_a_table TypeInferFunctions.it_is_ok_not_to_supply_enough_retvals -TypeInferFunctions.it_is_ok_to_oversaturate_a_higher_order_function_argument TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count TypeInferFunctions.no_lossy_function_type