Sync to upstream/release/512 (#330)

- Improve refinement support for unions, in particular it's now possible to implement tagged unions as a union of tables where individual branches use a string literal type for one of the fields.
- Fix `string.split` type information
- Optimize `select(_, ...)` to run in constant time (~2.7x faster on VariadicSelect benchmark)
- Improve debug line information for multi-line assignments
- Improve compilation of table literals when table keys are constant expressions/variables
- Use forward GC barrier for `setmetatable` which slightly accelerates GC progress
This commit is contained in:
Arseny Kapoulkine 2022-01-27 15:46:05 -08:00 committed by GitHub
parent 4b96f7efc1
commit 2f989fc049
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 1782 additions and 1133 deletions

View File

@ -42,6 +42,21 @@ struct ExprOrLocal
{
return expr ? expr->location : (local ? local->location : std::optional<Location>{});
}
std::optional<AstName> getName()
{
if (expr)
{
if (AstName name = getIdentifier(expr); name.value)
{
return name;
}
}
else if (local)
{
return local->name;
}
return std::nullopt;
}
private:
AstExpr* expr = nullptr;

View File

@ -13,6 +13,8 @@
#include <unordered_map>
#include <optional>
LUAU_FASTFLAG(LuauPrepopulateUnionOptionsBeforeAllocation)
namespace Luau
{
@ -58,6 +60,12 @@ struct TypeArena
template<typename T>
TypeId addType(T tv)
{
if (FFlag::LuauPrepopulateUnionOptionsBeforeAllocation)
{
if constexpr (std::is_same_v<T, UnionTypeVar>)
LUAU_ASSERT(tv.options.size() >= 2);
}
return addTV(TypeVar(std::move(tv)));
}

View File

@ -135,7 +135,8 @@ struct TypeChecker
void checkBlock(const ScopePtr& scope, const AstStatBlock& statement);
void checkBlockTypeAliases(const ScopePtr& scope, std::vector<AstStat*>& sorted);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional<TypeId> expectedType = std::nullopt);
ExprResult<TypeId> checkExpr(
const ScopePtr& scope, const AstExpr& expr, std::optional<TypeId> expectedType = std::nullopt, bool forceSingleton = false);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprLocal& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprGlobal& expr);
ExprResult<TypeId> checkExpr(const ScopePtr& scope, const AstExprVarargs& expr);
@ -160,14 +161,12 @@ struct TypeChecker
// Returns the type of the lvalue.
TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr);
// Returns both the type of the lvalue and its binding (if the caller wants to mutate the binding).
// Note: the binding may be null.
// TODO: remove second return value with FFlagLuauUpdateFunctionNameBinding
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExpr& expr);
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr);
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr);
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr);
std::pair<TypeId, TypeId*> checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr);
// Returns the type of the lvalue.
TypeId checkLValueBinding(const ScopePtr& scope, const AstExpr& expr);
TypeId checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr);
TypeId checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr);
TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr);
TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr);
TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level);
std::pair<TypeId, ScopePtr> checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr,
@ -322,8 +321,6 @@ private:
return addTV(TypeVar(tv));
}
TypeId addType(const UnionTypeVar& utv);
TypeId addTV(TypeVar&& tv);
TypePackId addTypePack(TypePackVar&& tp);
@ -349,6 +346,8 @@ public:
ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);
private:
void refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate);
std::optional<TypeId> resolveLValue(const ScopePtr& scope, const LValue& lvalue);
std::optional<TypeId> DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue);
std::optional<TypeId> resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue);

View File

@ -111,16 +111,16 @@ struct PrimitiveTypeVar
// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md
// Types for true and false
struct BoolSingleton
struct BooleanSingleton
{
bool value;
bool operator==(const BoolSingleton& rhs) const
bool operator==(const BooleanSingleton& rhs) const
{
return value == rhs.value;
}
bool operator!=(const BoolSingleton& rhs) const
bool operator!=(const BooleanSingleton& rhs) const
{
return !(*this == rhs);
}
@ -145,7 +145,7 @@ struct StringSingleton
// No type for float singletons, partly because === isn't any equalivalence on floats
// (NaN != NaN).
using SingletonVariant = Luau::Variant<BoolSingleton, StringSingleton>;
using SingletonVariant = Luau::Variant<BooleanSingleton, StringSingleton>;
struct SingletonTypeVar
{

View File

@ -85,6 +85,13 @@ public:
Unifier makeChildUnifier();
// A utility function that appends the given error to the unifier's error log.
// This allows setting a breakpoint wherever the unifier reports an error.
void reportError(TypeError error)
{
errors.push_back(error);
}
private:
bool isNonstrictMode() const;

View File

@ -14,9 +14,9 @@
LUAU_FASTFLAG(LuauUseCommittingTxnLog)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false);
LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false);
LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false);
LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false);
LUAU_FASTFLAGVARIABLE(PreferToCallFunctionsForIntersects, false);
static const std::unordered_set<std::string> kStatementStartingKeywords = {
"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
@ -194,8 +194,6 @@ static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::ve
static std::optional<TypeId> findExpectedTypeAt(const Module& module, AstNode* node, Position position)
{
LUAU_ASSERT(FFlag::LuauAutocompleteFirstArg);
auto expr = node->asExpr();
if (!expr)
return std::nullopt;
@ -266,43 +264,63 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
}
};
TypeId expectedType;
auto typeAtPosition = findExpectedTypeAt(module, node, position);
if (FFlag::LuauAutocompleteFirstArg)
if (!typeAtPosition)
return TypeCorrectKind::None;
TypeId expectedType = follow(*typeAtPosition);
if (FFlag::PreferToCallFunctionsForIntersects)
{
auto typeAtPosition = findExpectedTypeAt(module, node, position);
auto checkFunctionType = [&canUnify, &expectedType](const FunctionTypeVar* ftv) {
auto [retHead, retTail] = flatten(ftv->retType);
if (!typeAtPosition)
return TypeCorrectKind::None;
if (!retHead.empty() && canUnify(retHead.front(), expectedType))
return true;
expectedType = follow(*typeAtPosition);
// We might only have a variadic tail pack, check if the element is compatible
if (retTail)
{
if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType))
return true;
}
return false;
};
// We also want to suggest functions that return compatible result
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty); ftv && checkFunctionType(ftv))
{
return TypeCorrectKind::CorrectFunctionResult;
}
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(ty))
{
for (TypeId id : itv->parts)
{
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(id); ftv && checkFunctionType(ftv))
{
return TypeCorrectKind::CorrectFunctionResult;
}
}
}
}
else
{
auto expr = node->asExpr();
if (!expr)
return TypeCorrectKind::None;
auto it = module.astExpectedTypes.find(expr);
if (!it)
return TypeCorrectKind::None;
expectedType = follow(*it);
}
// We also want to suggest functions that return compatible result
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
{
auto [retHead, retTail] = flatten(ftv->retType);
if (!retHead.empty() && canUnify(retHead.front(), expectedType))
return TypeCorrectKind::CorrectFunctionResult;
// We might only have a variadic tail pack, check if the element is compatible
if (retTail)
// We also want to suggest functions that return compatible result
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(ty))
{
if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType))
auto [retHead, retTail] = flatten(ftv->retType);
if (!retHead.empty() && canUnify(retHead.front(), expectedType))
return TypeCorrectKind::CorrectFunctionResult;
// We might only have a variadic tail pack, check if the element is compatible
if (retTail)
{
if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType))
return TypeCorrectKind::CorrectFunctionResult;
}
}
}
@ -741,29 +759,12 @@ std::optional<const T*> returnFirstNonnullOptionOfType(const UnionTypeVar* utv)
static std::optional<bool> functionIsExpectedAt(const Module& module, AstNode* node, Position position)
{
TypeId expectedType;
auto typeAtPosition = findExpectedTypeAt(module, node, position);
if (FFlag::LuauAutocompleteFirstArg)
{
auto typeAtPosition = findExpectedTypeAt(module, node, position);
if (!typeAtPosition)
return std::nullopt;
if (!typeAtPosition)
return std::nullopt;
expectedType = follow(*typeAtPosition);
}
else
{
auto expr = node->asExpr();
if (!expr)
return std::nullopt;
auto it = module.astExpectedTypes.find(expr);
if (!it)
return std::nullopt;
expectedType = follow(*it);
}
TypeId expectedType = follow(*typeAtPosition);
if (get<FunctionTypeVar>(expectedType))
return true;

View File

@ -18,7 +18,6 @@
LUAU_FASTFLAG(LuauInferInNoCheckMode)
LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false)
LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false)
LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false)
namespace Luau
{
@ -102,8 +101,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t
generateDocumentationSymbols(globalTy, documentationSymbol);
targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol};
if (FFlag::LuauPersistDefinitionFileTypes)
persist(globalTy);
persist(globalTy);
}
for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings)
@ -113,8 +111,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t
generateDocumentationSymbols(globalTy.type, documentationSymbol);
targetScope->exportedTypeBindings[name] = globalTy;
if (FFlag::LuauPersistDefinitionFileTypes)
persist(globalTy.type);
persist(globalTy.type);
}
return LoadDefinitionFileResult{true, parseResult, checkedModule};

View File

@ -16,6 +16,8 @@ LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false)
LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300)
LUAU_FASTFLAG(LuauTypeAliasDefaults)
LUAU_FASTFLAGVARIABLE(LuauPrepopulateUnionOptionsBeforeAllocation, false)
namespace Luau
{
@ -377,14 +379,28 @@ void TypeCloner::operator()(const AnyTypeVar& t)
void TypeCloner::operator()(const UnionTypeVar& t)
{
TypeId result = dest.addType(UnionTypeVar{});
seenTypes[typeId] = result;
if (FFlag::LuauPrepopulateUnionOptionsBeforeAllocation)
{
std::vector<TypeId> options;
options.reserve(t.options.size());
UnionTypeVar* option = getMutable<UnionTypeVar>(result);
LUAU_ASSERT(option != nullptr);
for (TypeId ty : t.options)
options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState));
for (TypeId ty : t.options)
option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState));
TypeId result = dest.addType(UnionTypeVar{std::move(options)});
seenTypes[typeId] = result;
}
else
{
TypeId result = dest.addType(UnionTypeVar{});
seenTypes[typeId] = result;
UnionTypeVar* option = getMutable<UnionTypeVar>(result);
LUAU_ASSERT(option != nullptr);
for (TypeId ty : t.options)
option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState));
}
}
void TypeCloner::operator()(const IntersectionTypeVar& t)

View File

@ -10,7 +10,6 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions)
LUAU_FASTFLAG(LuauTypeAliasDefaults)
/*
@ -374,7 +373,7 @@ struct TypeVarStringifier
void operator()(TypeId, const SingletonTypeVar& stv)
{
if (const BoolSingleton* bs = Luau::get<BoolSingleton>(&stv))
if (const BooleanSingleton* bs = Luau::get<BooleanSingleton>(&stv))
state.emit(bs->value ? "true" : "false");
else if (const StringSingleton* ss = Luau::get<StringSingleton>(&stv))
{
@ -617,9 +616,7 @@ struct TypeVarStringifier
std::string saved = std::move(state.result.name);
bool needParens = FFlag::LuauOccursCheckOkWithRecursiveFunctions
? !state.cycleNames.count(el) && (get<IntersectionTypeVar>(el) || get<FunctionTypeVar>(el))
: get<IntersectionTypeVar>(el) || get<FunctionTypeVar>(el);
bool needParens = !state.cycleNames.count(el) && (get<IntersectionTypeVar>(el) || get<FunctionTypeVar>(el));
if (needParens)
state.emit("(");
@ -675,9 +672,7 @@ struct TypeVarStringifier
std::string saved = std::move(state.result.name);
bool needParens = FFlag::LuauOccursCheckOkWithRecursiveFunctions
? !state.cycleNames.count(el) && (get<UnionTypeVar>(el) || get<FunctionTypeVar>(el))
: get<UnionTypeVar>(el) || get<FunctionTypeVar>(el);
bool needParens = !state.cycleNames.count(el) && (get<UnionTypeVar>(el) || get<FunctionTypeVar>(el));
if (needParens)
state.emit("(");

View File

@ -97,7 +97,7 @@ public:
AstType* operator()(const SingletonTypeVar& stv)
{
if (const BoolSingleton* bs = get<BoolSingleton>(&stv))
if (const BooleanSingleton* bs = get<BooleanSingleton>(&stv))
return allocator->alloc<AstTypeSingletonBool>(Location(), bs->value);
else if (const StringSingleton* ss = get<StringSingleton>(&stv))
{

View File

@ -26,8 +26,6 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false)
LUAU_FASTFLAGVARIABLE(LuauGroupExpectedType, false)
LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false.
LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false)
LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false)
LUAU_FASTFLAG(LuauUseCommittingTxnLog)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false)
@ -37,6 +35,7 @@ LUAU_FASTFLAGVARIABLE(LuauLengthOnCompositeType, false)
LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false)
LUAU_FASTFLAGVARIABLE(LuauSealExports, false)
LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false)
LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions, false)
LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false)
LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false)
LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false)
@ -46,10 +45,8 @@ LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false)
LUAU_FASTFLAGVARIABLE(LuauPerModuleUnificationCache, false)
LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false)
LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false)
LUAU_FASTFLAGVARIABLE(LuauFixRecursiveMetatableCall, false)
LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false)
LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false)
LUAU_FASTFLAGVARIABLE(LuauUpdateFunctionNameBinding, false)
namespace Luau
{
@ -1139,33 +1136,25 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco
}
else
{
auto [leftType, leftTypeBinding] = checkLValueBinding(scope, *function.name);
TypeId leftType = checkLValueBinding(scope, *function.name);
checkFunctionBody(funScope, ty, *function.func);
unify(ty, leftType, function.location);
if (FFlag::LuauUpdateFunctionNameBinding)
{
LUAU_ASSERT(function.name->is<AstExprIndexName>() || function.name->is<AstExprError>());
LUAU_ASSERT(function.name->is<AstExprIndexName>() || function.name->is<AstExprError>());
if (auto exprIndexName = function.name->as<AstExprIndexName>())
if (auto exprIndexName = function.name->as<AstExprIndexName>())
{
if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr))
{
if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr))
if (auto ttv = getMutableTableType(*typeIt))
{
if (auto ttv = getMutableTableType(*typeIt))
{
if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end())
it->second.type = follow(quantify(funScope, leftType, function.name->location));
}
if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end())
it->second.type = follow(quantify(funScope, leftType, function.name->location));
}
}
}
else
{
if (leftTypeBinding)
*leftTypeBinding = follow(quantify(funScope, leftType, function.name->location));
}
}
}
@ -1426,7 +1415,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo
currentModule->getModuleScope()->bindings[global.name] = Binding{fnType, global.location};
}
ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional<TypeId> expectedType)
ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional<TypeId> expectedType, bool forceSingleton)
{
RecursionCounter _rc(&checkRecursionCount);
if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit)
@ -1443,14 +1432,14 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr&
result = {nilType};
else if (const AstExprConstantBool* bexpr = expr.as<AstExprConstantBool>())
{
if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType))
if (FFlag::LuauSingletonTypes && (forceSingleton || (expectedType && maybeSingleton(*expectedType))))
result = {singletonType(bexpr->value)};
else
result = {booleanType};
}
else if (const AstExprConstantString* sexpr = expr.as<AstExprConstantString>())
{
if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType))
if (FFlag::LuauSingletonTypes && (forceSingleton || (expectedType && maybeSingleton(*expectedType))))
result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))};
else
result = {stringType};
@ -1488,15 +1477,8 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr&
result.type = follow(result.type);
if (FFlag::LuauStoreMatchingOverloadFnType)
{
if (!currentModule->astTypes.find(&expr))
currentModule->astTypes[&expr] = result.type;
}
else
{
if (!currentModule->astTypes.find(&expr))
currentModule->astTypes[&expr] = result.type;
}
if (expectedType)
currentModule->astExpectedTypes[&expr] = *expectedType;
@ -2242,7 +2224,6 @@ TypeId TypeChecker::checkRelationalOperation(
state.log.commit();
}
bool needsMetamethod = !isEquality;
TypeId leftType = follow(lhsType);
@ -2250,10 +2231,11 @@ TypeId TypeChecker::checkRelationalOperation(
{
reportErrors(state.errors);
const PrimitiveTypeVar* ptv = get<PrimitiveTypeVar>(leftType);
if (!isEquality && state.errors.empty() && (get<UnionTypeVar>(leftType) || (ptv && ptv->type == PrimitiveTypeVar::Boolean)))
if (!isEquality && state.errors.empty() && (get<UnionTypeVar>(leftType) || isBoolean(leftType)))
{
reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(),
toString(expr.op).c_str())});
}
return booleanType;
}
@ -2501,7 +2483,8 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi
ExprResult<TypeId> rhs = checkExpr(innerScope, *expr.right);
return {checkBinaryOperation(innerScope, expr, lhs.type, rhs.type), {AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}};
return {checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhs.type, rhs.type),
{AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}};
}
else if (expr.op == AstExprBinary::Or)
{
@ -2513,7 +2496,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi
ExprResult<TypeId> rhs = checkExpr(innerScope, *expr.right);
// Because of C++, I'm not sure if lhs.predicates was not moved out by the time we call checkBinaryOperation.
TypeId result = checkBinaryOperation(innerScope, expr, lhs.type, rhs.type, lhs.predicates);
TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhs.type, rhs.type, lhs.predicates);
return {result, {OrPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}};
}
else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe)
@ -2521,8 +2504,8 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi
if (auto predicate = tryGetTypeGuardPredicate(expr))
return {booleanType, {std::move(*predicate)}};
ExprResult<TypeId> lhs = checkExpr(scope, *expr.left);
ExprResult<TypeId> rhs = checkExpr(scope, *expr.right);
ExprResult<TypeId> lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions);
ExprResult<TypeId> rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions);
PredicateVec predicates;
@ -2621,11 +2604,10 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIf
TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr)
{
auto [ty, binding] = checkLValueBinding(scope, expr);
return ty;
return checkLValueBinding(scope, expr);
}
std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr)
TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr)
{
if (auto a = expr.as<AstExprLocal>())
return checkLValueBinding(scope, *a);
@ -2639,22 +2621,22 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
{
for (AstExpr* expr : a->expressions)
checkExpr(scope, *expr);
return {errorRecoveryType(scope), nullptr};
return errorRecoveryType(scope);
}
else
ice("Unexpected AST node in checkLValue", expr.location);
}
std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr)
TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr)
{
if (std::optional<TypeId> ty = scope->lookup(expr.local))
return {*ty, nullptr};
return *ty;
reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding});
return {errorRecoveryType(scope), nullptr};
return errorRecoveryType(scope);
}
std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr)
TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr)
{
Name name = expr.name.value;
ScopePtr moduleScope = currentModule->getModuleScope();
@ -2662,7 +2644,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
const auto it = moduleScope->bindings.find(expr.name);
if (it != moduleScope->bindings.end())
return std::pair(it->second.typeId, &it->second.typeId);
return it->second.typeId;
TypeId result = freshType(scope);
Binding& binding = moduleScope->bindings[expr.name];
@ -2673,15 +2655,15 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
if (!isNonstrictMode())
reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}});
return std::pair(result, &binding.typeId);
return result;
}
std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr)
TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr)
{
TypeId lhs = checkExpr(scope, *expr.expr).type;
if (get<ErrorTypeVar>(lhs) || get<AnyTypeVar>(lhs))
return std::pair(lhs, nullptr);
return lhs;
tablify(lhs);
@ -2694,7 +2676,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
const auto& it = lhsTable->props.find(name);
if (it != lhsTable->props.end())
{
return std::pair(it->second.type, &it->second.type);
return it->second.type;
}
else if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free)
{
@ -2702,7 +2684,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
Property& property = lhsTable->props[name];
property.type = theType;
property.location = expr.indexLocation;
return std::pair(theType, &property.type);
return theType;
}
else if (auto indexer = lhsTable->indexer)
{
@ -2720,17 +2702,17 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
else if (FFlag::LuauUseCommittingTxnLog)
state.log.commit();
return std::pair(retType, nullptr);
return retType;
}
else if (lhsTable->state == TableState::Sealed)
{
reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}});
return std::pair(errorRecoveryType(scope), nullptr);
return errorRecoveryType(scope);
}
else
{
reportError(TypeError{expr.location, GenericError{"Internal error: generic tables are not lvalues"}});
return std::pair(errorRecoveryType(scope), nullptr);
return errorRecoveryType(scope);
}
}
else if (const ClassTypeVar* lhsClass = get<ClassTypeVar>(lhs))
@ -2739,29 +2721,29 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
if (!prop)
{
reportError(TypeError{expr.location, UnknownProperty{lhs, name}});
return std::pair(errorRecoveryType(scope), nullptr);
return errorRecoveryType(scope);
}
return std::pair(prop->type, nullptr);
return prop->type;
}
else if (get<IntersectionTypeVar>(lhs))
{
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, lhs, name, expr.location, false))
return std::pair(*ty, nullptr);
return *ty;
// If intersection has a table part, report that it cannot be extended just as a sealed table
if (isTableIntersection(lhs))
{
reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}});
return std::pair(errorRecoveryType(scope), nullptr);
return errorRecoveryType(scope);
}
}
reportError(TypeError{expr.location, NotATable{lhs}});
return std::pair(errorRecoveryType(scope), nullptr);
return errorRecoveryType(scope);
}
std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr)
TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr)
{
TypeId exprType = checkExpr(scope, *expr.expr).type;
tablify(exprType);
@ -2771,7 +2753,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
TypeId indexType = checkExpr(scope, *expr.index).type;
if (get<AnyTypeVar>(exprType) || get<ErrorTypeVar>(exprType))
return std::pair(exprType, nullptr);
return exprType;
AstExprConstantString* value = expr.index->as<AstExprConstantString>();
@ -2783,9 +2765,9 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
if (!prop)
{
reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}});
return std::pair(errorRecoveryType(scope), nullptr);
return errorRecoveryType(scope);
}
return std::pair(prop->type, nullptr);
return prop->type;
}
}
@ -2794,7 +2776,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
if (!exprTable)
{
reportError(TypeError{expr.expr->location, NotATable{exprType}});
return std::pair(errorRecoveryType(scope), nullptr);
return errorRecoveryType(scope);
}
if (value)
@ -2802,7 +2784,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
const auto& it = exprTable->props.find(value->value.data);
if (it != exprTable->props.end())
{
return std::pair(it->second.type, &it->second.type);
return it->second.type;
}
else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free)
{
@ -2810,7 +2792,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
Property& property = exprTable->props[value->value.data];
property.type = resultType;
property.location = expr.index->location;
return std::pair(resultType, &property.type);
return resultType;
}
}
@ -2818,18 +2800,18 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
{
const TableIndexer& indexer = *exprTable->indexer;
unify(indexType, indexer.indexType, expr.index->location);
return std::pair(indexer.indexResultType, nullptr);
return indexer.indexResultType;
}
else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free)
{
TypeId resultType = freshType(scope);
exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)};
return std::pair(resultType, nullptr);
return resultType;
}
else
{
TypeId resultType = freshType(scope);
return std::pair(resultType, nullptr);
return resultType;
}
}
@ -3326,7 +3308,7 @@ void TypeChecker::checkArgumentList(
} // ok
else
{
state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}});
state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}});
return;
}
++paramIter;
@ -3348,7 +3330,7 @@ void TypeChecker::checkArgumentList(
Location location = state.location;
if (!argLocations.empty())
location = {state.location.begin, argLocations.back().end};
state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}});
state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}});
return;
}
TypePackId tail = state.log.follow(*paramIter.tail());
@ -3405,7 +3387,7 @@ void TypeChecker::checkArgumentList(
if (!argLocations.empty())
location = {state.location.begin, argLocations.back().end};
// TODO: Better error message?
state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}});
state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}});
return;
}
}
@ -3520,7 +3502,7 @@ void TypeChecker::checkArgumentList(
} // ok
else
{
state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}});
state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}});
return;
}
++paramIter;
@ -3540,7 +3522,7 @@ void TypeChecker::checkArgumentList(
Location location = state.location;
if (!argLocations.empty())
location = {state.location.begin, argLocations.back().end};
state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}});
state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}});
return;
}
TypePackId tail = *paramIter.tail();
@ -3606,7 +3588,7 @@ void TypeChecker::checkArgumentList(
if (!argLocations.empty())
location = {state.location.begin, argLocations.back().end};
// TODO: Better error message?
state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}});
state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}});
return;
}
}
@ -3825,22 +3807,11 @@ std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const Scope
metaArgLocations = *argLocations;
metaArgLocations.insert(metaArgLocations.begin(), expr.func->location);
if (FFlag::LuauFixRecursiveMetatableCall)
{
fn = instantiate(scope, *ty, expr.func->location);
fn = instantiate(scope, *ty, expr.func->location);
argPack = metaCallArgPack;
args = metaCallArgs;
argLocations = &metaArgLocations;
}
else
{
TypeId fn = *ty;
fn = instantiate(scope, fn, expr.func->location);
return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, &metaArgLocations, argListResult,
overloadsThatMatchArgCount, overloadsThatDont, errors);
}
argPack = metaCallArgPack;
args = metaCallArgs;
argLocations = &metaArgLocations;
}
}
@ -3932,8 +3903,7 @@ std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const Scope
}
}
if (FFlag::LuauStoreMatchingOverloadFnType)
currentModule->astOverloadResolvedTypes[&expr] = fn;
currentModule->astOverloadResolvedTypes[&expr] = fn;
// We select this overload
return {{retPack}};
@ -4776,7 +4746,7 @@ TypeId TypeChecker::freshType(TypeLevel level)
TypeId TypeChecker::singletonType(bool value)
{
// TODO: cache singleton types
return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BoolSingleton{value})));
return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BooleanSingleton{value})));
}
TypeId TypeChecker::singletonType(std::string value)
@ -4813,13 +4783,6 @@ std::optional<TypeId> TypeChecker::filterMap(TypeId type, TypeIdPredicate predic
return std::nullopt;
}
TypeId TypeChecker::addType(const UnionTypeVar& utv)
{
LUAU_ASSERT(utv.options.size() > 1);
return addTV(TypeVar(utv));
}
TypeId TypeChecker::addTV(TypeVar&& tv)
{
return currentModule->internalTypes.addType(std::move(tv));
@ -5347,54 +5310,35 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
TypeId instantiated = *maybeInstantiated;
if (FFlag::LuauCloneCorrectlyBeforeMutatingTableType)
// TODO: CLI-46926 it's not a good idea to rename the type here
TypeId target = follow(instantiated);
bool needsClone = follow(tf.type) == target;
TableTypeVar* ttv = getMutableTableType(target);
if (ttv && needsClone)
{
// TODO: CLI-46926 it's not a good idea to rename the type here
TypeId target = follow(instantiated);
bool needsClone = follow(tf.type) == target;
TableTypeVar* ttv = getMutableTableType(target);
if (ttv && needsClone)
// Substitution::clone is a shallow clone. If this is a metatable type, we
// want to mutate its table, so we need to explicitly clone that table as
// well. If we don't, we will mutate another module's type surface and cause
// a use-after-free.
if (get<MetatableTypeVar>(target))
{
// Substitution::clone is a shallow clone. If this is a metatable type, we
// want to mutate its table, so we need to explicitly clone that table as
// well. If we don't, we will mutate another module's type surface and cause
// a use-after-free.
if (get<MetatableTypeVar>(target))
{
instantiated = applyTypeFunction.clone(tf.type);
MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(instantiated);
mtv->table = applyTypeFunction.clone(mtv->table);
ttv = getMutable<TableTypeVar>(mtv->table);
}
if (get<TableTypeVar>(target))
{
instantiated = applyTypeFunction.clone(tf.type);
ttv = getMutable<TableTypeVar>(instantiated);
}
instantiated = applyTypeFunction.clone(tf.type);
MetatableTypeVar* mtv = getMutable<MetatableTypeVar>(instantiated);
mtv->table = applyTypeFunction.clone(mtv->table);
ttv = getMutable<TableTypeVar>(mtv->table);
}
if (ttv)
if (get<TableTypeVar>(target))
{
ttv->instantiatedTypeParams = typeParams;
ttv->instantiatedTypePackParams = typePackParams;
instantiated = applyTypeFunction.clone(tf.type);
ttv = getMutable<TableTypeVar>(instantiated);
}
}
else
{
if (TableTypeVar* ttv = getMutableTableType(instantiated))
{
if (follow(tf.type) == instantiated)
{
// This can happen if a type alias has generics that it does not use at all.
// ex type FooBar<T> = { a: number }
instantiated = applyTypeFunction.clone(tf.type);
ttv = getMutableTableType(instantiated);
}
ttv->instantiatedTypeParams = typeParams;
ttv->instantiatedTypePackParams = typePackParams;
}
if (ttv)
{
ttv->instantiatedTypeParams = typeParams;
ttv->instantiatedTypePackParams = typePackParams;
}
return instantiated;
@ -5482,6 +5426,85 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st
return {generics, genericPacks};
}
void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate)
{
LUAU_ASSERT(FFlag::LuauDiscriminableUnions);
const LValue* target = &lvalue;
std::optional<LValue> key; // If set, we know we took the base of the lvalue path and should be walking down each option of the base's type.
auto ty = resolveLValue(scope, *target);
if (!ty)
return; // Do nothing. An error was already reported.
// If the provided lvalue is a local or global, then that's without a doubt the target.
// However, if there is a base lvalue, then we'll want that to be the target iff the base is a union type.
if (auto base = baseof(lvalue))
{
std::optional<TypeId> baseTy = resolveLValue(scope, *base);
if (baseTy && get<UnionTypeVar>(follow(*baseTy)))
{
ty = baseTy;
target = base;
key = lvalue;
}
}
// If we do not have a key, it means we're not trying to discriminate anything, so it's a simple matter of just filtering for a subset.
if (!key)
{
if (std::optional<TypeId> result = filterMap(*ty, predicate))
addRefinement(refis, *target, *result);
else
addRefinement(refis, *target, errorRecoveryType(scope));
return;
}
// Otherwise, we'll want to walk each option of ty, get its index type, and filter that.
auto utv = get<UnionTypeVar>(follow(*ty));
LUAU_ASSERT(utv);
std::unordered_set<TypeId> viableTargetOptions;
std::unordered_set<TypeId> viableChildOptions; // There may be additional refinements that apply. We add those here too.
for (TypeId option : utv)
{
std::optional<TypeId> discriminantTy;
if (auto field = Luau::get<Field>(*key)) // need to fully qualify Luau::get because of ADL.
discriminantTy = getIndexTypeFromType(scope, option, field->key, Location(), false);
else
LUAU_ASSERT(!"Unhandled LValue alternative?");
if (!discriminantTy)
return; // Do nothing. An error was already reported, as per usual.
if (std::optional<TypeId> result = filterMap(*discriminantTy, predicate))
{
viableTargetOptions.insert(option);
viableChildOptions.insert(*result);
}
}
auto intoType = [this](const std::unordered_set<TypeId>& s) -> std::optional<TypeId> {
if (s.empty())
return std::nullopt;
// TODO: allocate UnionTypeVar and just normalize.
std::vector<TypeId> options(s.begin(), s.end());
if (options.size() == 1)
return options[0];
return addType(UnionTypeVar{std::move(options)});
};
if (std::optional<TypeId> viableTargetType = intoType(viableTargetOptions))
addRefinement(refis, *target, *viableTargetType);
if (std::optional<TypeId> viableChildType = intoType(viableChildOptions))
addRefinement(refis, lvalue, *viableChildType);
}
std::optional<TypeId> TypeChecker::resolveLValue(const ScopePtr& scope, const LValue& lvalue)
{
if (!FFlag::LuauLValueAsKey)
@ -5645,18 +5668,29 @@ void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, Refi
return std::nullopt;
};
std::optional<TypeId> ty = resolveLValue(refis, scope, truthyP.lvalue);
if (!ty)
return;
if (FFlag::LuauDiscriminableUnions)
{
std::optional<TypeId> ty = resolveLValue(refis, scope, truthyP.lvalue);
if (ty && fromOr)
return addRefinement(refis, truthyP.lvalue, *ty);
// This is a hack. :(
// Without this, the expression 'a or b' might refine 'b' to be falsy.
// I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime.
if (fromOr)
return addRefinement(refis, truthyP.lvalue, *ty);
refineLValue(truthyP.lvalue, refis, scope, predicate);
}
else
{
std::optional<TypeId> ty = resolveLValue(refis, scope, truthyP.lvalue);
if (!ty)
return;
if (std::optional<TypeId> result = filterMap(*ty, predicate))
addRefinement(refis, truthyP.lvalue, *result);
// This is a hack. :(
// Without this, the expression 'a or b' might refine 'b' to be falsy.
// I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime.
if (fromOr)
return addRefinement(refis, truthyP.lvalue, *ty);
if (std::optional<TypeId> result = filterMap(*ty, predicate))
addRefinement(refis, truthyP.lvalue, *result);
}
}
void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense)
@ -5746,16 +5780,23 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement
return res;
};
std::optional<TypeId> ty = resolveLValue(refis, scope, isaP.lvalue);
if (!ty)
return;
if (std::optional<TypeId> result = filterMap(*ty, predicate))
addRefinement(refis, isaP.lvalue, *result);
if (FFlag::LuauDiscriminableUnions)
{
refineLValue(isaP.lvalue, refis, scope, predicate);
}
else
{
addRefinement(refis, isaP.lvalue, errorRecoveryType(scope));
errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}});
std::optional<TypeId> ty = resolveLValue(refis, scope, isaP.lvalue);
if (!ty)
return;
if (std::optional<TypeId> result = filterMap(*ty, predicate))
addRefinement(refis, isaP.lvalue, *result);
else
{
addRefinement(refis, isaP.lvalue, errorRecoveryType(scope));
errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}});
}
}
}
@ -5814,21 +5855,30 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec
if (auto it = primitives.find(typeguardP.kind); it != primitives.end())
{
if (std::optional<TypeId> result = filterMap(*ty, it->second(sense)))
addRefinement(refis, typeguardP.lvalue, *result);
if (FFlag::LuauDiscriminableUnions)
{
refineLValue(typeguardP.lvalue, refis, scope, it->second(sense));
return;
}
else
{
addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope));
if (sense)
errVec.push_back(
TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}});
}
if (std::optional<TypeId> result = filterMap(*ty, it->second(sense)))
addRefinement(refis, typeguardP.lvalue, *result);
else
{
addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope));
if (sense)
errVec.push_back(
TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}});
}
return;
return;
}
}
auto fail = [&](const TypeErrorData& err) {
errVec.push_back(TypeError{typeguardP.location, err});
if (!FFlag::LuauDiscriminableUnions)
errVec.push_back(TypeError{typeguardP.location, err});
addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope));
};
@ -5853,55 +5903,87 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec
void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense)
{
// This refinement will require success typing to do everything correctly. For now, we can get most of the way there.
auto options = [](TypeId ty) -> std::vector<TypeId> {
if (auto utv = get<UnionTypeVar>(follow(ty)))
return std::vector<TypeId>(begin(utv), end(utv));
return {ty};
};
if (FFlag::LuauWeakEqConstraint)
if (FFlag::LuauDiscriminableUnions)
{
if (!sense && isNil(eqP.type))
resolve(TruthyPredicate{std::move(eqP.lvalue), eqP.location}, errVec, refis, scope, true, /* fromOr= */ false);
return;
}
if (FFlag::LuauEqConstraint)
{
std::optional<TypeId> ty = resolveLValue(refis, scope, eqP.lvalue);
if (!ty)
return;
std::vector<TypeId> lhs = options(*ty);
std::vector<TypeId> rhs = options(eqP.type);
if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable))
{
addRefinement(refis, eqP.lvalue, eqP.type);
return;
}
else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable))
if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable))
return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here.
std::unordered_set<TypeId> set;
for (TypeId left : lhs)
{
for (TypeId right : rhs)
auto predicate = [&](TypeId option) -> std::optional<TypeId> {
if (sense && isUndecidable(option))
return FFlag::LuauWeakEqConstraint ? option : eqP.type;
if (!sense && isNil(eqP.type))
return (isUndecidable(option) || !isNil(option)) ? std::optional<TypeId>(option) : std::nullopt;
if (maybeSingleton(eqP.type))
{
// When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`.
if (canUnify(right, left, eqP.location).empty() == sense || (!sense && !isNil(left)))
set.insert(left);
// Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this.
if (!sense || canUnify(eqP.type, option, eqP.location).empty())
return sense ? eqP.type : option;
// local variable works around an odd gcc 9.3 warning: <anonymous> may be used uninitialized
std::optional<TypeId> res = std::nullopt;
return res;
}
return option;
};
refineLValue(eqP.lvalue, refis, scope, predicate);
}
else
{
if (FFlag::LuauWeakEqConstraint)
{
if (!sense && isNil(eqP.type))
resolve(TruthyPredicate{std::move(eqP.lvalue), eqP.location}, errVec, refis, scope, true, /* fromOr= */ false);
return;
}
if (set.empty())
return;
if (FFlag::LuauEqConstraint)
{
std::optional<TypeId> ty = resolveLValue(refis, scope, eqP.lvalue);
if (!ty)
return;
std::vector<TypeId> viable(set.begin(), set.end());
TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)});
addRefinement(refis, eqP.lvalue, result);
std::vector<TypeId> lhs = options(*ty);
std::vector<TypeId> rhs = options(eqP.type);
if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable))
{
addRefinement(refis, eqP.lvalue, eqP.type);
return;
}
else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable))
return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here.
std::unordered_set<TypeId> set;
for (TypeId left : lhs)
{
for (TypeId right : rhs)
{
// When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`.
if (canUnify(right, left, eqP.location).empty() == sense || (!sense && !isNil(left)))
set.insert(left);
}
}
if (set.empty())
return;
std::vector<TypeId> viable(set.begin(), set.end());
TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)});
addRefinement(refis, eqP.lvalue, result);
}
}
}

View File

@ -18,14 +18,15 @@
#include <unordered_map>
#include <unordered_set>
LUAU_FASTFLAG(DebugLuauFreezeArena)
LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500)
LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAG(LuauLengthOnCompositeType)
LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false)
LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false)
LUAU_FASTFLAGVARIABLE(LuauRefactorTypeVarQuestions, false)
LUAU_FASTFLAG(LuauErrorRecoveryType)
LUAU_FASTFLAG(DebugLuauFreezeArena)
namespace Luau
{
@ -144,7 +145,20 @@ bool isNil(TypeId ty)
bool isBoolean(TypeId ty)
{
return isPrim(ty, PrimitiveTypeVar::Boolean);
if (FFlag::LuauRefactorTypeVarQuestions)
{
if (isPrim(ty, PrimitiveTypeVar::Boolean) || get<BooleanSingleton>(get<SingletonTypeVar>(follow(ty))))
return true;
if (auto utv = get<UnionTypeVar>(follow(ty)))
return std::all_of(begin(utv), end(utv), isBoolean);
return false;
}
else
{
return isPrim(ty, PrimitiveTypeVar::Boolean);
}
}
bool isNumber(TypeId ty)
@ -154,7 +168,20 @@ bool isNumber(TypeId ty)
bool isString(TypeId ty)
{
return isPrim(ty, PrimitiveTypeVar::String);
if (FFlag::LuauRefactorTypeVarQuestions)
{
if (isPrim(ty, PrimitiveTypeVar::String) || get<StringSingleton>(get<SingletonTypeVar>(follow(ty))))
return true;
if (auto utv = get<UnionTypeVar>(follow(ty)))
return std::all_of(begin(utv), end(utv), isString);
return false;
}
else
{
return isPrim(ty, PrimitiveTypeVar::String);
}
}
bool isThread(TypeId ty)
@ -167,37 +194,45 @@ bool isOptional(TypeId ty)
if (isNil(ty))
return true;
if (!get<UnionTypeVar>(follow(ty)))
return false;
std::unordered_set<TypeId> seen;
std::deque<TypeId> queue{ty};
while (!queue.empty())
if (FFlag::LuauRefactorTypeVarQuestions)
{
TypeId current = follow(queue.front());
queue.pop_front();
auto utv = get<UnionTypeVar>(follow(ty));
if (!utv)
return false;
if (seen.count(current))
continue;
seen.insert(current);
if (isNil(current))
return true;
if (auto u = get<UnionTypeVar>(current))
return std::any_of(begin(utv), end(utv), isNil);
}
else
{
std::unordered_set<TypeId> seen;
std::deque<TypeId> queue{ty};
while (!queue.empty())
{
for (TypeId option : u->options)
{
if (isNil(option))
return true;
TypeId current = follow(queue.front());
queue.pop_front();
queue.push_back(option);
if (seen.count(current))
continue;
seen.insert(current);
if (isNil(current))
return true;
if (auto u = get<UnionTypeVar>(current))
{
for (TypeId option : u->options)
{
if (isNil(option))
return true;
queue.push_back(option);
}
}
}
}
return false;
return false;
}
}
bool isTableIntersection(TypeId ty)
@ -228,13 +263,27 @@ std::optional<TypeId> getMetatable(TypeId type)
return mtType->metatable;
else if (const ClassTypeVar* classType = get<ClassTypeVar>(type))
return classType->metatable;
else if (const PrimitiveTypeVar* primitiveType = get<PrimitiveTypeVar>(type); primitiveType && primitiveType->metatable)
else if (FFlag::LuauRefactorTypeVarQuestions)
{
LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String);
return primitiveType->metatable;
if (isString(type))
{
auto ptv = get<PrimitiveTypeVar>(getSingletonTypes().stringType);
LUAU_ASSERT(ptv && ptv->metatable);
return ptv->metatable;
}
else
return std::nullopt;
}
else
return std::nullopt;
{
if (const PrimitiveTypeVar* primitiveType = get<PrimitiveTypeVar>(type); primitiveType && primitiveType->metatable)
{
LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String);
return primitiveType->metatable;
}
else
return std::nullopt;
}
}
const TableTypeVar* getTableType(TypeId type)
@ -696,7 +745,7 @@ TypeId SingletonTypes::makeStringMetatable()
{"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}},
{"upper", {stringToStringType}},
{"split", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalString}, {},
{"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {},
{arena->addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, TypeLevel{}})})}},
{"pack", {arena->addType(FunctionTypeVar{
arena->addTypePack(TypePack{{stringType}, anyTypePack}),
@ -1108,30 +1157,14 @@ static Tags* getTags(TypeId ty)
void attachTag(TypeId ty, const std::string& tagName)
{
if (!FFlag::LuauRefactorTagging)
{
if (auto ftv = getMutable<FunctionTypeVar>(ty))
{
ftv->tags.emplace_back(tagName);
}
else
{
LUAU_ASSERT(!"Got a non functional type");
}
}
if (auto tags = getTags(ty))
tags->push_back(tagName);
else
{
if (auto tags = getTags(ty))
tags->push_back(tagName);
else
LUAU_ASSERT(!"This TypeId does not support tags");
}
LUAU_ASSERT(!"This TypeId does not support tags");
}
void attachTag(Property& prop, const std::string& tagName)
{
LUAU_ASSERT(FFlag::LuauRefactorTagging);
prop.tags.push_back(tagName);
}
@ -1140,7 +1173,6 @@ void attachTag(Property& prop, const std::string& tagName)
// Unfortunately, there's already use cases that's hard to disentangle. For now, we expose it.
bool hasTag(const Tags& tags, const std::string& tagName)
{
LUAU_ASSERT(FFlag::LuauRefactorTagging);
return std::find(tags.begin(), tags.end(), tagName) != tags.end();
}

View File

@ -17,15 +17,11 @@ LUAU_FASTFLAGVARIABLE(LuauCommittingTxnLogFreeTpPromote, false)
LUAU_FASTFLAG(LuauUseCommittingTxnLog)
LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000);
LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false);
LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false)
LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false)
LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false)
LUAU_FASTFLAG(LuauSingletonTypes)
LUAU_FASTFLAG(LuauErrorRecoveryType);
LUAU_FASTFLAG(LuauProperTypeLevels);
LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false)
LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false)
LUAU_FASTFLAGVARIABLE(LuauExtendedFunctionMismatchError, false)
namespace Luau
{
@ -229,8 +225,6 @@ static std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors)
// Used for tagged union matching heuristic, returns first singleton type field
static std::optional<std::pair<Luau::Name, const SingletonTypeVar*>> getTableMatchTag(TypeId type)
{
LUAU_ASSERT(FFlag::LuauExtendedUnionMismatchError);
type = follow(type);
if (auto ttv = get<TableTypeVar>(type))
@ -291,7 +285,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount)
{
errors.push_back(TypeError{location, UnificationTooComplex{}});
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
@ -403,7 +397,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (subGeneric && !subGeneric->level.subsumes(superLevel))
{
// TODO: a more informative error message? CLI-39912
errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}});
reportError(TypeError{location, GenericError{"Generic subtype escaping scope"}});
return;
}
@ -448,7 +442,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (superGeneric && !superGeneric->level.subsumes(subFree->level))
{
// TODO: a more informative error message? CLI-39912
errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}});
reportError(TypeError{location, GenericError{"Generic supertype escaping scope"}});
return;
}
@ -561,13 +555,13 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
}
if (unificationTooComplex)
errors.push_back(*unificationTooComplex);
reportError(*unificationTooComplex);
else if (failed)
{
if (firstFailedOption)
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}});
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}});
else
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
}
}
else if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable<UnionTypeVar>(superTy) : get<UnionTypeVar>(superTy))
@ -582,50 +576,44 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
bool foundHeuristic = false;
size_t startIndex = 0;
if (FFlag::LuauUnionHeuristic)
if (const std::string* subName = getName(subTy))
{
if (const std::string* subName = getName(subTy))
for (size_t i = 0; i < uv->options.size(); ++i)
{
for (size_t i = 0; i < uv->options.size(); ++i)
const std::string* optionName = getName(uv->options[i]);
if (optionName && *optionName == *subName)
{
const std::string* optionName = getName(uv->options[i]);
if (optionName && *optionName == *subName)
{
foundHeuristic = true;
startIndex = i;
break;
}
foundHeuristic = true;
startIndex = i;
break;
}
}
}
if (FFlag::LuauExtendedUnionMismatchError)
if (auto subMatchTag = getTableMatchTag(subTy))
{
for (size_t i = 0; i < uv->options.size(); ++i)
{
if (auto subMatchTag = getTableMatchTag(subTy))
auto optionMatchTag = getTableMatchTag(uv->options[i]);
if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second)
{
for (size_t i = 0; i < uv->options.size(); ++i)
{
auto optionMatchTag = getTableMatchTag(uv->options[i]);
if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second)
{
foundHeuristic = true;
startIndex = i;
break;
}
}
foundHeuristic = true;
startIndex = i;
break;
}
}
}
if (!foundHeuristic && cacheEnabled)
if (!foundHeuristic && cacheEnabled)
{
for (size_t i = 0; i < uv->options.size(); ++i)
{
for (size_t i = 0; i < uv->options.size(); ++i)
{
TypeId type = uv->options[i];
TypeId type = uv->options[i];
if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type})))
{
startIndex = i;
break;
}
if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type})))
{
startIndex = i;
break;
}
}
}
@ -650,7 +638,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
{
unificationTooComplex = e;
}
else if (FFlag::LuauExtendedUnionMismatchError && !isNil(type))
else if (!isNil(type))
{
failedOptionCount++;
@ -664,15 +652,15 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
if (unificationTooComplex)
{
errors.push_back(*unificationTooComplex);
reportError(*unificationTooComplex);
}
else if (!found)
{
if (FFlag::LuauExtendedUnionMismatchError && (failedOptionCount == 1 || foundHeuristic) && failedOption)
errors.push_back(
if ((failedOptionCount == 1 || foundHeuristic) && failedOption)
reportError(
TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}});
else
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}});
reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}});
}
}
else if (const IntersectionTypeVar* uv =
@ -702,9 +690,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
}
if (unificationTooComplex)
errors.push_back(*unificationTooComplex);
reportError(*unificationTooComplex);
else if (firstFailedOption)
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}});
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}});
}
else if (const IntersectionTypeVar* uv =
FFlag::LuauUseCommittingTxnLog ? log.getMutable<IntersectionTypeVar>(subTy) : get<IntersectionTypeVar>(subTy))
@ -754,10 +742,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
}
if (unificationTooComplex)
errors.push_back(*unificationTooComplex);
reportError(*unificationTooComplex);
else if (!found)
{
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}});
reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}});
}
}
else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable<PrimitiveTypeVar>(superTy) && log.getMutable<PrimitiveTypeVar>(subTy)) ||
@ -801,7 +789,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
tryUnifyWithClass(subTy, superTy, /*reversed*/ true);
else
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
if (FFlag::LuauUseCommittingTxnLog)
log.popSeen(superTy, subTy);
@ -1067,7 +1055,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount)
{
errors.push_back(TypeError{location, UnificationTooComplex{}});
reportError(TypeError{location, UnificationTooComplex{}});
return;
}
@ -1166,7 +1154,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
{
tryUnify_(*subIter, *superIter);
if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos)
if (!errors.empty() && !firstPackErrorPos)
firstPackErrorPos = loopCount;
superIter.advance();
@ -1251,7 +1239,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
size_t actualSize = size(subTp);
if (ctx == CountMismatch::Result)
std::swap(expectedSize, actualSize);
errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}});
reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}});
while (superIter.good())
{
@ -1272,7 +1260,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
}
else
{
errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}});
reportError(TypeError{location, GenericError{"Failed to unify type packs"}});
}
}
else
@ -1372,7 +1360,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
{
tryUnify_(*subIter, *superIter);
if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos)
if (!errors.empty() && !firstPackErrorPos)
firstPackErrorPos = loopCount;
superIter.advance();
@ -1459,7 +1447,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
size_t actualSize = size(subTp);
if (ctx == CountMismatch::Result)
std::swap(expectedSize, actualSize);
errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}});
reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}});
while (superIter.good())
{
@ -1480,7 +1468,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal
}
else
{
errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}});
reportError(TypeError{location, GenericError{"Failed to unify type packs"}});
}
}
}
@ -1493,7 +1481,7 @@ void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy)
ice("passed non primitive types to unifyPrimitives");
if (superPrim->type != subPrim->type)
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
}
void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy)
@ -1508,13 +1496,13 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy)
if (superSingleton && *superSingleton == *subSingleton)
return;
if (superPrim && superPrim->type == PrimitiveTypeVar::Boolean && get<BoolSingleton>(subSingleton) && variance == Covariant)
if (superPrim && superPrim->type == PrimitiveTypeVar::Boolean && get<BooleanSingleton>(subSingleton) && variance == Covariant)
return;
if (superPrim && superPrim->type == PrimitiveTypeVar::String && get<StringSingleton>(subSingleton) && variance == Covariant)
return;
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
}
void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall)
@ -1536,10 +1524,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
{
numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size());
if (FFlag::LuauExtendedFunctionMismatchError)
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}});
else
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}});
}
size_t numGenericPacks = superFunction->genericPacks.size();
@ -1547,10 +1532,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
{
numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size());
if (FFlag::LuauExtendedFunctionMismatchError)
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}});
else
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}});
}
for (size_t i = 0; i < numGenerics; i++)
@ -1567,48 +1549,35 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal
{
Unifier innerState = makeChildUnifier();
if (FFlag::LuauExtendedFunctionMismatchError)
innerState.ctx = CountMismatch::Arg;
innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall);
bool reported = !innerState.errors.empty();
if (auto e = hasUnificationTooComplex(innerState.errors))
reportError(*e);
else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
reportError(
TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front()}});
else if (!innerState.errors.empty())
reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}});
innerState.ctx = CountMismatch::Result;
innerState.tryUnify_(subFunction->retType, superFunction->retType);
if (!reported)
{
innerState.ctx = CountMismatch::Arg;
innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall);
bool reported = !innerState.errors.empty();
if (auto e = hasUnificationTooComplex(innerState.errors))
errors.push_back(*e);
reportError(*e);
else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType))
reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}});
else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
errors.push_back(
TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos),
reportError(
TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front()}});
else if (!innerState.errors.empty())
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}});
innerState.ctx = CountMismatch::Result;
innerState.tryUnify_(subFunction->retType, superFunction->retType);
if (!reported)
{
if (auto e = hasUnificationTooComplex(innerState.errors))
errors.push_back(*e);
else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType))
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}});
else if (!innerState.errors.empty() && innerState.firstPackErrorPos)
errors.push_back(
TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos),
innerState.errors.front()}});
else if (!innerState.errors.empty())
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}});
}
}
else
{
ctx = CountMismatch::Arg;
innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall);
ctx = CountMismatch::Result;
innerState.tryUnify_(subFunction->retType, superFunction->retType);
checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy);
reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}});
}
if (FFlag::LuauUseCommittingTxnLog)
@ -1716,7 +1685,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!missingProperties.empty())
{
errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}});
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}});
return;
}
}
@ -1734,7 +1703,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!extraProperties.empty())
{
errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}});
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}});
return;
}
}
@ -1957,13 +1926,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
if (!missingProperties.empty())
{
errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}});
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}});
return;
}
if (!extraProperties.empty())
{
errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}});
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}});
return;
}
@ -2051,7 +2020,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt
return tryUnifySealedTables(subTy, superTy, isIntersection);
else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Generic) ||
(superTable->state == TableState::Generic && subTable->state == TableState::Sealed))
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
else if ((superTable->state == TableState::Free) != (subTable->state == TableState::Free)) // one table is free and the other is not
{
TypeId freeTypeId = subTable->state == TableState::Free ? subTy : superTy;
@ -2090,7 +2059,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt
{
const auto& r = subTable->props.find(name);
if (r == subTable->props.end())
errors.push_back(TypeError{location, UnknownProperty{subTy, name}});
reportError(TypeError{location, UnknownProperty{subTy, name}});
else
tryUnify_(r->second.type, prop.type);
}
@ -2113,7 +2082,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt
}
}
else
errors.push_back(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}});
reportError(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}});
}
}
else if (superTable->state == TableState::Sealed)
@ -2194,7 +2163,7 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy)
}
}
else
errors.push_back(TypeError{location, UnknownProperty{subTy, freeName}});
reportError(TypeError{location, UnknownProperty{subTy, freeName}});
}
}
@ -2268,7 +2237,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec
if (!missingPropertiesInSuper.empty())
{
errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}});
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}});
return;
}
}
@ -2284,7 +2253,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec
missingPropertiesInSuper.push_back(it.first);
innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}});
}
else
{
@ -2299,7 +2268,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec
if (oldErrorSize != innerState.errors.size() && !errorReported)
{
errorReported = true;
errors.push_back(innerState.errors.back());
reportError(innerState.errors.back());
}
}
else
@ -2340,7 +2309,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec
}
}
else
innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}});
}
else
{
@ -2369,7 +2338,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec
}
}
else
innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}});
}
}
@ -2386,7 +2355,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec
if (!missingPropertiesInSuper.empty())
{
errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}});
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}});
return;
}
@ -2413,7 +2382,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec
if (!extraPropertiesInSub.empty())
{
errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraPropertiesInSub), MissingProperties::Extra}});
reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraPropertiesInSub), MissingProperties::Extra}});
return;
}
}
@ -2437,9 +2406,9 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
innerState.tryUnify_(subMetatable->metatable, superMetatable->metatable);
if (auto e = hasUnificationTooComplex(innerState.errors))
errors.push_back(*e);
reportError(*e);
else if (!innerState.errors.empty())
errors.push_back(
reportError(
TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}});
if (FFlag::LuauUseCommittingTxnLog)
@ -2470,7 +2439,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
case TableState::Sealed:
case TableState::Unsealed:
case TableState::Generic:
errors.push_back(mismatchError);
reportError(mismatchError);
}
}
else if (FFlag::LuauUseCommittingTxnLog ? (log.getMutable<AnyTypeVar>(subTy) || log.getMutable<ErrorTypeVar>(subTy))
@ -2479,7 +2448,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed)
}
else
{
errors.push_back(mismatchError);
reportError(mismatchError);
}
}
@ -2491,9 +2460,9 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
auto fail = [&]() {
if (!reversed)
errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}});
reportError(TypeError{location, TypeMismatch{superTy, subTy}});
else
errors.push_back(TypeError{location, TypeMismatch{subTy, superTy}});
reportError(TypeError{location, TypeMismatch{subTy, superTy}});
};
const ClassTypeVar* superClass = get<ClassTypeVar>(superTy);
@ -2538,7 +2507,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
if (!classProp)
{
ok = false;
errors.push_back(TypeError{location, UnknownProperty{superTy, propName}});
reportError(TypeError{location, UnknownProperty{superTy, propName}});
}
else
{
@ -2577,7 +2546,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed)
{
ok = false;
std::string msg = "Class " + superClass->name + " does not have an indexer";
errors.push_back(TypeError{location, GenericError{msg}});
reportError(TypeError{location, GenericError{msg}});
}
if (!ok)
@ -2695,7 +2664,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever
}
else if (get<Unifiable::Generic>(tail))
{
errors.push_back(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}});
reportError(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}});
}
else if (get<Unifiable::Error>(tail))
{
@ -2709,7 +2678,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever
}
else
{
errors.push_back(TypeError{location, GenericError{"Failed to unify variadic packs"}});
reportError(TypeError{location, GenericError{"Failed to unify variadic packs"}});
}
}
@ -2886,7 +2855,7 @@ void Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
if (needle == haystack)
{
errors.push_back(TypeError{location, OccursCheckFailed{}});
reportError(TypeError{location, OccursCheckFailed{}});
log.replace(needle, *getSingletonTypes().errorRecoveryType());
return;
@ -2894,17 +2863,6 @@ void Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
if (log.getMutable<FreeTypeVar>(haystack))
return;
else if (auto a = log.getMutable<FunctionTypeVar>(haystack))
{
if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions)
{
for (TypePackIterator it(a->argTypes, &log); it != end(a->argTypes); ++it)
check(*it);
for (TypePackIterator it(a->retType, &log); it != end(a->retType); ++it)
check(*it);
}
}
else if (auto a = log.getMutable<UnionTypeVar>(haystack))
{
for (TypeId ty : a->options)
@ -2934,7 +2892,7 @@ void Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
if (needle == haystack)
{
errors.push_back(TypeError{location, OccursCheckFailed{}});
reportError(TypeError{location, OccursCheckFailed{}});
DEPRECATED_log(needle);
*asMutable(needle) = *getSingletonTypes().errorRecoveryType();
return;
@ -2942,17 +2900,6 @@ void Unifier::occursCheck(DenseHashSet<TypeId>& seen, TypeId needle, TypeId hays
if (get<FreeTypeVar>(haystack))
return;
else if (auto a = get<FunctionTypeVar>(haystack))
{
if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions)
{
for (TypeId ty : a->argTypes)
check(ty);
for (TypeId ty : a->retType)
check(ty);
}
}
else if (auto a = get<UnionTypeVar>(haystack))
{
for (TypeId ty : a->options)
@ -2988,7 +2935,7 @@ void Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
if (log.getMutable<Unifiable::Error>(needle))
return;
if (!get<Unifiable::Free>(needle))
if (!log.getMutable<Unifiable::Free>(needle))
ice("Expected needle pack to be free");
RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit);
@ -2997,32 +2944,18 @@ void Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
{
if (needle == haystack)
{
errors.push_back(TypeError{location, OccursCheckFailed{}});
reportError(TypeError{location, OccursCheckFailed{}});
log.replace(needle, *getSingletonTypes().errorRecoveryTypePack());
return;
}
if (auto a = get<TypePack>(haystack))
if (auto a = get<TypePack>(haystack); a && a->tail)
{
for (const auto& ty : a->head)
{
if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions)
{
if (auto f = log.getMutable<FunctionTypeVar>(log.follow(ty)))
{
occursCheck(seen, needle, f->argTypes);
occursCheck(seen, needle, f->retType);
}
}
}
if (a->tail)
{
haystack = follow(*a->tail);
continue;
}
haystack = log.follow(*a->tail);
continue;
}
break;
}
}
@ -3048,31 +2981,17 @@ void Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
{
if (needle == haystack)
{
errors.push_back(TypeError{location, OccursCheckFailed{}});
reportError(TypeError{location, OccursCheckFailed{}});
DEPRECATED_log(needle);
*asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack();
}
if (auto a = get<TypePack>(haystack))
if (auto a = get<TypePack>(haystack); a && a->tail)
{
if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions)
{
for (const auto& ty : a->head)
{
if (auto f = get<FunctionTypeVar>(follow(ty)))
{
occursCheck(seen, needle, f->argTypes);
occursCheck(seen, needle, f->retType);
}
}
}
if (a->tail)
{
haystack = follow(*a->tail);
continue;
}
haystack = follow(*a->tail);
continue;
}
break;
}
}
@ -3094,17 +3013,17 @@ bool Unifier::isNonstrictMode() const
void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType)
{
if (auto e = hasUnificationTooComplex(innerErrors))
errors.push_back(*e);
reportError(*e);
else if (!innerErrors.empty())
errors.push_back(TypeError{location, TypeMismatch{wantedType, givenType}});
reportError(TypeError{location, TypeMismatch{wantedType, givenType}});
}
void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType)
{
if (auto e = hasUnificationTooComplex(innerErrors))
errors.push_back(*e);
reportError(*e);
else if (!innerErrors.empty())
errors.push_back(
reportError(
TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front()}});
}

View File

@ -43,7 +43,7 @@ static void report(ReportFormat format, const char* name, const Luau::Location&
}
}
static void reportError(Luau::Frontend& frontend, ReportFormat format, const Luau::TypeError& error)
static void reportError(const Luau::Frontend& frontend, ReportFormat format, const Luau::TypeError& error)
{
std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(error.moduleName);

View File

@ -15,8 +15,6 @@
#include <string.h>
#define READ_BUFFER_SIZE 4096
#ifdef _WIN32
static std::wstring fromUtf8(const std::string& path)
{
@ -79,9 +77,9 @@ std::optional<std::string> readFile(const std::string& name)
std::optional<std::string> readStdin()
{
std::string result;
char buffer[READ_BUFFER_SIZE] = { };
char buffer[4096] = { };
while (fgets(buffer, READ_BUFFER_SIZE, stdin) != nullptr)
while (fgets(buffer, sizeof(buffer), stdin) != nullptr)
result.append(buffer);
// If eof was not reached for stdin, then a read error occurred

View File

@ -158,7 +158,7 @@ static int lua_collectgarbage(lua_State* L)
luaL_error(L, "collectgarbage must be called with 'count' or 'collect'");
}
static void setupState(lua_State* L)
void setupState(lua_State* L)
{
luaL_openlibs(L);
@ -176,7 +176,7 @@ static void setupState(lua_State* L)
luaL_sandbox(L);
}
static std::string runCode(lua_State* L, const std::string& source)
std::string runCode(lua_State* L, const std::string& source)
{
std::string bytecode = Luau::compile(source, copts());
@ -206,7 +206,13 @@ static std::string runCode(lua_State* L, const std::string& source)
if (n)
{
luaL_checkstack(T, LUA_MINSTACK, "too many results to print");
lua_getglobal(T, "print");
lua_getglobal(T, "_PRETTYPRINT");
// If _PRETTYPRINT is nil, then use the standard print function instead
if (lua_isnil(T, -1))
{
lua_pop(T, 1);
lua_getglobal(T, "print");
}
lua_insert(T, 1);
lua_pcall(T, n, 0, 0);
}
@ -545,7 +551,7 @@ static int assertionHandler(const char* expr, const char* file, int line, const
return 1;
}
int main(int argc, char** argv)
int replMain(int argc, char** argv)
{
Luau::assertHandler() = assertionHandler;
@ -696,5 +702,6 @@ int main(int argc, char** argv)
case CliMode::Unknown:
default:
LUAU_ASSERT(!"Unhandled cli mode.");
return 1;
}
}

12
CLI/Repl.h Normal file
View File

@ -0,0 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "lua.h"
#include <string>
// Note: These are internal functions which are being exposed in a header
// so they can be included by unit tests.
int replMain(int argc, char** argv);
void setupState(lua_State* L);
std::string runCode(lua_State* L, const std::string& source);

10
CLI/ReplEntry.cpp Normal file
View File

@ -0,0 +1,10 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Repl.h"
int main(int argc, char** argv)
{
return replMain(argc, argv);
}

View File

@ -29,6 +29,7 @@ endif()
if(LUAU_BUILD_TESTS)
add_executable(Luau.UnitTest)
add_executable(Luau.Conformance)
add_executable(Luau.CLI.Test)
endif()
if(LUAU_BUILD_WEB)
@ -109,6 +110,17 @@ if(LUAU_BUILD_TESTS)
target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS})
target_include_directories(Luau.Conformance PRIVATE extern)
target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.VM)
target_compile_options(Luau.CLI.Test PRIVATE ${LUAU_OPTIONS})
target_include_directories(Luau.CLI.Test PRIVATE extern CLI)
target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.VM)
if(UNIX)
find_library(LIBPTHREAD pthread)
if (LIBPTHREAD)
target_link_libraries(Luau.CLI.Test PRIVATE pthread)
endif()
endif()
endif()
if(LUAU_BUILD_WEB)

View File

@ -472,6 +472,9 @@ enum LuauBuiltinFunction
// bit32.count
LBF_BIT32_COUNTLZ,
LBF_BIT32_COUNTRZ,
// select(_, ...)
LBF_SELECT_VARARG,
};
// Capture type, used in LOP_CAPTURE

View File

@ -4,6 +4,8 @@
#include "Luau/Bytecode.h"
#include "Luau/Compiler.h"
LUAU_FASTFLAGVARIABLE(LuauCompileSelectBuiltin, false)
namespace Luau
{
namespace Compile
@ -62,6 +64,9 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options)
if (builtin.isGlobal("unpack"))
return LBF_TABLE_UNPACK;
if (FFlag::LuauCompileSelectBuiltin && builtin.isGlobal("select"))
return LBF_SELECT_VARARG;
if (builtin.object == "math")
{
if (builtin.method == "abs")

View File

@ -15,6 +15,9 @@
#include <bitset>
#include <math.h>
LUAU_FASTFLAGVARIABLE(LuauCompileTableIndexOpt, false)
LUAU_FASTFLAG(LuauCompileSelectBuiltin)
namespace Luau
{
@ -261,6 +264,122 @@ struct Compiler
bytecode.emitABC(LOP_GETVARARGS, target, multRet ? 0 : uint8_t(targetCount + 1), 0);
}
void compileExprSelectVararg(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs)
{
LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin);
LUAU_ASSERT(targetCount == 1);
LUAU_ASSERT(!expr->self);
LUAU_ASSERT(expr->args.size == 2 && expr->args.data[1]->is<AstExprVarargs>());
AstExpr* arg = expr->args.data[0];
uint8_t argreg;
if (isExprLocalReg(arg))
argreg = getLocal(arg->as<AstExprLocal>()->local);
else
{
argreg = uint8_t(regs + 1);
compileExprTempTop(arg, argreg);
}
size_t fastcallLabel = bytecode.emitLabel();
bytecode.emitABC(LOP_FASTCALL1, LBF_SELECT_VARARG, argreg, 0);
// note, these instructions are normally not executed and are used as a fallback for FASTCALL
// we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten
compileExprTemp(expr->func, regs);
bytecode.emitABC(LOP_GETVARARGS, uint8_t(regs + 2), 0, 0);
size_t callLabel = bytecode.emitLabel();
if (!bytecode.patchSkipC(fastcallLabel, callLabel))
CompileError::raise(expr->func->location, "Exceeded jump distance limit; simplify the code to compile");
// note, this is always multCall (last argument is variadic)
bytecode.emitABC(LOP_CALL, regs, 0, multRet ? 0 : uint8_t(targetCount + 1));
// if we didn't output results directly to target, we need to move them
if (!targetTop)
{
for (size_t i = 0; i < targetCount; ++i)
bytecode.emitABC(LOP_MOVE, uint8_t(target + i), uint8_t(regs + i), 0);
}
}
void compileExprFastcallN(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs, int bfid)
{
LUAU_ASSERT(!expr->self);
LUAU_ASSERT(expr->args.size <= 2);
LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : LOP_FASTCALL2;
uint32_t args[2] = {};
for (size_t i = 0; i < expr->args.size; ++i)
{
if (i > 0)
{
if (int32_t cid = getConstantIndex(expr->args.data[i]); cid >= 0)
{
opc = LOP_FASTCALL2K;
args[i] = cid;
break;
}
}
if (isExprLocalReg(expr->args.data[i]))
args[i] = getLocal(expr->args.data[i]->as<AstExprLocal>()->local);
else
{
args[i] = uint8_t(regs + 1 + i);
compileExprTempTop(expr->args.data[i], uint8_t(args[i]));
}
}
size_t fastcallLabel = bytecode.emitLabel();
bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0);
if (opc != LOP_FASTCALL1)
bytecode.emitAux(args[1]);
// Set up a traditional Lua stack for the subsequent LOP_CALL.
// Note, as with other instructions that immediately follow FASTCALL, these are normally not executed and are used as a fallback for
// these FASTCALL variants.
for (size_t i = 0; i < expr->args.size; ++i)
{
if (i > 0 && opc == LOP_FASTCALL2K)
{
emitLoadK(uint8_t(regs + 1 + i), args[i]);
break;
}
if (args[i] != regs + 1 + i)
bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0);
}
// note, these instructions are normally not executed and are used as a fallback for FASTCALL
// we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten
compileExprTemp(expr->func, regs);
size_t callLabel = bytecode.emitLabel();
// FASTCALL will skip over the instructions needed to compute function and jump over CALL which must immediately follow the instruction
// sequence after FASTCALL
if (!bytecode.patchSkipC(fastcallLabel, callLabel))
CompileError::raise(expr->func->location, "Exceeded jump distance limit; simplify the code to compile");
bytecode.emitABC(LOP_CALL, regs, uint8_t(expr->args.size + 1), multRet ? 0 : uint8_t(targetCount + 1));
// if we didn't output results directly to target, we need to move them
if (!targetTop)
{
for (size_t i = 0; i < targetCount; ++i)
bytecode.emitABC(LOP_MOVE, uint8_t(target + i), uint8_t(regs + i), 0);
}
}
void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false)
{
LUAU_ASSERT(!targetTop || unsigned(target + targetCount) == regTop);
@ -284,6 +403,25 @@ struct Compiler
bfid = getBuiltinFunctionId(builtin, options);
}
if (bfid == LBF_SELECT_VARARG)
{
LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin);
// Optimization: compile select(_, ...) as FASTCALL1; the builtin will read variadic arguments directly
// note: for now we restrict this to single-return expressions since our runtime code doesn't deal with general cases
if (multRet == false && targetCount == 1 && expr->args.size == 2 && expr->args.data[1]->is<AstExprVarargs>())
return compileExprSelectVararg(expr, target, targetCount, targetTop, multRet, regs);
else
bfid = -1;
}
// Optimization: for 1/2 argument fast calls use specialized opcodes
if (!expr->self && bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2)
{
AstExpr* last = expr->args.data[expr->args.size - 1];
if (!last->is<AstExprCall>() && !last->is<AstExprVarargs>())
return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid);
}
if (expr->self)
{
AstExprIndexName* fi = expr->func->as<AstExprIndexName>();
@ -309,24 +447,13 @@ struct Compiler
compileExprTempTop(expr->func, regs);
}
// Note: if the last argument is ExprVararg or ExprCall, we need to route that directly to the called function preserving the # of args
bool multCall = false;
bool skipArgs = false;
if (!expr->self && bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2)
{
AstExpr* last = expr->args.data[expr->args.size - 1];
skipArgs = !(last->is<AstExprCall>() || last->is<AstExprVarargs>());
}
if (!skipArgs)
{
for (size_t i = 0; i < expr->args.size; ++i)
if (i + 1 == expr->args.size)
multCall = compileExprTempMultRet(expr->args.data[i], uint8_t(regs + 1 + expr->self + i));
else
compileExprTempTop(expr->args.data[i], uint8_t(regs + 1 + expr->self + i));
}
for (size_t i = 0; i < expr->args.size; ++i)
if (i + 1 == expr->args.size)
multCall = compileExprTempMultRet(expr->args.data[i], uint8_t(regs + 1 + expr->self + i));
else
compileExprTempTop(expr->args.data[i], uint8_t(regs + 1 + expr->self + i));
setDebugLineEnd(expr->func);
@ -347,59 +474,8 @@ struct Compiler
}
else if (bfid >= 0)
{
size_t fastcallLabel;
if (skipArgs)
{
LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : LOP_FASTCALL2;
uint32_t args[2] = {};
for (size_t i = 0; i < expr->args.size; ++i)
{
if (i > 0)
{
if (int32_t cid = getConstantIndex(expr->args.data[i]); cid >= 0)
{
opc = LOP_FASTCALL2K;
args[i] = cid;
break;
}
}
if (isExprLocalReg(expr->args.data[i]))
args[i] = getLocal(expr->args.data[i]->as<AstExprLocal>()->local);
else
{
args[i] = uint8_t(regs + 1 + i);
compileExprTempTop(expr->args.data[i], uint8_t(args[i]));
}
}
fastcallLabel = bytecode.emitLabel();
bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0);
if (opc != LOP_FASTCALL1)
bytecode.emitAux(args[1]);
// Set up a traditional Lua stack for the subsequent LOP_CALL.
// Note, as with other instructions that immediately follow FASTCALL, these are normally not executed and are used as a fallback for
// these FASTCALL variants.
for (size_t i = 0; i < expr->args.size; ++i)
{
if (i > 0 && opc == LOP_FASTCALL2K)
{
emitLoadK(uint8_t(regs + 1 + i), args[i]);
break;
}
if (args[i] != regs + 1 + i)
bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0);
}
}
else
{
fastcallLabel = bytecode.emitLabel();
bytecode.emitABC(LOP_FASTCALL, uint8_t(bfid), 0, 0);
}
size_t fastcallLabel = bytecode.emitLabel();
bytecode.emitABC(LOP_FASTCALL, uint8_t(bfid), 0, 0);
// note, these instructions are normally not executed and are used as a fallback for FASTCALL
// we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten
@ -1101,9 +1177,20 @@ struct Compiler
for (size_t i = 0; i < expr->items.size; ++i)
{
const AstExprTable::Item& item = expr->items.data[i];
AstExprConstantNumber* ckey = item.key->as<AstExprConstantNumber>();
LUAU_ASSERT(item.key); // no list portion => all items have keys
indexSize += (ckey && ckey->value == double(indexSize + 1));
if (FFlag::LuauCompileTableIndexOpt)
{
const Constant* ckey = constants.find(item.key);
indexSize += (ckey && ckey->type == Constant::Type_Number && ckey->valueNumber == double(indexSize + 1));
}
else
{
AstExprConstantNumber* ckey = item.key->as<AstExprConstantNumber>();
indexSize += (ckey && ckey->value == double(indexSize + 1));
}
}
// we only perform the optimization if we don't have any other []-keys
@ -1200,37 +1287,47 @@ struct Compiler
arrayChunkCurrent = 0;
}
// items with a key are set one by one via SETTABLE/SETTABLEKS
// items with a key are set one by one via SETTABLE/SETTABLEKS/SETTABLEN
if (key)
{
RegScope rsi(this);
// Optimization: use SETTABLEKS/SETTABLEN for literal keys, this happens often as part of usual table construction syntax
if (AstExprConstantString* ckey = key->as<AstExprConstantString>())
if (FFlag::LuauCompileTableIndexOpt)
{
BytecodeBuilder::StringRef cname = sref(ckey->value);
int32_t cid = bytecode.addConstantString(cname);
if (cid < 0)
CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile");
LValue lv = compileLValueIndex(reg, key, rsi);
uint8_t rv = compileExprAuto(value, rsi);
bytecode.emitABC(LOP_SETTABLEKS, rv, reg, uint8_t(BytecodeBuilder::getStringHash(cname)));
bytecode.emitAux(cid);
}
else if (AstExprConstantNumber* ckey = key->as<AstExprConstantNumber>();
ckey && ckey->value >= 1 && ckey->value <= 256 && double(int(ckey->value)) == ckey->value)
{
uint8_t rv = compileExprAuto(value, rsi);
bytecode.emitABC(LOP_SETTABLEN, rv, reg, uint8_t(int(ckey->value) - 1));
compileAssign(lv, rv);
}
else
{
uint8_t rk = compileExprAuto(key, rsi);
uint8_t rv = compileExprAuto(value, rsi);
// Optimization: use SETTABLEKS/SETTABLEN for literal keys, this happens often as part of usual table construction syntax
if (AstExprConstantString* ckey = key->as<AstExprConstantString>())
{
BytecodeBuilder::StringRef cname = sref(ckey->value);
int32_t cid = bytecode.addConstantString(cname);
if (cid < 0)
CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile");
bytecode.emitABC(LOP_SETTABLE, rv, reg, rk);
uint8_t rv = compileExprAuto(value, rsi);
bytecode.emitABC(LOP_SETTABLEKS, rv, reg, uint8_t(BytecodeBuilder::getStringHash(cname)));
bytecode.emitAux(cid);
}
else if (AstExprConstantNumber* ckey = key->as<AstExprConstantNumber>();
ckey && ckey->value >= 1 && ckey->value <= 256 && double(int(ckey->value)) == ckey->value)
{
uint8_t rv = compileExprAuto(value, rsi);
bytecode.emitABC(LOP_SETTABLEN, rv, reg, uint8_t(int(ckey->value) - 1));
}
else
{
uint8_t rk = compileExprAuto(key, rsi);
uint8_t rv = compileExprAuto(value, rsi);
bytecode.emitABC(LOP_SETTABLE, rv, reg, rk);
}
}
}
// items without a key are set using SETLIST so that we can initialize large arrays quickly
@ -1339,6 +1436,9 @@ struct Compiler
uint8_t rt = compileExprAuto(expr->expr, rs);
uint8_t i = uint8_t(int(cv->valueNumber) - 1);
if (FFlag::LuauCompileTableIndexOpt)
setDebugLine(expr->index);
bytecode.emitABC(LOP_GETTABLEN, target, rt, i);
}
else if (cv && cv->type == Constant::Type_String)
@ -1350,6 +1450,9 @@ struct Compiler
if (cid < 0)
CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile");
if (FFlag::LuauCompileTableIndexOpt)
setDebugLine(expr->index);
bytecode.emitABC(LOP_GETTABLEKS, target, rt, uint8_t(BytecodeBuilder::getStringHash(iname)));
bytecode.emitAux(cid);
}
@ -1657,6 +1760,40 @@ struct Compiler
Location location;
};
LValue compileLValueIndex(uint8_t reg, AstExpr* index, RegScope& rs)
{
const Constant* cv = constants.find(index);
if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 &&
double(int(cv->valueNumber)) == cv->valueNumber)
{
LValue result = {LValue::Kind_IndexNumber};
result.reg = reg;
result.number = uint8_t(int(cv->valueNumber) - 1);
result.location = index->location;
return result;
}
else if (cv && cv->type == Constant::Type_String)
{
LValue result = {LValue::Kind_IndexName};
result.reg = reg;
result.name = sref(cv->getString());
result.location = index->location;
return result;
}
else
{
LValue result = {LValue::Kind_IndexExpr};
result.reg = reg;
result.index = compileExprAuto(index, rs);
result.location = index->location;
return result;
}
}
LValue compileLValue(AstExpr* node, RegScope& rs)
{
setDebugLine(node);
@ -1699,36 +1836,9 @@ struct Compiler
}
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
{
const Constant* cv = constants.find(expr->index);
uint8_t reg = compileExprAuto(expr->expr, rs);
if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 &&
double(int(cv->valueNumber)) == cv->valueNumber)
{
LValue result = {LValue::Kind_IndexNumber};
result.reg = compileExprAuto(expr->expr, rs);
result.number = uint8_t(int(cv->valueNumber) - 1);
result.location = node->location;
return result;
}
else if (cv && cv->type == Constant::Type_String)
{
LValue result = {LValue::Kind_IndexName};
result.reg = compileExprAuto(expr->expr, rs);
result.name = sref(cv->getString());
result.location = node->location;
return result;
}
else
{
LValue result = {LValue::Kind_IndexExpr};
result.reg = compileExprAuto(expr->expr, rs);
result.index = compileExprAuto(expr->index, rs);
result.location = node->location;
return result;
}
return compileLValueIndex(reg, expr->index, rs);
}
else
{
@ -1740,6 +1850,9 @@ struct Compiler
void compileLValueUse(const LValue& lv, uint8_t reg, bool set)
{
if (FFlag::LuauCompileTableIndexOpt)
setDebugLine(lv.location);
switch (lv.kind)
{
case LValue::Kind_Local:

View File

@ -23,11 +23,11 @@ VM_SOURCES=$(wildcard VM/src/*.cpp)
VM_OBJECTS=$(VM_SOURCES:%=$(BUILD)/%.o)
VM_TARGET=$(BUILD)/libluauvm.a
TESTS_SOURCES=$(wildcard tests/*.cpp)
TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp
TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o)
TESTS_TARGET=$(BUILD)/luau-tests
REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp
REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp CLI/ReplEntry.cpp
REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o)
REPL_CLI_TARGET=$(BUILD)/luau
@ -90,11 +90,12 @@ $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include
$(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -IAst/include
$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include
$(VM_OBJECTS): CXXFLAGS+=-std=c++11 -IVM/include
$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -Iextern
$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -ICLI -Iextern
$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IVM/include -Iextern
$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include -Iextern
$(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include
$(TESTS_TARGET): LDFLAGS+=-lpthread
$(REPL_CLI_TARGET): LDFLAGS+=-lpthread
fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a build/libprotobuf-mutator/external.protobuf/lib/libprotobuf.a

View File

@ -176,7 +176,8 @@ if(TARGET Luau.Repl.CLI)
CLI/FileUtils.cpp
CLI/Profiler.h
CLI/Profiler.cpp
CLI/Repl.cpp)
CLI/Repl.cpp
CLI/ReplEntry.cpp)
endif()
if(TARGET Luau.Analyze.CLI)
@ -243,6 +244,21 @@ if(TARGET Luau.Conformance)
tests/main.cpp)
endif()
if(TARGET Luau.CLI.Test)
# Luau.CLI.Test Sources
target_sources(Luau.CLI.Test PRIVATE
CLI/Coverage.h
CLI/Coverage.cpp
CLI/FileUtils.h
CLI/FileUtils.cpp
CLI/Profiler.h
CLI/Profiler.cpp
CLI/Repl.cpp
tests/Repl.test.cpp
tests/main.cpp)
endif()
if(TARGET Luau.Web)
# Luau.Web Sources
target_sources(Luau.Web PRIVATE

View File

@ -14,6 +14,8 @@
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauGcForwardMetatableBarrier, false)
const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n"
"$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n"
"$URL: www.lua.org $\n";
@ -869,7 +871,16 @@ int lua_setmetatable(lua_State* L, int objindex)
luaG_runerror(L, "Attempt to modify a readonly table");
hvalue(obj)->metatable = mt;
if (mt)
luaC_objbarriert(L, hvalue(obj), mt);
{
if (FFlag::LuauGcForwardMetatableBarrier)
{
luaC_objbarrier(L, hvalue(obj), mt);
}
else
{
luaC_objbarriert(L, hvalue(obj), mt);
}
}
break;
}
case LUA_TUSERDATA:

View File

@ -1087,6 +1087,34 @@ static int luauF_countrz(lua_State* L, StkId res, TValue* arg0, int nresults, St
return -1;
}
static int luauF_select(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
if (nparams == 1 && nresults == 1)
{
int n = cast_int(L->base - L->ci->func) - clvalue(L->ci->func)->l.p->numparams - 1;
if (ttisnumber(arg0))
{
int i = int(nvalue(arg0));
// i >= 1 && i <= n
if (unsigned(i - 1) <= unsigned(n))
{
setobj2s(L, res, L->base - n + (i - 1));
return 1;
}
// note: for now we don't handle negative case (wrap around) and defer to fallback
}
else if (ttisstring(arg0) && *svalue(arg0) == '#')
{
setnvalue(res, double(n));
return 1;
}
}
return -1;
}
luau_FastFunction luauF_table[256] = {
NULL,
luauF_assert,
@ -1156,4 +1184,6 @@ luau_FastFunction luauF_table[256] = {
luauF_countlz,
luauF_countrz,
luauF_select,
};

View File

@ -5,8 +5,6 @@
#include "lstate.h"
#include "lvm.h"
LUAU_FASTFLAGVARIABLE(LuauCoroutineClose, false)
#define CO_RUN 0 /* running */
#define CO_SUS 1 /* suspended */
#define CO_NOR 2 /* 'normal' (it resumed another coroutine) */
@ -235,9 +233,6 @@ static int coyieldable(lua_State* L)
static int coclose(lua_State* L)
{
if (!FFlag::LuauCoroutineClose)
luaL_error(L, "coroutine.close is not enabled");
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");

View File

@ -17,8 +17,6 @@
#include <string.h>
LUAU_FASTFLAG(LuauCoroutineClose)
/*
** {======================================================
** Error-recovery functions
@ -300,7 +298,7 @@ static void resume(lua_State* L, void* ud)
{
// start coroutine
LUAU_ASSERT(L->ci == L->base_ci && firstArg >= L->base);
if (FFlag::LuauCoroutineClose && firstArg == L->base)
if (firstArg == L->base)
luaG_runerror(L, "cannot resume dead coroutine");
if (luau_precall(L, firstArg - 1, LUA_MULTRET) != PCRLUA)

View File

@ -93,10 +93,8 @@ static void finishGcCycleStats(global_State* g)
g->gcstats.lastcycle = g->gcstats.currcycle;
g->gcstats.currcycle = GCCycleStats();
g->gcstats.cyclestatsacc.markitems += g->gcstats.lastcycle.markitems;
g->gcstats.cyclestatsacc.marktime += g->gcstats.lastcycle.marktime;
g->gcstats.cyclestatsacc.atomictime += g->gcstats.lastcycle.atomictime;
g->gcstats.cyclestatsacc.sweepitems += g->gcstats.lastcycle.sweepitems;
g->gcstats.cyclestatsacc.sweeptime += g->gcstats.lastcycle.sweeptime;
}
@ -492,23 +490,22 @@ static void freeobj(lua_State* L, GCObject* o, lua_Page* page)
}
}
#define sweepwholelist(L, p, tc) sweeplist(L, p, SIZE_MAX, tc)
#define sweepwholelist(L, p) sweeplist(L, p, SIZE_MAX)
static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* traversedcount)
static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count)
{
LUAU_ASSERT(!FFlag::LuauGcPagedSweep);
GCObject* curr;
global_State* g = L->global;
int deadmask = otherwhite(g);
size_t startcount = count;
LUAU_ASSERT(testbit(deadmask, FIXEDBIT)); /* make sure we never sweep fixed objects */
while ((curr = *p) != NULL && count-- > 0)
{
int alive = (curr->gch.marked ^ WHITEBITS) & deadmask;
if (curr->gch.tt == LUA_TTHREAD)
{
sweepwholelist(L, (GCObject**)&gco2th(curr)->openupval, traversedcount); /* sweep open upvalues */
sweepwholelist(L, (GCObject**)&gco2th(curr)->openupval); /* sweep open upvalues */
lua_State* th = gco2th(curr);
@ -534,10 +531,6 @@ static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* tr
}
}
// if we didn't reach the end of the list it means that we've stopped because the count dropped below zero
if (traversedcount)
*traversedcount += startcount - (curr ? count + 1 : count);
return p;
}
@ -721,8 +714,6 @@ static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco)
int alive = (gco->gch.marked ^ WHITEBITS) & deadmask;
g->gcstats.currcycle.sweepitems++;
if (gco->gch.tt == LUA_TTHREAD)
{
lua_State* th = gco2th(gco);
@ -793,8 +784,6 @@ static size_t gcstep(lua_State* L, size_t limit)
{
while (g->gray && cost < limit)
{
g->gcstats.currcycle.markitems++;
cost += propagatemark(g);
}
@ -812,8 +801,6 @@ static size_t gcstep(lua_State* L, size_t limit)
{
while (g->gray && cost < limit)
{
g->gcstats.currcycle.markitems++;
cost += propagatemark(g);
}
@ -842,10 +829,8 @@ static size_t gcstep(lua_State* L, size_t limit)
while (g->sweepstrgc < g->strt.size && cost < limit)
{
size_t traversedcount = 0;
sweepwholelist(L, (GCObject**)&g->strt.hash[g->sweepstrgc++], &traversedcount);
sweepwholelist(L, (GCObject**)&g->strt.hash[g->sweepstrgc++]);
g->gcstats.currcycle.sweepitems += traversedcount;
cost += GC_SWEEPCOST;
}
@ -855,12 +840,10 @@ static size_t gcstep(lua_State* L, size_t limit)
// sweep string buffer list and preserve used string count
uint32_t nuse = L->global->strt.nuse;
size_t traversedcount = 0;
sweepwholelist(L, (GCObject**)&g->strbufgc, &traversedcount);
sweepwholelist(L, (GCObject**)&g->strbufgc);
L->global->strt.nuse = nuse;
g->gcstats.currcycle.sweepitems += traversedcount;
g->gcstate = GCSsweep; // end sweep-string phase
}
break;
@ -893,10 +876,8 @@ static size_t gcstep(lua_State* L, size_t limit)
{
while (*g->sweepgc && cost < limit)
{
size_t traversedcount = 0;
g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount);
g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX);
g->gcstats.currcycle.sweepitems += traversedcount;
cost += GC_SWEEPMAX * GC_SWEEPCOST;
}

View File

@ -113,6 +113,7 @@
luaC_barrierf(L, obj2gco(p), obj2gco(o)); \
}
// TODO: remove with FFlagLuauGcForwardMetatableBarrier
#define luaC_objbarriert(L, t, o) \
{ \
if (isblack(obj2gco(t)) && iswhite(obj2gco(o))) \

View File

@ -96,9 +96,6 @@ struct GCCycleStats
double sweeptime = 0.0;
size_t markitems = 0;
size_t sweepitems = 0;
size_t assistwork = 0;
size_t explicitwork = 0;

View File

@ -44,10 +44,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "prop")
TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg")
{
ScopedFastFlag sffs[] = {
{"LuauPersistDefinitionFileTypes", true},
};
loadDefinition(R"(
declare function Connect(fn: (string) -> ())
)");
@ -63,8 +59,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg")
TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_fn")
{
ScopedFastFlag sffs{"LuauStoreMatchingOverloadFnType", true};
loadDefinition(R"(
declare foo: ((string) -> number) & ((number) -> string)
)");

View File

@ -2626,7 +2626,6 @@ local a: A<(number, s@1>
TEST_CASE_FIXTURE(ACFixture, "autocomplete_first_function_arg_expected_type")
{
ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true);
ScopedFastFlag luauAutocompleteFirstArg("LuauAutocompleteFirstArg", true);
check(R"(
local function foo1() return 1 end
@ -2728,4 +2727,39 @@ end
CHECK(ac.entryMap.count("getx"));
}
TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons")
{
ScopedFastFlag sffs[] = {
{"LuauParseSingletonTypes", true},
{"LuauSingletonTypes", true},
{"LuauRefactorTypeVarQuestions", true},
};
check(R"(
--!strict
local foo: "hello" | "bye" = "hello"
foo:@1
)");
auto ac = autocomplete('1');
CHECK(ac.entryMap.count("format"));
}
TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses_2")
{
ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true);
ScopedFastFlag preferToCallFunctionsForIntersects("PreferToCallFunctionsForIntersects", true);
check(R"(
local bar: ((number) -> number) & (number, number) -> number)
local abc = b@1
)");
auto ac = autocomplete('1');
CHECK(ac.entryMap.count("bar"));
CHECK(ac.entryMap["bar"].parens == ParenthesesRecommendation::CursorInside);
}
TEST_SUITE_END();

View File

@ -603,6 +603,37 @@ RETURN R0 1
)");
}
TEST_CASE("TableLiteralsIndexConstant")
{
ScopedFastFlag sff("LuauCompileTableIndexOpt", true);
// validate that we use SETTTABLEKS for constant variable keys
CHECK_EQ("\n" + compileFunction0(R"(
local a, b = "key", "value"
return {[a] = 42, [b] = 0}
)"), R"(
NEWTABLE R0 2 0
LOADN R1 42
SETTABLEKS R1 R0 K0
LOADN R1 0
SETTABLEKS R1 R0 K1
RETURN R0 1
)");
// validate that we use SETTABLEN for constant variable keys *and* that we predict array size
CHECK_EQ("\n" + compileFunction0(R"(
local a, b = 1, 2
return {[a] = 42, [b] = 0}
)"), R"(
NEWTABLE R0 0 2
LOADN R1 42
SETTABLEN R1 R0 1
LOADN R1 0
SETTABLEN R1 R0 2
RETURN R0 1
)");
}
TEST_CASE("TableSizePredictionBasic")
{
CHECK_EQ("\n" + compileFunction0(R"(
@ -2450,6 +2481,37 @@ return
)");
}
TEST_CASE("DebugLineInfoAssignment")
{
ScopedFastFlag sff("LuauCompileTableIndexOpt", true);
Luau::BytecodeBuilder bcb;
bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines);
Luau::compileOrThrow(bcb, R"(
local a = { b = { c = { d = 3 } } }
a
["b"]
["c"]
["d"] = 4
)");
CHECK_EQ("\n" + bcb.dumpFunction(0), R"(
2: DUPTABLE R0 1
2: DUPTABLE R1 3
2: DUPTABLE R2 5
2: LOADN R3 3
2: SETTABLEKS R3 R2 K4
2: SETTABLEKS R2 R1 K2
2: SETTABLEKS R1 R0 K0
5: GETTABLEKS R2 R0 K0
6: GETTABLEKS R1 R2 K2
7: LOADN R2 4
7: SETTABLEKS R2 R1 K4
8: RETURN R0 0
)");
}
TEST_CASE("DebugSource")
{
const char* source = R"(
@ -2763,6 +2825,75 @@ RETURN R1 -1
)");
}
TEST_CASE("FastcallSelect")
{
ScopedFastFlag sff("LuauCompileSelectBuiltin", true);
// select(_, ...) compiles to a builtin call
CHECK_EQ("\n" + compileFunction0("return (select('#', ...))"), R"(
LOADK R1 K0
FASTCALL1 57 R1 +3
GETIMPORT R0 2
GETVARARGS R2 -1
CALL R0 -1 1
RETURN R0 1
)");
// more complex example: select inside a for loop bound + select from a iterator
CHECK_EQ("\n" + compileFunction0(R"(
local sum = 0
for i=1, select('#', ...) do
sum += select(i, ...)
end
return sum
)"), R"(
LOADN R0 0
LOADN R3 1
LOADK R5 K0
FASTCALL1 57 R5 +3
GETIMPORT R4 2
GETVARARGS R6 -1
CALL R4 -1 1
MOVE R1 R4
LOADN R2 1
FORNPREP R1 +7
FASTCALL1 57 R3 +3
GETIMPORT R4 2
GETVARARGS R6 -1
CALL R4 -1 1
ADD R0 R0 R4
FORNLOOP R1 -7
RETURN R0 1
)");
// currently we assume a single value return to avoid dealing with stack resizing
CHECK_EQ("\n" + compileFunction0("return select('#', ...)"), R"(
GETIMPORT R0 1
LOADK R1 K2
GETVARARGS R2 -1
CALL R0 -1 -1
RETURN R0 -1
)");
// note that select with a non-variadic second argument doesn't get optimized
CHECK_EQ("\n" + compileFunction0("return select('#')"), R"(
GETIMPORT R0 1
LOADK R1 K2
CALL R0 1 -1
RETURN R0 -1
)");
// note that select with a non-variadic second argument doesn't get optimized
CHECK_EQ("\n" + compileFunction0("return select('#', foo())"), R"(
GETIMPORT R0 1
LOADK R1 K2
GETIMPORT R2 4
CALL R2 0 -1
CALL R0 -1 -1
RETURN R0 -1
)");
}
TEST_CASE("LotsOfParameters")
{
const char* source = R"(

View File

@ -331,8 +331,6 @@ TEST_CASE("UTF8")
TEST_CASE("Coroutine")
{
ScopedFastFlag sff("LuauCoroutineClose", true);
runConformance("coroutine.lua");
}

View File

@ -956,7 +956,6 @@ TEST_CASE("no_use_after_free_with_type_fun_instantiation")
{
// This flag forces this test to crash if there's a UAF in this code.
ScopedFastFlag sff_DebugLuauFreezeArena("DebugLuauFreezeArena", true);
ScopedFastFlag sff_LuauCloneCorrectlyBeforeMutatingTableType("LuauCloneCorrectlyBeforeMutatingTableType", true);
FrontendFixture fix;

View File

@ -2000,6 +2000,73 @@ TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors")
matchParseError("type Y<T... = (string) -> number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}});
}
TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression")
{
{
AstStat* stat = parse("return if true then 1 else 2");
REQUIRE(stat != nullptr);
AstStatReturn* str = stat->as<AstStatBlock>()->body.data[0]->as<AstStatReturn>();
REQUIRE(str != nullptr);
CHECK(str->list.size == 1);
auto* ifElseExpr = str->list.data[0]->as<AstExprIfElse>();
REQUIRE(ifElseExpr != nullptr);
}
{
AstStat* stat = parse("return if true then 1 elseif true then 2 else 3");
REQUIRE(stat != nullptr);
AstStatReturn* str = stat->as<AstStatBlock>()->body.data[0]->as<AstStatReturn>();
REQUIRE(str != nullptr);
CHECK(str->list.size == 1);
auto* ifElseExpr1 = str->list.data[0]->as<AstExprIfElse>();
REQUIRE(ifElseExpr1 != nullptr);
auto* ifElseExpr2 = ifElseExpr1->falseExpr->as<AstExprIfElse>();
REQUIRE(ifElseExpr2 != nullptr);
}
// Use "else if" as opposed to elseif
{
AstStat* stat = parse("return if true then 1 else if true then 2 else 3");
REQUIRE(stat != nullptr);
AstStatReturn* str = stat->as<AstStatBlock>()->body.data[0]->as<AstStatReturn>();
REQUIRE(str != nullptr);
CHECK(str->list.size == 1);
auto* ifElseExpr1 = str->list.data[0]->as<AstExprIfElse>();
REQUIRE(ifElseExpr1 != nullptr);
auto* ifElseExpr2 = ifElseExpr1->falseExpr->as<AstExprIfElse>();
REQUIRE(ifElseExpr2 != nullptr);
}
// Use an if-else expression as the conditional expression of an if-else expression
{
AstStat* stat = parse("return if if true then false else true then 1 else 2");
REQUIRE(stat != nullptr);
AstStatReturn* str = stat->as<AstStatBlock>()->body.data[0]->as<AstStatReturn>();
REQUIRE(str != nullptr);
CHECK(str->list.size == 1);
auto* ifElseExpr = str->list.data[0]->as<AstExprIfElse>();
REQUIRE(ifElseExpr != nullptr);
auto* nestedIfElseExpr = ifElseExpr->condition->as<AstExprIfElse>();
REQUIRE(nestedIfElseExpr != nullptr);
}
}
TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters")
{
AstStat* stat = parse(R"(
type Packed<T...> = () -> T...
type A<X...> = Packed<X...>
type B<X...> = Packed<...number>
type C<X...> = Packed<(number, X...)>
)");
REQUIRE(stat != nullptr);
}
TEST_SUITE_END();
TEST_SUITE_BEGIN("ParseErrorRecovery");
@ -2504,71 +2571,4 @@ type Y<T..., U = T...> = (T...) -> U...
CHECK_EQ(1, result.errors.size());
}
TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression")
{
{
AstStat* stat = parse("return if true then 1 else 2");
REQUIRE(stat != nullptr);
AstStatReturn* str = stat->as<AstStatBlock>()->body.data[0]->as<AstStatReturn>();
REQUIRE(str != nullptr);
CHECK(str->list.size == 1);
auto* ifElseExpr = str->list.data[0]->as<AstExprIfElse>();
REQUIRE(ifElseExpr != nullptr);
}
{
AstStat* stat = parse("return if true then 1 elseif true then 2 else 3");
REQUIRE(stat != nullptr);
AstStatReturn* str = stat->as<AstStatBlock>()->body.data[0]->as<AstStatReturn>();
REQUIRE(str != nullptr);
CHECK(str->list.size == 1);
auto* ifElseExpr1 = str->list.data[0]->as<AstExprIfElse>();
REQUIRE(ifElseExpr1 != nullptr);
auto* ifElseExpr2 = ifElseExpr1->falseExpr->as<AstExprIfElse>();
REQUIRE(ifElseExpr2 != nullptr);
}
// Use "else if" as opposed to elseif
{
AstStat* stat = parse("return if true then 1 else if true then 2 else 3");
REQUIRE(stat != nullptr);
AstStatReturn* str = stat->as<AstStatBlock>()->body.data[0]->as<AstStatReturn>();
REQUIRE(str != nullptr);
CHECK(str->list.size == 1);
auto* ifElseExpr1 = str->list.data[0]->as<AstExprIfElse>();
REQUIRE(ifElseExpr1 != nullptr);
auto* ifElseExpr2 = ifElseExpr1->falseExpr->as<AstExprIfElse>();
REQUIRE(ifElseExpr2 != nullptr);
}
// Use an if-else expression as the conditional expression of an if-else expression
{
AstStat* stat = parse("return if if true then false else true then 1 else 2");
REQUIRE(stat != nullptr);
AstStatReturn* str = stat->as<AstStatBlock>()->body.data[0]->as<AstStatReturn>();
REQUIRE(str != nullptr);
CHECK(str->list.size == 1);
auto* ifElseExpr = str->list.data[0]->as<AstExprIfElse>();
REQUIRE(ifElseExpr != nullptr);
auto* nestedIfElseExpr = ifElseExpr->condition->as<AstExprIfElse>();
REQUIRE(nestedIfElseExpr != nullptr);
}
}
TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters")
{
AstStat* stat = parse(R"(
type Packed<T...> = () -> T...
type A<X...> = Packed<X...>
type B<X...> = Packed<...number>
type C<X...> = Packed<(number, X...)>
)");
REQUIRE(stat != nullptr);
}
TEST_SUITE_END();

117
tests/Repl.test.cpp Normal file
View File

@ -0,0 +1,117 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "lua.h"
#include "lualib.h"
#include "Repl.h"
#include "doctest.h"
#include <iostream>
#include <memory>
#include <string>
#include <vector>
class ReplFixture
{
public:
ReplFixture()
: luaState(luaL_newstate(), lua_close)
{
L = luaState.get();
setupState(L);
luaL_sandboxthread(L);
std::string result = runCode(L, prettyPrintSource);
}
// Returns all of the output captured from the pretty printer
std::string getCapturedOutput()
{
lua_getglobal(L, "capturedoutput");
const char* str = lua_tolstring(L, -1, nullptr);
std::string result(str);
lua_pop(L, 1);
return result;
}
lua_State* L;
private:
std::unique_ptr<lua_State, void (*)(lua_State*)> luaState;
// This is a simplicitic and incomplete pretty printer.
// It is included here to test that the pretty printer hook is being called.
// More elaborate tests to ensure correct output can be added if we introduce
// a more feature rich pretty printer.
std::string prettyPrintSource = R"(
-- Accumulate pretty printer output in `capturedoutput`
capturedoutput = ""
function arraytostring(arr)
local strings = {}
table.foreachi(arr, function(k,v) table.insert(strings, pptostring(v)) end )
return "{" .. table.concat(strings, ", ") .. "}"
end
function pptostring(x)
if type(x) == "table" then
-- Just assume array-like tables for now.
return arraytostring(x)
elseif type(x) == "string" then
return '"' .. x .. '"'
else
return tostring(x)
end
end
-- Note: Instead of calling print, the pretty printer just stores the output
-- in `capturedoutput` so we can check for the correct results.
function _PRETTYPRINT(...)
local args = table.pack(...)
local strings = {}
for i=1, args.n do
local item = args[i]
local str = pptostring(item, customoptions)
if i == 1 then
capturedoutput = capturedoutput .. str
else
capturedoutput = capturedoutput .. "\t" .. str
end
end
end
)";
};
TEST_SUITE_BEGIN("ReplPrettyPrint");
TEST_CASE_FIXTURE(ReplFixture, "AdditionStatement")
{
runCode(L, "return 30 + 12");
CHECK(getCapturedOutput() == "42");
}
TEST_CASE_FIXTURE(ReplFixture, "TableLiteral")
{
runCode(L, "return {1, 2, 3, 4}");
CHECK(getCapturedOutput() == "{1, 2, 3, 4}");
}
TEST_CASE_FIXTURE(ReplFixture, "StringLiteral")
{
runCode(L, "return 'str'");
CHECK(getCapturedOutput() == "\"str\"");
}
TEST_CASE_FIXTURE(ReplFixture, "TableWithStringLiterals")
{
runCode(L, "return {1, 'two', 3, 'four'}");
CHECK(getCapturedOutput() == "{1, \"two\", 3, \"four\"}");
}
TEST_CASE_FIXTURE(ReplFixture, "MultipleArguments")
{
runCode(L, "return 3, 'three'");
CHECK(getCapturedOutput() == "3\t\"three\"");
}
TEST_SUITE_END();

View File

@ -435,8 +435,6 @@ TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_T
TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_union")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
type F = ((() -> number)?) -> F?
local function f(p) return f end
@ -450,8 +448,6 @@ TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_union"
TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_intersection")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
function f() return f end
local a: ((number) -> ()) & typeof(f)

View File

@ -11,8 +11,6 @@ TEST_SUITE_BEGIN("TypeAliases");
TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
type F = () -> F?
local function f()
@ -194,8 +192,6 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic")
TEST_CASE_FIXTURE(Fixture, "corecursive_function_types")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
type A = () -> (number, B)
type B = () -> (string, A)

View File

@ -9,8 +9,6 @@
using namespace Luau;
LUAU_FASTFLAG(LuauExtendedFunctionMismatchError)
TEST_SUITE_BEGIN("GenericsTests");
TEST_CASE_FIXTURE(Fixture, "check_generic_function")
@ -656,11 +654,7 @@ local d: D = c
LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::LuauExtendedFunctionMismatchError)
CHECK_EQ(
toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '<T>() -> ()'; different number of generic type parameters)");
else
CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '<T>() -> ()')");
CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '<T>() -> ()'; different number of generic type parameters)");
}
TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_pack")
@ -675,11 +669,8 @@ local d: D = c
LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::LuauExtendedFunctionMismatchError)
CHECK_EQ(toString(result.errors[0]),
R"(Type '() -> ()' could not be converted into '<T...>() -> ()'; different number of generic type pack parameters)");
else
CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '<T...>() -> ()')");
CHECK_EQ(toString(result.errors[0]),
R"(Type '() -> ()' could not be converted into '<T...>() -> ()'; different number of generic type pack parameters)");
}
TEST_SUITE_END();

View File

@ -271,6 +271,32 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap")
CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b
}
// Also belongs in TypeInfer.refinements.test.cpp.
// Just needs to fully support equality refinement. Which is annoying without type states.
TEST_CASE_FIXTURE(Fixture, "discriminate_from_x_not_equal_to_nil")
{
ScopedFastFlag sff{"LuauDiscriminableUnions", true};
CheckResult result = check(R"(
type T = {x: string, y: number} | {x: nil, y: nil}
local function f(t: T)
if t.x ~= nil then
local foo = t
else
local bar = t
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("{| x: string, y: number |}", toString(requireTypeAtPosition({5, 28})));
// Should be {| x: nil, y: nil |}
CHECK_EQ("{| x: nil, y: nil |} | {| x: string, y: number |}", toString(requireTypeAtPosition({7, 28})));
}
TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5))
{
ScopedFastInt sffi{"LuauTarjanChildLimit", 1};
@ -590,8 +616,6 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table
TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param")
{
ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true};
// Mutability in type function application right now can create strange recursive types
// TODO: instantiation right now is problematic, in this example should either leave the Table type alone
// or it should rename the type to 'Self' so that the result will be 'Self<Table>'

View File

@ -6,11 +6,77 @@
#include "doctest.h"
LUAU_FASTFLAG(LuauDiscriminableUnions)
LUAU_FASTFLAG(LuauWeakEqConstraint)
LUAU_FASTFLAG(LuauQuantifyInPlace2)
using namespace Luau;
namespace
{
std::optional<ExprResult<TypePackId>> magicFunctionInstanceIsA(
TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult)
{
if (expr.args.size != 1)
return std::nullopt;
auto index = expr.func->as<Luau::AstExprIndexName>();
auto str = expr.args.data[0]->as<Luau::AstExprConstantString>();
if (!index || !str)
return std::nullopt;
std::optional<LValue> lvalue = tryGetLValue(*index->expr);
std::optional<TypeFun> tfun = scope->lookupType(std::string(str->value.data, str->value.size));
if (!lvalue || !tfun)
return std::nullopt;
unfreeze(typeChecker.globalTypes);
TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType});
freeze(typeChecker.globalTypes);
return ExprResult<TypePackId>{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}};
}
struct RefinementClassFixture : Fixture
{
RefinementClassFixture()
{
TypeArena& arena = typeChecker.globalTypes;
unfreeze(arena);
TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr});
getMutable<ClassTypeVar>(vec3)->props = {
{"X", Property{typeChecker.numberType}},
{"Y", Property{typeChecker.numberType}},
{"Z", Property{typeChecker.numberType}},
};
TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr});
TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType});
TypePackId isARets = arena.addTypePack({typeChecker.booleanType});
TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets});
getMutable<FunctionTypeVar>(isA)->magicFunction = magicFunctionInstanceIsA;
getMutable<ClassTypeVar>(inst)->props = {
{"Name", Property{typeChecker.stringType}},
{"IsA", Property{isA}},
};
TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr});
TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr});
getMutable<ClassTypeVar>(part)->props = {
{"Position", Property{vec3}},
};
typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3};
typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst};
typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder};
typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part};
freeze(typeChecker.globalTypes);
}
};
} // namespace
TEST_SUITE_BEGIN("RefinementTest");
TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint")
@ -196,8 +262,18 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope")
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0]));
if (FFlag::LuauDiscriminableUnions)
{
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("*unknown*", toString(requireTypeAtPosition({8, 44})));
CHECK_EQ("*unknown*", toString(requireTypeAtPosition({9, 38})));
}
else
{
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0]));
}
}
TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard")
@ -237,7 +313,6 @@ TEST_CASE_FIXTURE(Fixture, "impossible_type_narrow_is_not_an_error")
TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties")
{
CheckResult result = check(R"(
local t: {x: number?} = {x = 1}
@ -254,7 +329,6 @@ TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties")
TEST_CASE_FIXTURE(Fixture, "index_on_a_refined_property")
{
CheckResult result = check(R"(
local t: {x: {y: string}?} = {x = {y = "hello!"}}
@ -360,7 +434,10 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term")
TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue")
{
ScopedFastFlag sff1{"LuauEqConstraint", true};
ScopedFastFlag sff[] = {
{"LuauDiscriminableUnions", true},
{"LuauSingletonTypes", true},
};
CheckResult result = check(R"(
local function f(a: (string | number)?)
@ -374,16 +451,8 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue")
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::LuauWeakEqConstraint)
{
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == "hello"
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello"
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "string"); // a == "hello"
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello"
}
CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello")"); // a == "hello"
CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello"
}
TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil")
@ -416,7 +485,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil")
TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue")
{
ScopedFastFlag sff1{"LuauEqConstraint", true};
ScopedFastFlag sff{"LuauDiscriminableUnions", true};
ScopedFastFlag sff2{"LuauWeakEqConstraint", true};
CheckResult result = check(R"(
local function f(a, b: string?)
@ -428,16 +498,8 @@ TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue")
LUAU_REQUIRE_NO_ERRORS(result);
if (FFlag::LuauWeakEqConstraint)
{
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b
}
else
{
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "string?"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b
}
CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b
CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b
}
TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_equal")
@ -527,9 +589,17 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector")
end
)");
// This is kinda weird to see, but this actually only happens in Luau without Roblox type bindings because we don't have a Vector3 type.
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Unknown type 'Vector3'", toString(result.errors[0]));
if (FFlag::LuauDiscriminableUnions)
{
LUAU_REQUIRE_NO_ERRORS(result);
}
else
{
// This is kinda weird to see, but this actually only happens in Luau without Roblox type bindings because we don't have a Vector3 type.
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Unknown type 'Vector3'", toString(result.errors[0]));
}
CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28})));
}
@ -614,214 +684,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions")
CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function"
}
namespace
{
std::optional<ExprResult<TypePackId>> magicFunctionInstanceIsA(
TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult<TypePackId> exprResult)
{
if (expr.args.size != 1)
return std::nullopt;
auto index = expr.func->as<Luau::AstExprIndexName>();
auto str = expr.args.data[0]->as<Luau::AstExprConstantString>();
if (!index || !str)
return std::nullopt;
std::optional<LValue> lvalue = tryGetLValue(*index->expr);
std::optional<TypeFun> tfun = scope->lookupType(std::string(str->value.data, str->value.size));
if (!lvalue || !tfun)
return std::nullopt;
unfreeze(typeChecker.globalTypes);
TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType});
freeze(typeChecker.globalTypes);
return ExprResult<TypePackId>{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}};
}
struct RefinementClassFixture : Fixture
{
RefinementClassFixture()
{
TypeArena& arena = typeChecker.globalTypes;
unfreeze(arena);
TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr});
getMutable<ClassTypeVar>(vec3)->props = {
{"X", Property{typeChecker.numberType}},
{"Y", Property{typeChecker.numberType}},
{"Z", Property{typeChecker.numberType}},
};
TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr});
TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType});
TypePackId isARets = arena.addTypePack({typeChecker.booleanType});
TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets});
getMutable<FunctionTypeVar>(isA)->magicFunction = magicFunctionInstanceIsA;
getMutable<ClassTypeVar>(inst)->props = {
{"Name", Property{typeChecker.stringType}},
{"IsA", Property{isA}},
};
TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr});
TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr});
getMutable<ClassTypeVar>(part)->props = {
{"Position", Property{vec3}},
};
typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3};
typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst};
typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder};
typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part};
freeze(typeChecker.globalTypes);
}
};
} // namespace
TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector")
{
CheckResult result = check(R"(
local function f(vec)
local X, Y, Z = vec.X, vec.Y, vec.Z
if type(vec) == "vector" then
local foo = vec
elseif typeof(vec) == "Instance" then
local foo = vec
else
local foo = vec
end
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector"
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0]));
else
CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0]));
CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance"
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
else
CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
}
TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector")
{
CheckResult result = check(R"(
local function f(x: Instance | Vector3)
if typeof(x) == "Vector3" then
local foo = x
else
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28})));
}
TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata")
{
CheckResult result = check(R"(
local function f(x: string | number | Instance | Vector3)
if type(x) == "userdata" then
local foo = x
else
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("Instance | Vector3", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28})));
}
TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance")
{
CheckResult result = check(R"(
local function f(x: Part | Folder | string)
if typeof(x) == "Instance" then
local foo = x
else
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("string", toString(requireTypeAtPosition({5, 28})));
}
TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union")
{
CheckResult result = check(R"(
local function f(x: Part | Folder | Instance | string | Vector3 | any)
if typeof(x) == "Instance" then
local foo = x
else
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("Folder | Instance | Part", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("Vector3 | any | string", toString(requireTypeAtPosition({5, 28})));
}
TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table")
{
CheckResult result = check(R"(
--!nonstrict
local function f(x)
if typeof(x) == "Instance" and x:IsA("Folder") then
local foo = x
elseif typeof(x) == "table" then
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 28})));
CHECK_EQ("any", toString(requireTypeAtPosition({7, 28})));
}
TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part")
{
CheckResult result = check(R"(
local function f(x: Part | Folder | string)
if typeof(x) ~= "Instance" or not x:IsA("Part") then
local foo = x
else
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("Folder | string", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28})));
}
TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables")
{
CheckResult result = check(R"(
@ -1145,4 +1007,259 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x")
{
ScopedFastFlag sff[] = {
{"LuauDiscriminableUnions", true},
{"LuauParseSingletonTypes", true},
{"LuauSingletonTypes", true},
};
CheckResult result = check(R"(
type T = {tag: "missing", x: nil} | {tag: "exists", x: string}
local function f(t: T)
if t.x then
local foo = t
else
local bar = t
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28})));
CHECK_EQ(R"({| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28})));
}
TEST_CASE_FIXTURE(Fixture, "discriminate_tag")
{
ScopedFastFlag sff[] = {
{"LuauDiscriminableUnions", true},
{"LuauParseSingletonTypes", true},
{"LuauSingletonTypes", true},
};
CheckResult result = check(R"(
type Cat = {tag: "Cat", name: string, catfood: string}
type Dog = {tag: "Dog", name: string, dogfood: string}
type Animal = Cat | Dog
local function f(animal: Animal)
if animal.tag == "Cat" then
local cat: Cat = animal
elseif animal.tag == "Dog" then
local dog: Dog = animal
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33})));
CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33})));
}
TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string")
{
ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true};
CheckResult result = check(R"(
type T = { [string]: { prop: number }? }
local t: T = {}
if t["hello"] then
local foo = t["hello"].prop
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement")
{
CheckResult result = check(R"(
local function len(a: {any})
return a and #a or nil
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x")
{
ScopedFastFlag sff[] = {
{"LuauDiscriminableUnions", true},
{"LuauParseSingletonTypes", true},
{"LuauSingletonTypes", true},
};
CheckResult result = check(R"(
type T = {tag: "Part", x: Part} | {tag: "Folder", x: Folder}
local function f(t: T)
if t.x:IsA("Part") then
local foo = t
else
local bar = t
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(R"({| tag: "Part", x: Part |})", toString(requireTypeAtPosition({5, 28})));
CHECK_EQ(R"({| tag: "Folder", x: Folder |})", toString(requireTypeAtPosition({7, 28})));
}
TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector")
{
CheckResult result = check(R"(
local function f(vec)
local X, Y, Z = vec.X, vec.Y, vec.Z
if type(vec) == "vector" then
local foo = vec
elseif typeof(vec) == "Instance" then
local foo = vec
else
local foo = vec
end
end
)");
if (FFlag::LuauDiscriminableUnions)
LUAU_REQUIRE_NO_ERRORS(result);
else
{
LUAU_REQUIRE_ERROR_COUNT(1, result);
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0]));
else
CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0]));
}
CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector"
CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance"
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
else
CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
}
TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector")
{
CheckResult result = check(R"(
local function f(x: Instance | Vector3)
if typeof(x) == "Vector3" then
local foo = x
else
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28})));
}
TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata")
{
CheckResult result = check(R"(
local function f(x: string | number | Instance | Vector3)
if type(x) == "userdata" then
local foo = x
else
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("Instance | Vector3", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28})));
}
TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance")
{
CheckResult result = check(R"(
local function f(x: Part | Folder | string)
if typeof(x) == "Instance" then
local foo = x
else
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("string", toString(requireTypeAtPosition({5, 28})));
}
TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union")
{
CheckResult result = check(R"(
local function f(x: Part | Folder | Instance | string | Vector3 | any)
if typeof(x) == "Instance" then
local foo = x
else
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("Folder | Instance | Part", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("Vector3 | any | string", toString(requireTypeAtPosition({5, 28})));
}
TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table")
{
CheckResult result = check(R"(
--!nonstrict
local function f(x)
if typeof(x) == "Instance" and x:IsA("Folder") then
local foo = x
elseif typeof(x) == "table" then
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 28})));
CHECK_EQ("any", toString(requireTypeAtPosition({7, 28})));
}
TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part")
{
CheckResult result = check(R"(
local function f(x: Part | Folder | string)
if typeof(x) ~= "Instance" or not x:IsA("Part") then
local foo = x
else
local foo = x
end
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("Folder | string", toString(requireTypeAtPosition({3, 28})));
CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28})));
}
TEST_SUITE_END();

View File

@ -379,9 +379,7 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string")
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
{"LuauUnionHeuristic", true},
{"LuauExpectedTypesOfProperties", true},
{"LuauExtendedUnionMismatchError", true},
};
CheckResult result = check(R"(
@ -404,9 +402,7 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool")
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
{"LuauUnionHeuristic", true},
{"LuauExpectedTypesOfProperties", true},
{"LuauExtendedUnionMismatchError", true},
};
CheckResult result = check(R"(
@ -429,9 +425,7 @@ TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options")
ScopedFastFlag sffs[] = {
{"LuauSingletonTypes", true},
{"LuauParseSingletonTypes", true},
{"LuauUnionHeuristic", true},
{"LuauExpectedTypesOfProperties", true},
{"LuauExtendedUnionMismatchError", true},
{"LuauIfElseExpectedType2", true},
{"LuauIfElseBranchTypeUnion", true},
};

View File

@ -12,8 +12,6 @@
using namespace Luau;
LUAU_FASTFLAG(LuauExtendedFunctionMismatchError)
TEST_SUITE_BEGIN("TableTests");
TEST_CASE_FIXTURE(Fixture, "basic")
@ -2075,22 +2073,11 @@ caused by:
caused by:
Property 'y' is not compatible. Type 'string' could not be converted into 'number')");
if (FFlag::LuauExtendedFunctionMismatchError)
{
CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2'
CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2'
caused by:
Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: <a>(a) -> () }'
caused by:
Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '<a>(a) -> ()'; different number of generic type parameters)");
}
else
{
CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2'
caused by:
Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: <a>(a) -> () }'
caused by:
Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '<a>(a) -> ()')");
}
}
TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table")
@ -2166,7 +2153,6 @@ a.p = { x = 9 }
TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call")
{
ScopedFastFlag sff[]{
{"LuauFixRecursiveMetatableCall", true},
{"LuauUnsealedTableLiteral", true},
};

View File

@ -16,7 +16,6 @@
LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr)
LUAU_FASTFLAG(LuauEqConstraint)
LUAU_FASTFLAG(LuauExtendedFunctionMismatchError)
using namespace Luau;
@ -959,8 +958,6 @@ TEST_CASE_FIXTURE(Fixture, "another_recursive_local_function")
TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
function f()
return f
@ -973,8 +970,6 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets")
TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
function f(g)
return f(f)
@ -1699,8 +1694,6 @@ TEST_CASE_FIXTURE(Fixture, "first_argument_can_be_optional")
TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
--!strict
local s
@ -1711,8 +1704,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check")
TEST_CASE_FIXTURE(Fixture, "occurs_check_does_not_recurse_forever_if_asked_to_traverse_a_cyclic_type")
{
ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true};
CheckResult result = check(R"(
--!strict
function u(t, w)
@ -3326,11 +3317,12 @@ TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison")
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable")
TEST_CASE_FIXTURE(Fixture, "concat_op_on_free_lhs_and_string_rhs")
{
CheckResult result = check(R"(
local x
print((x == true and (x .. "y")) .. 1)
local function f(x)
return x .. "y"
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
@ -3340,13 +3332,14 @@ TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable
TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs")
{
CheckResult result = check(R"(
local x
print("foo" .. x)
local function f(x)
return "foo" .. x
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("string", toString(requireType("x")));
CHECK_EQ("(string) -> string", toString(requireType("f")));
}
TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown")
@ -4374,8 +4367,6 @@ TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok")
TEST_CASE_FIXTURE(Fixture, "record_matching_overload")
{
ScopedFastFlag sffs("LuauStoreMatchingOverloadFnType", true);
CheckResult result = check(R"(
type Overload = ((string) -> string) & ((number) -> number)
local abc: Overload
@ -4475,17 +4466,10 @@ f(function(a, b, c, ...) return a + b end)
LUAU_REQUIRE_ERRORS(result);
if (FFlag::LuauExtendedFunctionMismatchError)
{
CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number'
CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number'
caused by:
Argument count mismatch. Function expects 3 arguments, but only 2 are specified)",
toString(result.errors[0]));
}
else
{
CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number')", toString(result.errors[0]));
}
toString(result.errors[0]));
// Infer from variadic packs into elements
result = check(R"(
@ -4618,17 +4602,9 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i
)");
LUAU_REQUIRE_ERRORS(result);
if (FFlag::LuauExtendedFunctionMismatchError)
{
CHECK_EQ(
"Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '<a>(a, a, (a, a) -> a) -> a'; different number of generic type "
"parameters",
toString(result.errors[0]));
}
else
{
CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '<a>(a, a, (a, a) -> a) -> a'", toString(result.errors[0]));
}
CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '<a>(a, a, (a, a) -> a) -> a'; different number of generic type "
"parameters",
toString(result.errors[0]));
}
TEST_CASE_FIXTURE(Fixture, "infer_return_value_type")
@ -4741,8 +4717,6 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch")
TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table")
{
ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true};
CheckResult result = check(R"(
type A = { x: number }
local a: A = { x = 1 }
@ -4965,8 +4939,6 @@ TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_
TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count")
{
ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true};
CheckResult result = check(R"(
type A = (number, number) -> string
type B = (number) -> string
@ -4983,8 +4955,6 @@ caused by:
TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg")
{
ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true};
CheckResult result = check(R"(
type A = (number, number) -> string
type B = (number, string) -> string
@ -5001,8 +4971,6 @@ caused by:
TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count")
{
ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true};
CheckResult result = check(R"(
type A = (number, number) -> (number)
type B = (number, number) -> (number, boolean)
@ -5019,8 +4987,6 @@ caused by:
TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret")
{
ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true};
CheckResult result = check(R"(
type A = (number, number) -> string
type B = (number, number) -> number
@ -5037,8 +5003,6 @@ caused by:
TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult")
{
ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true};
CheckResult result = check(R"(
type A = (number, number) -> (number, string)
type B = (number, number) -> (number, boolean)
@ -5069,8 +5033,6 @@ TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options")
TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free")
{
ScopedFastFlag luauUnifyFunctionCheckResult{"LuauUpdateFunctionNameBinding", true};
CheckResult result = check(R"(
local t = {}

View File

@ -931,7 +931,6 @@ type R = { m: F<R> }
TEST_CASE_FIXTURE(Fixture, "pack_tail_unification_check")
{
ScopedFastFlag luauUnifyPackTails{"LuauUnifyPackTails", true};
ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true};
CheckResult result = check(R"(
local a: () -> (number, ...string)

View File

@ -464,8 +464,6 @@ local a: XYZ = { w = 4 }
TEST_CASE_FIXTURE(Fixture, "error_detailed_optional")
{
ScopedFastFlag luauExtendedUnionMismatchError{"LuauExtendedUnionMismatchError", true};
CheckResult result = check(R"(
type X = { x: number }

View File

@ -268,8 +268,6 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure")
TEST_CASE("tagging_tables")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypeVar ttv{TableTypeVar{}};
CHECK(!Luau::hasTag(&ttv, "foo"));
Luau::attachTag(&ttv, "foo");
@ -278,8 +276,6 @@ TEST_CASE("tagging_tables")
TEST_CASE("tagging_classes")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}};
CHECK(!Luau::hasTag(&base, "foo"));
Luau::attachTag(&base, "foo");
@ -288,8 +284,6 @@ TEST_CASE("tagging_classes")
TEST_CASE("tagging_subclasses")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}};
TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}};
@ -307,8 +301,6 @@ TEST_CASE("tagging_subclasses")
TEST_CASE("tagging_functions")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypePackVar empty{TypePack{}};
TypeVar ftv{FunctionTypeVar{&empty, &empty}};
CHECK(!Luau::hasTag(&ftv, "foo"));
@ -318,8 +310,6 @@ TEST_CASE("tagging_functions")
TEST_CASE("tagging_props")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
Property prop{};
CHECK(!Luau::hasTag(prop, "foo"));
Luau::attachTag(prop, "foo");
@ -370,4 +360,66 @@ local b: (T, T, T) -> T
CHECK_EQ(count, 1);
}
TEST_CASE("isString_on_string_singletons")
{
ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true};
TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}};
CHECK(isString(&helloString));
}
TEST_CASE("isString_on_unions_of_various_string_singletons")
{
ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true};
TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}};
TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}};
TypeVar union_{UnionTypeVar{{&helloString, &byeString}}};
CHECK(isString(&union_));
}
TEST_CASE("proof_that_isString_uses_all_of")
{
ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true};
TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}};
TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}};
TypeVar booleanType{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}};
TypeVar union_{UnionTypeVar{{&helloString, &byeString, &booleanType}}};
CHECK(!isString(&union_));
}
TEST_CASE("isBoolean_on_boolean_singletons")
{
ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true};
TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}};
CHECK(isBoolean(&trueBool));
}
TEST_CASE("isBoolean_on_unions_of_true_or_false_singletons")
{
ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true};
TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}};
TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}};
TypeVar union_{UnionTypeVar{{&trueBool, &falseBool}}};
CHECK(isBoolean(&union_));
}
TEST_CASE("proof_that_isBoolean_uses_all_of")
{
ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true};
TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}};
TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}};
TypeVar stringType{PrimitiveTypeVar{PrimitiveTypeVar::String}};
TypeVar union_{UnionTypeVar{{&trueBool, &falseBool, &stringType}}};
CHECK(!isBoolean(&union_));
}
TEST_SUITE_END();