diff --git a/Analysis/include/Luau/Documentation.h b/Analysis/include/Luau/Documentation.h index 7b609b4..68ff3a7 100644 --- a/Analysis/include/Luau/Documentation.h +++ b/Analysis/include/Luau/Documentation.h @@ -12,10 +12,17 @@ namespace Luau struct FunctionDocumentation; struct TableDocumentation; struct OverloadedFunctionDocumentation; +struct BasicDocumentation; -using Documentation = Luau::Variant; +using Documentation = Luau::Variant; using DocumentationSymbol = std::string; +struct BasicDocumentation +{ + std::string documentation; + std::string learnMoreLink; +}; + struct FunctionParameterDocumentation { std::string name; @@ -29,6 +36,7 @@ struct FunctionDocumentation std::string documentation; std::vector parameters; std::vector returns; + std::string learnMoreLink; }; struct OverloadedFunctionDocumentation @@ -43,6 +51,7 @@ struct TableDocumentation { std::string documentation; Luau::DenseHashMap keys; + std::string learnMoreLink; }; using DocumentationDatabase = Luau::DenseHashMap; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index e5683fc..50379c1 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -23,10 +23,11 @@ struct ToStringNameMap struct ToStringOptions { - bool exhaustive = false; // If true, we produce complete output rather than comprehensible output - bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. - bool functionTypeArguments = false; // If true, output function type argument names when they are available - bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' + bool exhaustive = false; // If true, we produce complete output rather than comprehensible output + bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. + bool functionTypeArguments = false; // If true, output function type argument names when they are available + bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' + bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); std::optional nameMap; @@ -64,6 +65,8 @@ inline std::string toString(TypePackId ty) std::string toString(const TypeVar& tv, const ToStringOptions& opts = {}); std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); +std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts = {}); + // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression void dump(TypeId ty); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 306ac77..78d642c 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -175,10 +175,10 @@ struct TypeChecker std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& errors); + std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, const std::vector& errors); - ExprResult reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, + void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, const std::vector& errors); @@ -282,6 +282,14 @@ public: // Wrapper for merge(l, r, toUnion) but without the lambda junk. void merge(RefinementMap& l, const RefinementMap& r); + // Produce an "emergency backup type" for recovery from type errors. + // This comes in two flavours, depening on whether or not we can make a good guess + // for an error recovery type. + TypeId errorRecoveryType(TypeId guess); + TypePackId errorRecoveryTypePack(TypePackId guess); + TypeId errorRecoveryType(const ScopePtr& scope); + TypePackId errorRecoveryTypePack(const ScopePtr& scope); + private: void prepareErrorsForDisplay(ErrorVec& errVec); void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data); @@ -297,6 +305,10 @@ private: TypeId freshType(const ScopePtr& scope); TypeId freshType(TypeLevel level); + // Produce a new singleton type var. + TypeId singletonType(bool value); + TypeId singletonType(std::string value); + // Returns nullopt if the predicate filters down the TypeId to 0 options. std::optional filterMap(TypeId type, TypeIdPredicate predicate); @@ -330,8 +342,8 @@ private: const std::vector& typePackParams, const Location& location); // Note: `scope` must be a fresh scope. - std::pair, std::vector> createGenericTypes( - const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + std::pair, std::vector> createGenericTypes(const ScopePtr& scope, std::optional levelOpt, + const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); @@ -347,7 +359,6 @@ private: void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); bool isNonstrictMode() const; @@ -387,12 +398,9 @@ public: const TypeId booleanType; const TypeId threadType; const TypeId anyType; - - const TypeId errorType; const TypeId optionalNumberType; const TypePackId anyTypePack; - const TypePackId errorTypePack; private: int checkRecursionCount = 0; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 6bd7932..093ea43 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -108,6 +108,79 @@ struct PrimitiveTypeVar } }; +// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md +// Types for true and false +struct BoolSingleton +{ + bool value; + + bool operator==(const BoolSingleton& rhs) const + { + return value == rhs.value; + } + + bool operator!=(const BoolSingleton& rhs) const + { + return !(*this == rhs); + } +}; + +// Types for "foo", "bar" etc. +struct StringSingleton +{ + std::string value; + + bool operator==(const StringSingleton& rhs) const + { + return value == rhs.value; + } + + bool operator!=(const StringSingleton& rhs) const + { + return !(*this == rhs); + } +}; + +// No type for float singletons, partly because === isn't any equalivalence on floats +// (NaN != NaN). + +using SingletonVariant = Luau::Variant; + +struct SingletonTypeVar +{ + explicit SingletonTypeVar(const SingletonVariant& variant) + : variant(variant) + { + } + + explicit SingletonTypeVar(SingletonVariant&& variant) + : variant(std::move(variant)) + { + } + + // Default operator== is C++20. + bool operator==(const SingletonTypeVar& rhs) const + { + return variant == rhs.variant; + } + + bool operator!=(const SingletonTypeVar& rhs) const + { + return !(*this == rhs); + } + + SingletonVariant variant; +}; + +template +const T* get(const SingletonTypeVar* stv) +{ + if (stv) + return get_if(&stv->variant); + else + return nullptr; +} + struct FunctionArgument { Name name; @@ -332,8 +405,8 @@ struct LazyTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = Unifiable::Variant; struct TypeVar final { @@ -410,6 +483,9 @@ bool isGeneric(const TypeId ty); // Checks if a type may be instantiated to one containing generic type binders bool maybeGeneric(const TypeId ty); +// Checks if a type is of the form T1|...|Tn where one of the Ti is a singleton +bool maybeSingleton(TypeId ty); + struct SingletonTypes { const TypeId nilType; @@ -418,16 +494,19 @@ struct SingletonTypes const TypeId booleanType; const TypeId threadType; const TypeId anyType; - const TypeId errorType; const TypeId optionalNumberType; const TypePackId anyTypePack; - const TypePackId errorTypePack; SingletonTypes(); SingletonTypes(const SingletonTypes&) = delete; void operator=(const SingletonTypes&) = delete; + TypeId errorRecoveryType(TypeId guess); + TypePackId errorRecoveryTypePack(TypePackId guess); + TypeId errorRecoveryType(); + TypePackId errorRecoveryTypePack(); + private: std::unique_ptr arena; TypeId makeStringMetatable(); diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index c2e07e4..b47610f 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -105,6 +105,8 @@ private: struct Error { + // This constructor has to be public, since it's used in TypeVar and TypePack, + // but shouldn't be called directly. Please use errorRecoveryType() instead. Error(); int index; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index be0aadd..503034a 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -65,6 +65,7 @@ struct Unifier private: void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); void tryUnifyPrimitives(TypeId superTy, TypeId subTy); + void tryUnifySingletons(TypeId superTy, TypeId subTy); void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 1c94bb6..6fc0b3f 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false); LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -198,11 +199,24 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ UnifierSharedState unifierState(&iceReporter); Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); - unifier.tryUnify(expectedType, actualType); + if (FFlag::LuauAutocompleteAvoidMutation) + { + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, nullptr); + actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, nullptr); - bool ok = unifier.errors.empty(); - unifier.log.rollback(); - return ok; + auto errors = unifier.canUnify(expectedType, actualType); + return errors.empty(); + } + else + { + unifier.tryUnify(expectedType, actualType); + + bool ok = unifier.errors.empty(); + unifier.log.rollback(); + return ok; + } }; auto expr = node->asExpr(); @@ -1496,11 +1510,9 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName if (!sourceModule) return {}; - TypeChecker& typeChecker = - (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - ModulePtr module = - (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) - : frontend.moduleResolver.getModule(moduleName)); + TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + ModulePtr module = (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) + : frontend.moduleResolver.getModule(moduleName)); if (!module) return {}; @@ -1527,8 +1539,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view sourceModule->mode = Mode::Strict; sourceModule->commentLocations = std::move(result.commentLocations); - TypeChecker& typeChecker = - (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 96703ef..9f5c825 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -153,6 +153,7 @@ declare function gcinfo(): number wrap: ((A...) -> R...) -> any, yield: (A...) -> R..., isyieldable: () -> boolean, + close: (thread) -> (boolean, any?) } declare table: { diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 46ff2c7..f80d50a 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -180,13 +180,13 @@ struct ErrorConverter switch (e.context) { case CountMismatch::Return: - return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + - std::to_string(e.actual) + " " + actualVerb + " returned here"; + return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + + actualVerb + " returned here"; case CountMismatch::Result: // It is alright if right hand side produces more values than the // left hand side accepts. In this context consider only the opposite case. - return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + - std::to_string(e.actual) + " are required here"; + return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) + + " are required here"; case CountMismatch::Arg: if (FFlag::LuauTypeAliasPacks) return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 2f41127..1e97705 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -22,7 +22,6 @@ LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) LUAU_FASTFLAG(LuauNewRequireTrace2) -LUAU_FASTFLAGVARIABLE(LuauClearScopes, false) namespace Luau { @@ -458,8 +457,7 @@ CheckResult Frontend::check(const ModuleName& name) module->astTypes.clear(); module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); - if (FFlag::LuauClearScopes) - module->scopes.resize(1); + module->scopes.resize(1); } if (mode != Mode::NoCheck) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 880ffd2..32a0646 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -161,6 +161,7 @@ struct TypeCloner void operator()(const Unifiable::Bound& t); void operator()(const Unifiable::Error& t); void operator()(const PrimitiveTypeVar& t); + void operator()(const SingletonTypeVar& t); void operator()(const FunctionTypeVar& t); void operator()(const TableTypeVar& t); void operator()(const MetatableTypeVar& t); @@ -199,7 +200,9 @@ struct TypePackCloner if (encounteredFreeType) *encounteredFreeType = true; - seenTypePacks[typePackId] = dest.addTypePack(TypePackVar{Unifiable::Error{}}); + TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack); + TypePackId cloned = dest.addTypePack(*err); + seenTypePacks[typePackId] = cloned; } void operator()(const Unifiable::Generic& t) @@ -251,8 +254,9 @@ void TypeCloner::operator()(const Unifiable::Free& t) { if (encounteredFreeType) *encounteredFreeType = true; - - seenTypes[typeId] = dest.addType(ErrorTypeVar{}); + TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType); + TypeId cloned = dest.addType(*err); + seenTypes[typeId] = cloned; } void TypeCloner::operator()(const Unifiable::Generic& t) @@ -270,11 +274,17 @@ void TypeCloner::operator()(const Unifiable::Error& t) { defaultClone(t); } + void TypeCloner::operator()(const PrimitiveTypeVar& t) { defaultClone(t); } +void TypeCloner::operator()(const SingletonTypeVar& t) +{ + defaultClone(t); +} + void TypeCloner::operator()(const FunctionTypeVar& t) { TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 885fd48..735bfa5 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -350,6 +350,23 @@ struct TypeVarStringifier } } + void operator()(TypeId, const SingletonTypeVar& stv) + { + if (const BoolSingleton* bs = Luau::get(&stv)) + state.emit(bs->value ? "true" : "false"); + else if (const StringSingleton* ss = Luau::get(&stv)) + { + state.emit("\""); + state.emit(escape(ss->value)); + state.emit("\""); + } + else + { + LUAU_ASSERT(!"Unknown singleton type"); + throw std::runtime_error("Unknown singleton type"); + } + } + void operator()(TypeId, const FunctionTypeVar& ftv) { if (state.hasSeen(&ftv)) @@ -359,6 +376,7 @@ struct TypeVarStringifier return; } + // We should not be respecting opts.hideNamedFunctionTypeParameters here. if (ftv.generics.size() > 0 || ftv.genericPacks.size() > 0) { state.emit("<"); @@ -514,7 +532,14 @@ struct TypeVarStringifier break; } - state.emit(name); + if (isIdentifier(name)) + state.emit(name); + else + { + state.emit("[\""); + state.emit(escape(name)); + state.emit("\"]"); + } state.emit(": "); stringify(prop.type); comma = true; @@ -1084,6 +1109,94 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts) return toString(const_cast(&tp), std::move(opts)); } +std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +{ + std::string s = prefix; + + auto toString_ = [&opts](TypeId ty) -> std::string { + ToStringResult res = toStringDetailed(ty, opts); + opts.nameMap = std::move(res.nameMap); + return res.name; + }; + + auto toStringPack_ = [&opts](TypePackId ty) -> std::string { + ToStringResult res = toStringDetailed(ty, opts); + opts.nameMap = std::move(res.nameMap); + return res.name; + }; + + if (!opts.hideNamedFunctionTypeParameters && (!ftv.generics.empty() || !ftv.genericPacks.empty())) + { + s += "<"; + + bool first = true; + for (TypeId g : ftv.generics) + { + if (!first) + s += ", "; + first = false; + s += toString_(g); + } + + for (TypePackId gp : ftv.genericPacks) + { + if (!first) + s += ", "; + first = false; + s += toStringPack_(gp); + } + + s += ">"; + } + + s += "("; + + auto argPackIter = begin(ftv.argTypes); + auto argNameIter = ftv.argNames.begin(); + + bool first = true; + while (argPackIter != end(ftv.argTypes)) + { + if (!first) + s += ", "; + first = false; + + // argNames is guaranteed to be equal to argTypes iff argNames is not empty. + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (!ftv.argNames.empty()) + s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + s += toString_(*argPackIter); + + ++argPackIter; + if (!ftv.argNames.empty()) + { + LUAU_ASSERT(argNameIter != ftv.argNames.end()); + ++argNameIter; + } + } + + if (argPackIter.tail()) + { + if (auto vtp = get(*argPackIter.tail())) + s += ", ...: " + toString_(vtp->ty); + else + s += ", ...: " + toStringPack_(*argPackIter.tail()); + } + + s += "): "; + + size_t retSize = size(ftv.retType); + bool hasTail = !finite(ftv.retType); + if (retSize == 0 && !hasTail) + s += "()"; + else if ((retSize == 0 && hasTail) || (retSize == 1 && !hasTail)) + s += toStringPack_(ftv.retType); + else + s += "(" + toStringPack_(ftv.retType) + ")"; + + return s; +} + void dump(TypeId ty) { ToStringOptions opts; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 7d880af..6627fbe 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -14,61 +14,6 @@ LUAU_FASTFLAG(LuauTypeAliasPacks) namespace { - -std::string escape(std::string_view s) -{ - std::string r; - r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting - - for (uint8_t c : s) - { - if (c >= ' ' && c != '\\' && c != '\'' && c != '\"') - r += c; - else - { - r += '\\'; - - switch (c) - { - case '\a': - r += 'a'; - break; - case '\b': - r += 'b'; - break; - case '\f': - r += 'f'; - break; - case '\n': - r += 'n'; - break; - case '\r': - r += 'r'; - break; - case '\t': - r += 't'; - break; - case '\v': - r += 'v'; - break; - case '\'': - r += '\''; - break; - case '\"': - r += '\"'; - break; - case '\\': - r += '\\'; - break; - default: - Luau::formatAppend(r, "%03u", c); - } - } - } - - return r; -} - bool isIdentifierStartChar(char c) { return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_'; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 11aa7b3..af6d254 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -96,6 +96,22 @@ public: return nullptr; } } + + AstType* operator()(const SingletonTypeVar& stv) + { + if (const BoolSingleton* bs = get(&stv)) + return allocator->alloc(Location(), bs->value); + else if (const StringSingleton* ss = get(&stv)) + { + AstArray value; + value.data = const_cast(ss->value.c_str()); + value.size = strlen(value.data); + return allocator->alloc(Location(), value); + } + else + return nullptr; + } + AstType* operator()(const AnyTypeVar&) { return allocator->alloc(Location(), std::nullopt, AstName("any")); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 96b7992..b2ae94c 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,6 +36,9 @@ LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAG(LuauNewRequireTrace2) LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) +LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) namespace Luau { @@ -211,10 +214,8 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , booleanType(singletonTypes.booleanType) , threadType(singletonTypes.threadType) , anyType(singletonTypes.anyType) - , errorType(singletonTypes.errorType) , optionalNumberType(singletonTypes.optionalNumberType) , anyTypePack(singletonTypes.anyTypePack) - , errorTypePack(singletonTypes.errorTypePack) { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -484,7 +485,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std TypeId type = bindings[name].type; if (get(follow(type))) { - *asMutable(type) = ErrorTypeVar{}; + *asMutable(type) = *errorRecoveryType(anyType); reportError(TypeError{typealias->location, OccursCheckFailed{}}); } } @@ -719,7 +720,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) else if (auto tail = valueIter.tail()) { if (get(*tail)) - right = errorType; + right = errorRecoveryType(scope); else if (auto vtp = get(*tail)) right = vtp->ty; else if (get(*tail)) @@ -961,7 +962,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) else if (get(callRetPack) || !first(callRetPack)) { for (TypeId var : varTypes) - unify(var, errorType, forin.location); + unify(var, errorRecoveryType(scope), forin.location); return check(loopScope, *forin.body); } @@ -979,7 +980,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) const FunctionTypeVar* iterFunc = get(iterTy); if (!iterFunc) { - TypeId varTy = get(iterTy) ? anyType : errorType; + TypeId varTy = get(iterTy) ? anyType : errorRecoveryType(loopScope); for (TypeId var : varTypes) unify(var, varTy, forin.location); @@ -1152,9 +1153,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); if (FFlag::LuauTypeAliasPacks) - bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorType}; + bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; else - bindingsMap[name] = TypeFun{binding->typeParams, errorType}; + bindingsMap[name] = TypeFun{binding->typeParams, errorRecoveryType(anyType)}; } else { @@ -1398,7 +1399,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(expr.location); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult result; @@ -1407,12 +1408,22 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& result = checkExpr(scope, *a->expr); else if (expr.is()) result = {nilType}; - else if (expr.is()) - result = {booleanType}; + else if (const AstExprConstantBool* bexpr = expr.as()) + { + if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + result = {singletonType(bexpr->value)}; + else + result = {booleanType}; + } + else if (const AstExprConstantString* sexpr = expr.as()) + { + if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))}; + else + result = {stringType}; + } else if (expr.is()) result = {numberType}; - else if (expr.is()) - result = {stringType}; else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) @@ -1485,7 +1496,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLo // TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint // ice("AstExprLocal exists but no binding definition for it?", expr.location); reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}}); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) @@ -1497,7 +1508,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGl return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}}); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) @@ -1517,14 +1528,14 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa return {head}; } if (get(varargPack)) - return {errorType}; + return {errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) return {vtp->ty}; else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); - return {errorType}; + return {errorRecoveryType(scope)}; } else ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); @@ -1547,7 +1558,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa return {head, std::move(result.predicates)}; } if (get(retPack)) - return {errorType, std::move(result.predicates)}; + return {errorRecoveryType(scope), std::move(result.predicates)}; else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; else if (get(retPack)) @@ -1572,7 +1583,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) return {*ty}; - return {errorType}; + return {errorRecoveryType(scope)}; } std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location) @@ -1876,6 +1887,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa std::vector> fieldTypes(expr.items.size); const TableTypeVar* expectedTable = nullptr; + const UnionTypeVar* expectedUnion = nullptr; std::optional expectedIndexType; std::optional expectedIndexResultType; @@ -1894,6 +1906,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa } } } + else if (FFlag::LuauExpectedTypesOfProperties) + if (const UnionTypeVar* utv = get(follow(*expectedType))) + expectedUnion = utv; } for (size_t i = 0; i < expr.items.size; ++i) @@ -1916,6 +1931,18 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) expectedResultType = prop->second.type; } + else if (FFlag::LuauExpectedTypesOfProperties && expectedUnion) + { + std::vector expectedResultTypes; + for (TypeId expectedOption : expectedUnion) + if (const TableTypeVar* ttv = get(follow(expectedOption))) + if (auto prop = ttv->props.find(key->value.data); prop != ttv->props.end()) + expectedResultTypes.push_back(prop->second.type); + if (expectedResultTypes.size() == 1) + expectedResultType = expectedResultTypes[0]; + else if (expectedResultTypes.size() > 1) + expectedResultType = addType(UnionTypeVar{expectedResultTypes}); + } } else { @@ -1958,21 +1985,22 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn { TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypePackId arguments = addTypePack({operandType}); - TypePackId retType = freshTypePack(scope); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); + TypePackId retTypePack = freshTypePack(scope); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + TypeId retType = first(retTypePack).value_or(nilType); if (!state.errors.empty()) - return {errorType}; + retType = errorRecoveryType(retType); - return {first(retType).value_or(nilType)}; + return {retType}; } reportError(expr.location, GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}); - return {errorType}; + return {errorRecoveryType(scope)}; } reportErrors(tryUnify(numberType, operandType, expr.location)); @@ -1984,7 +2012,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn operandType = stripFromNilAndReport(operandType, expr.location); if (get(operandType)) - return {errorType}; + return {errorRecoveryType(scope)}; if (get(operandType)) return {numberType}; // Not strictly correct: metatables permit overriding this @@ -2044,7 +2072,7 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, b if (unify(a, b, location)) return a; - return errorType; + return errorRecoveryType(anyType); } if (*a == *b) @@ -2166,11 +2194,13 @@ TypeId TypeChecker::checkRelationalOperation( std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); + // TODO: this check seems odd, the second part is redundant + // is it meant to be if (leftMetatable && rightMetatable && leftMetatable != rightMetatable) if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) { reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); - return errorType; + return errorRecoveryType(booleanType); } if (leftMetatable) @@ -2188,7 +2218,7 @@ TypeId TypeChecker::checkRelationalOperation( if (!state.errors.empty()) { reportError(expr.location, GenericError{format("Metamethod '%s' must return type 'boolean'", metamethodName.c_str())}); - return errorType; + return errorRecoveryType(booleanType); } } } @@ -2206,7 +2236,7 @@ TypeId TypeChecker::checkRelationalOperation( { reportError( expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())}); - return errorType; + return errorRecoveryType(booleanType); } } @@ -2214,14 +2244,14 @@ TypeId TypeChecker::checkRelationalOperation( { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); - return errorType; + return errorRecoveryType(booleanType); } if (needsMetamethod) { reportError(expr.location, GenericError{format("Type %s cannot be compared with %s because it has no metatable", toString(lhsType).c_str(), toString(expr.op).c_str())}); - return errorType; + return errorRecoveryType(booleanType); } return booleanType; @@ -2266,7 +2296,8 @@ TypeId TypeChecker::checkBinaryOperation( { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - return errorType; + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } // If we know nothing at all about the lhs type, we can usually say nothing about the result. @@ -2296,18 +2327,33 @@ TypeId TypeChecker::checkBinaryOperation( auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId { TypeId actualFunctionType = instantiate(scope, fnt, expr.location); TypePackId arguments = addTypePack({lhst, rhst}); - TypePackId retType = freshTypePack(scope); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); + TypePackId retTypePack = freshTypePack(scope); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); reportErrors(state.errors); + bool hasErrors = !state.errors.empty(); - if (!state.errors.empty()) - return errorType; + if (FFlag::LuauErrorRecoveryType && hasErrors) + { + // If there are unification errors, the return type may still be unknown + // so we loosen the argument types to see if that helps. + TypePackId fallbackArguments = freshTypePack(scope); + TypeId fallbackFunctionType = addType(FunctionTypeVar(scope->level, fallbackArguments, retTypePack)); + state.log.rollback(); + state.errors.clear(); + state.tryUnify(fallbackFunctionType, actualFunctionType, /*isFunctionCall*/ true); + if (!state.errors.empty()) + state.log.rollback(); + } - return first(retType).value_or(nilType); + TypeId retType = first(retTypePack).value_or(nilType); + if (hasErrors) + retType = errorRecoveryType(retType); + + return retType; }; std::string op = opToMetaTableEntry(expr.op); @@ -2321,7 +2367,8 @@ TypeId TypeChecker::checkBinaryOperation( reportError(expr.location, GenericError{format("Binary operator '%s' not supported by types '%s' and '%s'", toString(expr.op).c_str(), toString(lhsType).c_str(), toString(rhsType).c_str())}); - return errorType; + + return errorRecoveryType(scope); } switch (expr.op) @@ -2414,11 +2461,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy ExprResult result = checkExpr(scope, *expr.expr, annotationType); ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); + reportErrors(errorVec); if (!errorVec.empty()) - { - reportErrors(errorVec); - return {errorType, std::move(result.predicates)}; - } + annotationType = errorRecoveryType(annotationType); return {annotationType, std::move(result.predicates)}; } @@ -2434,7 +2479,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr // any type errors that may arise from it are going to be useless. currentModule->errors.resize(oldSize); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr) @@ -2476,7 +2521,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { for (AstExpr* expr : a->expressions) checkExpr(scope, *expr); - return std::pair(errorType, nullptr); + return {errorRecoveryType(scope), nullptr}; } else ice("Unexpected AST node in checkLValue", expr.location); @@ -2488,7 +2533,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope return {*ty, nullptr}; reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); - return {errorType, nullptr}; + return {errorRecoveryType(scope), nullptr}; } std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) @@ -2545,24 +2590,25 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { Unifier state = mkUnifier(expr.location); state.tryUnify(indexer->indexType, stringType); + TypeId retType = indexer->indexResultType; if (!state.errors.empty()) { state.log.rollback(); reportError(expr.location, UnknownProperty{lhs, name}); - return std::pair(errorType, nullptr); + retType = errorRecoveryType(retType); } - return std::pair(indexer->indexResultType, nullptr); + return std::pair(retType, nullptr); } else if (lhsTable->state == TableState::Sealed) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } else { reportError(TypeError{expr.location, GenericError{"Internal error: generic tables are not lvalues"}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } } else if (const ClassTypeVar* lhsClass = get(lhs)) @@ -2571,7 +2617,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } return std::pair(prop->type, nullptr); @@ -2585,12 +2631,12 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (isTableIntersection(lhs)) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } } reportError(TypeError{expr.location, NotATable{lhs}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) @@ -2615,7 +2661,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } return std::pair(prop->type, nullptr); } @@ -2626,7 +2672,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!exprTable) { reportError(TypeError{expr.expr->location, NotATable{exprType}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } if (value) @@ -2678,7 +2724,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) if (isNonstrictMode()) return globalScope->bindings[name].typeId; - return errorType; + return errorRecoveryType(scope); } else { @@ -2705,20 +2751,21 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) TableTypeVar* ttv = getMutableTableType(lhsType); if (!ttv) { - if (!isTableIntersection(lhsType)) + if (!FFlag::LuauErrorRecoveryType && !isTableIntersection(lhsType)) + // This error now gets reported when we check the function body. reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); - return errorType; + return errorRecoveryType(scope); } // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check if (lhsType->persistent || ttv->state == TableState::Sealed) - return errorType; + return errorRecoveryType(scope); Name name = indexName->index.value; if (ttv->props.count(name)) - return errorType; + return errorRecoveryType(scope); Property& property = ttv->props[name]; @@ -2728,9 +2775,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) return property.type; } else if (funName.is()) - { - return errorType; - } + return errorRecoveryType(scope); else { ice("Unexpected AST node type", funName.location); @@ -2991,7 +3036,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A else if (expr.is()) { if (!scope->varargPack) - return {addTypePack({addType(ErrorTypeVar())})}; + return {errorRecoveryTypePack(scope)}; return {*scope->varargPack}; } @@ -3095,10 +3140,9 @@ void TypeChecker::checkArgumentList( if (get(tail)) { // Unify remaining parameters so we don't leave any free-types hanging around. - TypeId argTy = errorType; while (paramIter != endIter) { - state.tryUnify(*paramIter, argTy); + state.tryUnify(*paramIter, errorRecoveryType(anyType)); ++paramIter; } return; @@ -3157,7 +3201,7 @@ void TypeChecker::checkArgumentList( { while (argIter != endIter) { - unify(*argIter, errorType, state.location); + unify(*argIter, errorRecoveryType(scope), state.location); ++argIter; } // For this case, we want the error span to cover every errant extra parameter @@ -3246,7 +3290,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A // For each overload // Compare parameter and argument types // Report any errors (also speculate dot vs colon warnings!) - // If there are no errors, return the resulting return type + // Return the resulting return type (even if there are errors) + // If there are no matching overloads, unify with (a...) -> (b...) and return b... TypeId selfType = nullptr; TypeId functionType = nullptr; @@ -3268,8 +3313,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } else { - functionType = errorType; - actualFunctionType = errorType; + functionType = errorRecoveryType(scope); + actualFunctionType = functionType; } } else @@ -3296,7 +3341,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A TypePackId argPack = argListResult.type; if (get(argPack)) - return ExprResult{errorTypePack}; + return {errorRecoveryTypePack(scope)}; TypePack* args = getMutable(argPack); LUAU_ASSERT(args != nullptr); @@ -3314,19 +3359,34 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector errors; // errors encountered for each overload std::vector overloadsThatMatchArgCount; + std::vector overloadsThatDont; for (TypeId fn : overloads) { fn = follow(fn); - if (auto ret = checkCallOverload(scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, errors)) + if (auto ret = checkCallOverload( + scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) return *ret; } if (handleSelfCallMismatch(scope, expr, args, argLocations, errors)) return {retPack}; - return reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + + if (FFlag::LuauErrorRecoveryType) + { + const FunctionTypeVar* overload = nullptr; + if (!overloadsThatMatchArgCount.empty()) + overload = get(overloadsThatMatchArgCount[0]); + if (!overload && !overloadsThatDont.empty()) + overload = get(overloadsThatDont[0]); + if (overload) + return {errorRecoveryTypePack(overload->retType)}; + } + + return {errorRecoveryTypePack(retPack)}; } std::vector> TypeChecker::getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall) @@ -3382,7 +3442,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& errors) + std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { fn = stripFromNilAndReport(fn, expr.func->location); @@ -3394,7 +3454,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (get(fn)) { - return {{addTypePack(TypePackVar{Unifiable::Error{}})}}; + return {{errorRecoveryTypePack(scope)}}; } if (get(fn)) @@ -3427,14 +3487,14 @@ std::optional> TypeChecker::checkCallOverload(const Scope TypeId fn = *ty; fn = instantiate(scope, fn, expr.func->location); - return checkCallOverload( - scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, overloadsThatMatchArgCount, errors); + return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, + overloadsThatMatchArgCount, overloadsThatDont, errors); } } reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); - unify(retPack, errorTypePack, expr.func->location); - return {{errorTypePack}}; + unify(retPack, errorRecoveryTypePack(scope), expr.func->location); + return {{errorRecoveryTypePack(retPack)}}; } // When this function type has magic functions and did return something, we select that overload instead. @@ -3476,6 +3536,8 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (!argMismatch) overloadsThatMatchArgCount.push_back(fn); + else if (FFlag::LuauErrorRecoveryType) + overloadsThatDont.push_back(fn); errors.emplace_back(std::move(state.errors), args->head, ftv); state.log.rollback(); @@ -3586,14 +3648,14 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal return false; } -ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, - TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, - const std::vector& overloadsThatMatchArgCount, const std::vector& errors) +void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, + const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, + const std::vector& errors) { if (overloads.size() == 1) { reportErrors(std::get<0>(errors.front())); - return {errorTypePack}; + return; } std::vector overloadTypes = overloadsThatMatchArgCount; @@ -3622,7 +3684,7 @@ ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr // If only one overload matched, we don't need this error because we provided the previous errors. if (overloadsThatMatchArgCount.size() == 1) - return {errorTypePack}; + return; } std::string s; @@ -3655,7 +3717,7 @@ ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr reportError(expr.func->location, ExtraInformation{"Other overloads are also not viable: " + s}); // No viable overload - return {errorTypePack}; + return; } ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, @@ -3740,7 +3802,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) { reportError(TypeError{location, UnknownRequire{}}); - return errorType; + return errorRecoveryType(anyType); } return anyType; @@ -3758,14 +3820,14 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module reportError(TypeError{location, UnknownRequire{reportedModulePath}}); } - return errorType; + return errorRecoveryType(scope); } if (module->type != SourceCode::Module) { std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); - return errorType; + return errorRecoveryType(scope); } std::optional moduleType = first(module->getModuleScope()->returnType); @@ -3773,7 +3835,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module { std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); - return errorType; + return errorRecoveryType(scope); } SeenTypes seenTypes; @@ -4078,7 +4140,7 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (!qty.has_value()) { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } if (ty == *qty) @@ -4101,7 +4163,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat else { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } } @@ -4116,7 +4178,7 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) else { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(anyType); } } @@ -4131,7 +4193,7 @@ TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location lo else { reportError(location, UnificationTooComplex{}); - return errorTypePack; + return errorRecoveryTypePack(anyTypePack); } } @@ -4279,6 +4341,38 @@ TypeId TypeChecker::freshType(TypeLevel level) return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); } +TypeId TypeChecker::singletonType(bool value) +{ + // TODO: cache singleton types + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BoolSingleton{value}))); +} + +TypeId TypeChecker::singletonType(std::string value) +{ + // TODO: cache singleton types + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(StringSingleton{std::move(value)}))); +} + +TypeId TypeChecker::errorRecoveryType(const ScopePtr& scope) +{ + return singletonTypes.errorRecoveryType(); +} + +TypeId TypeChecker::errorRecoveryType(TypeId guess) +{ + return singletonTypes.errorRecoveryType(guess); +} + +TypePackId TypeChecker::errorRecoveryTypePack(const ScopePtr& scope) +{ + return singletonTypes.errorRecoveryTypePack(); +} + +TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) +{ + return singletonTypes.errorRecoveryTypePack(guess); +} + std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) { std::vector types = Luau::filterMap(type, predicate); @@ -4350,7 +4444,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (lit->parameters.size != 1 || !lit->parameters.data[0].type) { reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(anyType); } ToStringOptions opts; @@ -4368,7 +4462,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (!tf) { if (lit->name == Parser::errorName) - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); std::string typeName; if (lit->hasPrefix) @@ -4380,7 +4474,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else reportError(TypeError{annotation.location, UnknownSymbol{typeName, UnknownSymbol::Type}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); } if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty())) @@ -4390,14 +4484,17 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size()) { reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}}); - return addType(ErrorTypeVar{}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } - else if (FFlag::LuauTypeAliasPacks) + + if (FFlag::LuauTypeAliasPacks) { if (!lit->hasParameterList && !tf->typePackParams.empty()) { reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); - return addType(ErrorTypeVar{}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } std::vector typeParams; @@ -4445,7 +4542,17 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { reportError( TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); - return addType(ErrorTypeVar{}); + + if (FFlag::LuauErrorRecoveryType) + { + // Pad the types out with error recovery types + while (typeParams.size() < tf->typeParams.size()) + typeParams.push_back(errorRecoveryType(scope)); + while (typePackParams.size() < tf->typePackParams.size()) + typePackParams.push_back(errorRecoveryTypePack(scope)); + } + else + return errorRecoveryType(scope); } if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) @@ -4464,6 +4571,14 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation for (const auto& param : lit->parameters) typeParams.push_back(resolveType(scope, *param.type)); + if (FFlag::LuauErrorRecoveryType) + { + // If there aren't enough type parameters, pad them out with error recovery types + // (we've already reported the error) + while (typeParams.size() < lit->parameters.size) + typeParams.push_back(errorRecoveryType(scope)); + } + if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) { // If the generic parameters and the type arguments are the same, we are about to @@ -4483,8 +4598,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation props[prop.name.value] = {resolveType(scope, *prop.type)}; if (const auto& indexer = table->indexer) - tableIndexer = TableIndexer( - resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); + tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); return addType(TableTypeVar{ props, tableIndexer, scope->level, @@ -4536,14 +4650,20 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return addType(IntersectionTypeVar{types}); } - else if (annotation.is()) + else if (const auto& tsb = annotation.as()) { - return addType(ErrorTypeVar{}); + return singletonType(tsb->value); } + else if (const auto& tss = annotation.as()) + { + return singletonType(std::string(tss->value.data, tss->value.size)); + } + else if (annotation.is()) + return errorRecoveryType(scope); else { reportError(TypeError{annotation.location, GenericError{"Unknown type annotation?"}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); } } @@ -4584,7 +4704,7 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack else reportError(TypeError{generic->location, UnknownSymbol{genericName, UnknownSymbol::Type}}); - return addTypePack(TypePackVar{Unifiable::Error{}}); + return errorRecoveryTypePack(scope); } return *genericTy; @@ -4706,12 +4826,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (!maybeInstantiated.has_value()) { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } if (FFlag::LuauRecursiveTypeParameterRestriction && applyTypeFunction.encounteredForwardedType) { reportError(TypeError{location, GenericError{"Recursive type being used with different parameters"}}); - return errorType; + return errorRecoveryType(scope); } TypeId instantiated = *maybeInstantiated; @@ -4773,8 +4893,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, return instantiated; } -std::pair, std::vector> TypeChecker::createGenericTypes( - const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) +std::pair, std::vector> TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, + const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) { LUAU_ASSERT(scope->parent); @@ -5043,7 +5163,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement addRefinement(refis, isaP.lvalue, *result); else { - addRefinement(refis, isaP.lvalue, errorType); + addRefinement(refis, isaP.lvalue, errorRecoveryType(scope)); errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); } } @@ -5107,7 +5227,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec addRefinement(refis, typeguardP.lvalue, *result); else { - addRefinement(refis, typeguardP.lvalue, errorType); + addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); if (sense) errVec.push_back( TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); @@ -5118,7 +5238,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec auto fail = [&](const TypeErrorData& err) { errVec.push_back(TypeError{typeguardP.location, err}); - addRefinement(refis, typeguardP.lvalue, errorType); + addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); }; if (!typeguardP.isTypeof) @@ -5139,28 +5259,6 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, errVec, refis, scope, sense); } -void TypeChecker::DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) -{ - if (!sense) - return; - - static std::vector primitives{ - "string", "number", "boolean", "nil", "thread", - "table", // no op. Requires special handling. - "function", // no op. Requires special handling. - "userdata", // no op. Requires special handling. - }; - - if (auto typeFun = globalScope->lookupType(typeguardP.kind); - typeFun && typeFun->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || typeFun->typePackParams.empty())) - { - if (auto it = std::find(primitives.begin(), primitives.end(), typeguardP.kind); it != primitives.end()) - addRefinement(refis, typeguardP.lvalue, typeFun->type); - else if (typeguardP.isTypeof) - addRefinement(refis, typeguardP.lvalue, typeFun->type); - } -} - void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 228b192..d3221c7 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -286,5 +286,4 @@ TypePack* asMutable(const TypePack* tp) { return const_cast(tp); } - } // namespace Luau diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 4ee5cb8..924bf08 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -21,6 +21,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) +LUAU_FASTFLAG(LuauErrorRecoveryType) namespace Luau { @@ -305,6 +306,18 @@ bool maybeGeneric(TypeId ty) return isGeneric(ty); } +bool maybeSingleton(TypeId ty) +{ + ty = follow(ty); + if (get(ty)) + return true; + if (const UnionTypeVar* utv = get(ty)) + for (TypeId option : utv) + if (get(follow(option))) + return true; + return false; +} + FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) : argTypes(argTypes) , retType(retType) @@ -562,10 +575,8 @@ SingletonTypes::SingletonTypes() , booleanType(&booleanType_) , threadType(&threadType_) , anyType(&anyType_) - , errorType(&errorType_) , optionalNumberType(&optionalNumberType_) , anyTypePack(&anyTypePack_) - , errorTypePack(&errorTypePack_) , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); @@ -634,6 +645,32 @@ TypeId SingletonTypes::makeStringMetatable() return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } +TypeId SingletonTypes::errorRecoveryType() +{ + return &errorType_; +} + +TypePackId SingletonTypes::errorRecoveryTypePack() +{ + return &errorTypePack_; +} + +TypeId SingletonTypes::errorRecoveryType(TypeId guess) +{ + if (FFlag::LuauErrorRecoveryType) + return guess; + else + return &errorType_; +} + +TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) +{ + if (FFlag::LuauErrorRecoveryType) + return guess; + else + return &errorTypePack_; +} + SingletonTypes singletonTypes; void persist(TypeId ty) @@ -1141,6 +1178,11 @@ struct QVarFinder return false; } + bool operator()(const SingletonTypeVar&) const + { + return false; + } + bool operator()(const FunctionTypeVar& ftv) const { if (hasGeneric(ftv.argTypes)) @@ -1412,7 +1454,7 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha else if (strchr(options, data[i])) result.push_back(typechecker.numberType); else - result.push_back(typechecker.errorType); + result.push_back(typechecker.errorRecoveryType(typechecker.anyType)); } } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 82f621b..e1a52be 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,7 +22,9 @@ LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) LUAU_FASTFLAG(LuauShareTxnSeen); LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) +LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) +LUAU_FASTFLAG(LuauErrorRecoveryType); namespace Luau { @@ -211,6 +213,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { occursCheck(subTy, superTy); + // The occurrence check might have caused superTy no longer to be a free type if (!get(subTy)) { log(subTy); @@ -221,10 +224,20 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool } else if (l && r) { - log(superTy); + if (!FFlag::LuauErrorRecoveryType) + log(superTy); occursCheck(superTy, subTy); r->level = min(r->level, l->level); - *asMutable(superTy) = BoundTypeVar(subTy); + + // The occurrence check might have caused superTy no longer to be a free type + if (!FFlag::LuauErrorRecoveryType) + *asMutable(superTy) = BoundTypeVar(subTy); + else if (!get(superTy)) + { + log(superTy); + *asMutable(superTy) = BoundTypeVar(subTy); + } + return; } else if (l) @@ -240,6 +253,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool return; } + // The occurrence check might have caused superTy no longer to be a free type if (!get(superTy)) { if (auto rightLevel = getMutableLevel(subTy)) @@ -251,6 +265,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool log(superTy); *asMutable(superTy) = BoundTypeVar(subTy); } + return; } else if (r) @@ -512,6 +527,9 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (get(superTy) && get(subTy)) tryUnifyPrimitives(superTy, subTy); + else if (FFlag::LuauSingletonTypes && (get(superTy) || get(superTy)) && get(subTy)) + tryUnifySingletons(superTy, subTy); + else if (get(superTy) && get(subTy)) tryUnifyFunctions(superTy, subTy, isFunctionCall); @@ -723,17 +741,18 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal { occursCheck(superTp, subTp); + // The occurrence check might have caused superTp no longer to be a free type if (!get(superTp)) { log(superTp); *asMutable(superTp) = Unifiable::Bound(subTp); } } - else if (get(subTp)) { occursCheck(subTp, superTp); + // The occurrence check might have caused superTp no longer to be a free type if (!get(subTp)) { log(subTp); @@ -874,13 +893,13 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal while (superIter.good()) { - tryUnify_(singletonTypes.errorType, *superIter); + tryUnify_(singletonTypes.errorRecoveryType(), *superIter); superIter.advance(); } while (subIter.good()) { - tryUnify_(singletonTypes.errorType, *subIter); + tryUnify_(singletonTypes.errorRecoveryType(), *subIter); subIter.advance(); } @@ -906,6 +925,27 @@ void Unifier::tryUnifyPrimitives(TypeId superTy, TypeId subTy) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } +void Unifier::tryUnifySingletons(TypeId superTy, TypeId subTy) +{ + const PrimitiveTypeVar* lp = get(superTy); + const SingletonTypeVar* ls = get(superTy); + const SingletonTypeVar* rs = get(subTy); + + if ((!lp && !ls) || !rs) + ice("passed non singleton/primitive types to unifySingletons"); + + if (ls && *ls == *rs) + return; + + if (lp && lp->type == PrimitiveTypeVar::Boolean && get(rs) && variance == Covariant) + return; + + if (lp && lp->type == PrimitiveTypeVar::String && get(rs) && variance == Covariant) + return; + + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); +} + void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall) { FunctionTypeVar* lf = getMutable(superTy); @@ -1023,7 +1063,8 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) } // And vice versa if we're invariant - if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && lt->state != TableState::Free) + if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && + lt->state != TableState::Free) { for (const auto& [propName, subProp] : rt->props) { @@ -1038,7 +1079,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) return; } } - + // Reminder: left is the supertype, right is the subtype. // Width subtyping: any property in the supertype must be in the subtype, // and the types must agree. @@ -1634,9 +1675,8 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) { ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); - if (!FFlag::LuauExtendedClassMismatchError) - tryUnify_(prop.type, singletonTypes.errorType); + tryUnify_(prop.type, singletonTypes.errorRecoveryType()); } else { @@ -1952,7 +1992,7 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) { LUAU_ASSERT(get(any)); - const TypeId anyTy = singletonTypes.errorType; + const TypeId anyTy = singletonTypes.errorRecoveryType(); if (FFlag::LuauTypecheckOpts) { @@ -2046,7 +2086,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHash { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = ErrorTypeVar{}; + *asMutable(needle) = *singletonTypes.errorRecoveryType(); return; } @@ -2134,7 +2174,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = ErrorTypeVar{}; + *asMutable(needle) = *singletonTypes.errorRecoveryTypePack(); return; } diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index a2189f7..5b4bfa0 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -255,6 +255,14 @@ public: { return visit((class AstType*)node); } + virtual bool visit(class AstTypeSingletonBool* node) + { + return visit((class AstType*)node); + } + virtual bool visit(class AstTypeSingletonString* node) + { + return visit((class AstType*)node); + } virtual bool visit(class AstTypeError* node) { return visit((class AstType*)node); @@ -1158,6 +1166,30 @@ public: unsigned messageIndex; }; +class AstTypeSingletonBool : public AstType +{ +public: + LUAU_RTTI(AstTypeSingletonBool) + + AstTypeSingletonBool(const Location& location, bool value); + + void visit(AstVisitor* visitor) override; + + bool value; +}; + +class AstTypeSingletonString : public AstType +{ +public: + LUAU_RTTI(AstTypeSingletonString) + + AstTypeSingletonString(const Location& location, const AstArray& value); + + void visit(AstVisitor* visitor) override; + + const AstArray value; +}; + class AstTypePack : public AstNode { public: diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 39c7d92..87ebc48 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -286,6 +286,7 @@ private: // `<' typeAnnotation[, ...] `>' AstArray parseTypeParams(); + std::optional> parseCharArray(); AstExpr* parseString(); AstLocal* pushLocal(const Binding& binding); diff --git a/Ast/include/Luau/StringUtils.h b/Ast/include/Luau/StringUtils.h index 4f7673f..6ecf060 100644 --- a/Ast/include/Luau/StringUtils.h +++ b/Ast/include/Luau/StringUtils.h @@ -34,4 +34,6 @@ bool equalsLower(std::string_view lhs, std::string_view rhs); size_t hashRange(const char* data, size_t size); +std::string escape(std::string_view s); +bool isIdentifier(std::string_view s); } // namespace Luau diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index b1209fa..e709894 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -841,6 +841,28 @@ void AstTypeIntersection::visit(AstVisitor* visitor) } } +AstTypeSingletonBool::AstTypeSingletonBool(const Location& location, bool value) + : AstType(ClassIndex(), location) + , value(value) +{ +} + +void AstTypeSingletonBool::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstTypeSingletonString::AstTypeSingletonString(const Location& location, const AstArray& value) + : AstType(ClassIndex(), location) + , value(value) +{ +} + +void AstTypeSingletonString::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + AstTypeError::AstTypeError(const Location& location, const AstArray& types, bool isMissing, unsigned messageIndex) : AstType(ClassIndex(), location) , types(types) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index a1bad65..bc63e37 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) +LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) namespace Luau @@ -1278,7 +1279,27 @@ AstType* Parser::parseTableTypeAnnotation() while (lexer.current().type != '}') { - if (lexer.current().type == '[') + if (FFlag::LuauParseSingletonTypes && lexer.current().type == '[' && + (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString)) + { + const Lexeme begin = lexer.current(); + nextLexeme(); // [ + std::optional> chars = parseCharArray(); + + expectMatchAndConsume(']', begin); + expectAndConsume(':', "table field"); + + AstType* type = parseTypeAnnotation(); + + // TODO: since AstName conains a char*, it can't contain null + bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + + if (chars && !containsNull) + props.push_back({AstName(chars->data), begin.location, type}); + else + report(begin.location, "String literal contains malformed escape sequence"); + } + else if (lexer.current().type == '[') { if (indexer) { @@ -1528,6 +1549,32 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) nextLexeme(); return {allocator.alloc(begin, std::nullopt, nameNil), {}}; } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedTrue) + { + nextLexeme(); + return {allocator.alloc(begin, true)}; + } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedFalse) + { + nextLexeme(); + return {allocator.alloc(begin, false)}; + } + else if (FFlag::LuauParseSingletonTypes && (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString)) + { + if (std::optional> value = parseCharArray()) + { + AstArray svalue = *value; + return {allocator.alloc(begin, svalue)}; + } + else + return {reportTypeAnnotationError(begin, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; + } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::BrokenString) + { + Location location = lexer.current().location; + nextLexeme(); + return {reportTypeAnnotationError(location, {}, /*isMissing*/ false, "Malformed string")}; + } else if (lexer.current().type == Lexeme::Name) { std::optional prefix; @@ -2416,7 +2463,7 @@ AstArray Parser::parseTypeParams() return copy(parameters); } -AstExpr* Parser::parseString() +std::optional> Parser::parseCharArray() { LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString); @@ -2426,11 +2473,8 @@ AstExpr* Parser::parseString() { if (!Lexer::fixupQuotedString(scratchData)) { - Location location = lexer.current().location; - nextLexeme(); - - return reportExprError(location, {}, "String literal contains malformed escape sequence"); + return std::nullopt; } } else @@ -2438,12 +2482,18 @@ AstExpr* Parser::parseString() Lexer::fixupMultilineString(scratchData); } - Location start = lexer.current().location; AstArray value = copy(scratchData); - nextLexeme(); + return value; +} - return allocator.alloc(start, value); +AstExpr* Parser::parseString() +{ + Location location = lexer.current().location; + if (std::optional> value = parseCharArray()) + return allocator.alloc(location, *value); + else + return reportExprError(location, {}, "String literal contains malformed escape sequence"); } AstLocal* Parser::pushLocal(const Binding& binding) diff --git a/Ast/src/StringUtils.cpp b/Ast/src/StringUtils.cpp index 24b2283..9c7fed3 100644 --- a/Ast/src/StringUtils.cpp +++ b/Ast/src/StringUtils.cpp @@ -225,4 +225,62 @@ size_t hashRange(const char* data, size_t size) return hash; } +bool isIdentifier(std::string_view s) +{ + return (s.find_first_not_of("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234567890_") == std::string::npos); +} + +std::string escape(std::string_view s) +{ + std::string r; + r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting + + for (uint8_t c : s) + { + if (c >= ' ' && c != '\\' && c != '\'' && c != '\"') + r += c; + else + { + r += '\\'; + + switch (c) + { + case '\a': + r += 'a'; + break; + case '\b': + r += 'b'; + break; + case '\f': + r += 'f'; + break; + case '\n': + r += 'n'; + break; + case '\r': + r += 'r'; + break; + case '\t': + r += 't'; + break; + case '\v': + r += 'v'; + break; + case '\'': + r += '\''; + break; + case '\"': + r += '\"'; + break; + case '\\': + r += '\\'; + break; + default: + Luau::formatAppend(r, "%03u", c); + } + } + } + + return r; +} } // namespace Luau diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 798ead0..4194d5a 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -236,32 +236,12 @@ int main(int argc, char** argv) Luau::registerBuiltinTypes(frontend.typeChecker); Luau::freeze(frontend.typeChecker.globalTypes); + std::vector files = getSourceFiles(argc, argv); + int failed = 0; - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; - - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - // Look for .luau first and if absent, fall back to .lua - if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) - { - failed += !analyzeFile(frontend, name.c_str(), format, annotate); - } - else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - { - failed += !analyzeFile(frontend, name.c_str(), format, annotate); - } - }); - } - else - { - failed += !analyzeFile(frontend, argv[i], format, annotate); - } - } + for (const std::string& path : files) + failed += !analyzeFile(frontend, path.c_str(), format, annotate); if (!configResolver.configErrors.empty()) { diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index fff5e42..b3c9557 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -223,3 +223,40 @@ std::optional getParentPath(const std::string& path) return ""; } + +static std::string getExtension(const std::string& path) +{ + std::string::size_type dot = path.find_last_of(".\\/"); + + if (dot == std::string::npos || path[dot] != '.') + return ""; + + return path.substr(dot); +} + +std::vector getSourceFiles(int argc, char** argv) +{ + std::vector files; + + for (int i = 1; i < argc; ++i) + { + if (argv[i][0] == '-') + continue; + + if (isDirectory(argv[i])) + { + traverseDirectory(argv[i], [&](const std::string& name) { + std::string ext = getExtension(name); + + if (ext == ".lua" || ext == ".luau") + files.push_back(name); + }); + } + else + { + files.push_back(argv[i]); + } + } + + return files; +} diff --git a/CLI/FileUtils.h b/CLI/FileUtils.h index f7fbe8a..da11f51 100644 --- a/CLI/FileUtils.h +++ b/CLI/FileUtils.h @@ -4,6 +4,7 @@ #include #include #include +#include std::optional readFile(const std::string& name); @@ -12,3 +13,5 @@ bool traverseDirectory(const std::string& path, const std::function getParentPath(const std::string& path); + +std::vector getSourceFiles(int argc, char** argv); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 0d804be..410674f 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -20,7 +20,7 @@ enum class CompileFormat { - Default, + Text, Binary }; @@ -33,7 +33,7 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); std::string bytecode = Luau::compile(std::string(s, l)); - if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, chunkname, bytecode.data(), bytecode.size(), 0) == 0) return 1; lua_pushnil(L); @@ -80,7 +80,7 @@ static int lua_require(lua_State* L) // now we can compile & run module on the new thread std::string bytecode = Luau::compile(*source); - if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { int status = lua_resume(ML, L, 0); @@ -151,7 +151,7 @@ static std::string runCode(lua_State* L, const std::string& source) { std::string bytecode = Luau::compile(source); - if (luau_load(L, "=stdin", bytecode.data(), bytecode.size()) != 0) + if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0) { size_t len; const char* msg = lua_tolstring(L, -1, &len); @@ -370,7 +370,7 @@ static bool runFile(const char* name, lua_State* GL) std::string bytecode = Luau::compile(*source); int status = 0; - if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { status = lua_resume(L, NULL, 0); } @@ -379,11 +379,7 @@ static bool runFile(const char* name, lua_State* GL) status = LUA_ERRSYNTAX; } - if (status == 0) - { - return true; - } - else + if (status != 0) { std::string error; @@ -400,8 +396,10 @@ static bool runFile(const char* name, lua_State* GL) error += lua_debugtrace(L); fprintf(stderr, "%s", error.c_str()); - return false; } + + lua_pop(GL, 1); + return status == 0; } static void report(const char* name, const Luau::Location& location, const char* type, const char* message) @@ -431,14 +429,18 @@ static bool compileFile(const char* name, CompileFormat format) try { Luau::BytecodeBuilder bcb; - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); - bcb.setDumpSource(*source); + + if (format == CompileFormat::Text) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); + bcb.setDumpSource(*source); + } Luau::compileOrThrow(bcb, *source); switch (format) { - case CompileFormat::Default: + case CompileFormat::Text: printf("%s", bcb.dumpEverything().c_str()); break; case CompileFormat::Binary: @@ -504,7 +506,7 @@ int main(int argc, char** argv) if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) { - CompileFormat format = CompileFormat::Default; + CompileFormat format = CompileFormat::Text; if (strcmp(argv[1], "--compile=binary") == 0) format = CompileFormat::Binary; @@ -514,27 +516,12 @@ int main(int argc, char** argv) _setmode(_fileno(stdout), _O_BINARY); #endif + std::vector files = getSourceFiles(argc, argv); + int failed = 0; - for (int i = 2; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; - - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) - failed += !compileFile(name.c_str(), format); - else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !compileFile(name.c_str(), format); - }); - } - else - { - failed += !compileFile(argv[i], format); - } - } + for (const std::string& path : files) + failed += !compileFile(path.c_str(), format); return failed; } @@ -548,33 +535,25 @@ int main(int argc, char** argv) int profile = 0; for (int i = 1; i < argc; ++i) + { + if (argv[i][0] != '-') + continue; + if (strcmp(argv[i], "--profile") == 0) profile = 10000; // default to 10 KHz else if (strncmp(argv[i], "--profile=", 10) == 0) profile = atoi(argv[i] + 10); + } if (profile) profilerStart(L, profile); + std::vector files = getSourceFiles(argc, argv); + int failed = 0; - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; - - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !runFile(name.c_str(), L); - }); - } - else - { - failed += !runFile(argv[i], L); - } - } + for (const std::string& path : files) + failed += !runFile(path.c_str(), L); if (profile) { diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index 4f88e60..65e962d 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -13,11 +13,9 @@ class AstNameTable; class BytecodeBuilder; class BytecodeEncoder; +// Note: this structure is duplicated in luacode.h, don't forget to change these in sync! struct CompileOptions { - // default bytecode version target; can be used to compile code for older clients - int bytecodeVersion = 1; - // 0 - no optimization // 1 - baseline optimization level that doesn't prevent debuggability // 2 - includes optimizations that harm debuggability such as inlining diff --git a/Compiler/include/luacode.h b/Compiler/include/luacode.h new file mode 100644 index 0000000..e235a2e --- /dev/null +++ b/Compiler/include/luacode.h @@ -0,0 +1,39 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +/* Can be used to reconfigure visibility/exports for public APIs */ +#ifndef LUACODE_API +#define LUACODE_API extern +#endif + +typedef struct lua_CompileOptions lua_CompileOptions; + +struct lua_CompileOptions +{ + // 0 - no optimization + // 1 - baseline optimization level that doesn't prevent debuggability + // 2 - includes optimizations that harm debuggability such as inlining + int optimizationLevel; // default=1 + + // 0 - no debugging support + // 1 - line info & function names only; sufficient for backtraces + // 2 - full debug info with local & upvalue names; necessary for debugger + int debugLevel; // default=1 + + // 0 - no code coverage support + // 1 - statement coverage + // 2 - statement and expression coverage (verbose) + int coverageLevel; // default=0 + + // global builtin to construct vectors; disabled by default + const char* vectorLib; + const char* vectorCtor; + + // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these + const char** mutableGlobals; +}; + +/* compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy */ +LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize); diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 9712f02..5b93c1d 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -11,9 +11,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false) -LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresFenv, false) -LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresUpval, false) -LUAU_FASTFLAGVARIABLE(LuauGenericSpecialGlobals, false) LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) @@ -24,9 +21,6 @@ static const uint32_t kMaxRegisterCount = 255; static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxLocalCount = 200; -// TODO: Remove with LuauGenericSpecialGlobals -static const char* kSpecialGlobals[] = {"Game", "Workspace", "_G", "game", "plugin", "script", "shared", "workspace"}; - CompileError::CompileError(const Location& location, const std::string& message) : location(location) , message(message) @@ -466,7 +460,7 @@ struct Compiler bool shared = false; - if (FFlag::LuauPreloadClosuresUpval) + if (FFlag::LuauPreloadClosures) { // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it @@ -482,18 +476,6 @@ struct Compiler } } } - // Optimization: when closure has no upvalues, instead of allocating it every time we can share closure objects - // (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it is used) - else if (FFlag::LuauPreloadClosures && options.optimizationLevel >= 1 && f->upvals.empty() && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); - - if (cid >= 0 && cid < 32768) - { - bytecode.emitAD(LOP_DUPCLOSURE, target, cid); - return; - } - } if (!shared) bytecode.emitAD(LOP_NEWCLOSURE, target, pid); @@ -3298,8 +3280,7 @@ struct Compiler bool visit(AstStatLocalFunction* node) override { // record local->function association for some optimizations - if (FFlag::LuauPreloadClosuresUpval) - self->locals[node->name].func = node->func; + self->locals[node->name].func = node->func; return true; } @@ -3711,24 +3692,13 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block imports from non-readonly tables - if (FFlag::LuauGenericSpecialGlobals) - { - if (AstName name = names.get("_G"); name.value) - compiler.globals[name].writable = true; + if (AstName name = names.get("_G"); name.value) + compiler.globals[name].writable = true; - if (options.mutableGlobals) - for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) - if (AstName name = names.get(*ptr); name.value) - compiler.globals[name].writable = true; - } - else - { - for (const char* global : kSpecialGlobals) - { - if (AstName name = names.get(global); name.value) + if (options.mutableGlobals) + for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) + if (AstName name = names.get(*ptr); name.value) compiler.globals[name].writable = true; - } - } // this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written Compiler::AssignmentVisitor assignmentVisitor(&compiler); @@ -3742,7 +3712,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName } // this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found - if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value)) + if (options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value)) { Compiler::FenvVisitor fenvVisitor(compiler.getfenvUsed, compiler.setfenvUsed); root->visit(&fenvVisitor); diff --git a/Compiler/src/lcode.cpp b/Compiler/src/lcode.cpp new file mode 100644 index 0000000..ee150b1 --- /dev/null +++ b/Compiler/src/lcode.cpp @@ -0,0 +1,29 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "luacode.h" + +#include "Luau/Compiler.h" + +#include + +char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize) +{ + LUAU_ASSERT(outsize); + + Luau::CompileOptions opts; + + if (options) + { + static_assert(sizeof(lua_CompileOptions) == sizeof(Luau::CompileOptions), "C and C++ interface must match"); + memcpy(static_cast(&opts), options, sizeof(opts)); + } + + std::string result = compile(std::string(source, size), opts); + + char* copy = static_cast(malloc(result.size())); + if (!copy) + return nullptr; + + memcpy(copy, result.data(), result.size()); + *outsize = result.size(); + return copy; +} diff --git a/Sources.cmake b/Sources.cmake index c30cf77..23b931c 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -25,9 +25,11 @@ target_sources(Luau.Compiler PRIVATE Compiler/include/Luau/Bytecode.h Compiler/include/Luau/BytecodeBuilder.h Compiler/include/Luau/Compiler.h + Compiler/include/luacode.h Compiler/src/BytecodeBuilder.cpp Compiler/src/Compiler.cpp + Compiler/src/lcode.cpp ) # Luau.Analysis Sources @@ -204,6 +206,7 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.intersectionTypes.test.cpp tests/TypeInfer.provisional.test.cpp tests/TypeInfer.refinements.test.cpp + tests/TypeInfer.singletons.test.cpp tests/TypeInfer.tables.test.cpp tests/TypeInfer.test.cpp tests/TypeInfer.tryUnify.test.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index a9d3e87..1568d19 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -102,6 +102,8 @@ LUA_API lua_State* lua_newstate(lua_Alloc f, void* ud); LUA_API void lua_close(lua_State* L); LUA_API lua_State* lua_newthread(lua_State* L); LUA_API lua_State* lua_mainthread(lua_State* L); +LUA_API void lua_resetthread(lua_State* L); +LUA_API int lua_isthreadreset(lua_State* L); /* ** basic stack manipulation @@ -162,8 +164,7 @@ LUA_API void lua_pushlstring(lua_State* L, const char* s, size_t l); LUA_API void lua_pushstring(lua_State* L, const char* s); LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp); LUA_API LUA_PRINTF_ATTR(2, 3) const char* lua_pushfstringL(lua_State* L, const char* fmt, ...); -LUA_API void lua_pushcfunction( - lua_State* L, lua_CFunction fn, const char* debugname = NULL, int nup = 0, lua_Continuation cont = NULL); +LUA_API void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont); LUA_API void lua_pushboolean(lua_State* L, int b); LUA_API void lua_pushlightuserdata(lua_State* L, void* p); LUA_API int lua_pushthread(lua_State* L); @@ -178,9 +179,9 @@ LUA_API void lua_rawget(lua_State* L, int idx); LUA_API void lua_rawgeti(lua_State* L, int idx, int n); LUA_API void lua_createtable(lua_State* L, int narr, int nrec); -LUA_API void lua_setreadonly(lua_State* L, int idx, bool value); +LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled); LUA_API int lua_getreadonly(lua_State* L, int idx); -LUA_API void lua_setsafeenv(lua_State* L, int idx, bool value); +LUA_API void lua_setsafeenv(lua_State* L, int idx, int enabled); LUA_API void* lua_newuserdata(lua_State* L, size_t sz, int tag); LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); @@ -200,7 +201,7 @@ LUA_API int lua_setfenv(lua_State* L, int idx); /* ** `load' and `call' functions (load and run Luau bytecode) */ -LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env = 0); +LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env); LUA_API void lua_call(lua_State* L, int nargs, int nresults); LUA_API int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc); @@ -293,6 +294,8 @@ LUA_API void lua_unref(lua_State* L, int ref); #define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL) #define lua_pushliteral(L, s) lua_pushlstring(L, "" s, (sizeof(s) / sizeof(char)) - 1) +#define lua_pushcfunction(L, fn, debugname) lua_pushcclosurek(L, fn, debugname, 0, NULL) +#define lua_pushcclosure(L, fn, debugname, nup) lua_pushcclosurek(L, fn, debugname, nup, NULL) #define lua_setglobal(L, s) lua_setfield(L, LUA_GLOBALSINDEX, (s)) #define lua_getglobal(L, s) lua_getfield(L, LUA_GLOBALSINDEX, (s)) @@ -319,8 +322,8 @@ LUA_API const char* lua_setlocal(lua_State* L, int level, int n); LUA_API const char* lua_getupvalue(lua_State* L, int funcindex, int n); LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n); -LUA_API void lua_singlestep(lua_State* L, bool singlestep); -LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable); +LUA_API void lua_singlestep(lua_State* L, int enabled); +LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled); /* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */ LUA_API const char* lua_debugtrace(lua_State* L); @@ -361,6 +364,7 @@ struct lua_Callbacks void (*debuginterrupt)(lua_State* L, lua_Debug* ar); /* gets called when thread execution is interrupted by break in another thread */ void (*debugprotectederror)(lua_State* L); /* gets called when protected call results in an error */ }; +typedef struct lua_Callbacks lua_Callbacks; LUA_API lua_Callbacks* lua_callbacks(lua_State* L); diff --git a/VM/include/lualib.h b/VM/include/lualib.h index 30cffaf..fa83695 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -8,11 +8,12 @@ #define luaL_typeerror(L, narg, tname) luaL_typeerrorL(L, narg, tname) #define luaL_argerror(L, narg, extramsg) luaL_argerrorL(L, narg, extramsg) -typedef struct luaL_Reg +struct luaL_Reg { const char* name; lua_CFunction func; -} luaL_Reg; +}; +typedef struct luaL_Reg luaL_Reg; LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l); LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* e); @@ -75,6 +76,7 @@ struct luaL_Buffer struct TString* storage; char buffer[LUA_BUFFERSIZE]; }; +typedef struct luaL_Buffer luaL_Buffer; // when internal buffer storage is exhausted, a mutable string value 'storage' will be placed on the stack // in general, functions expect the mutable string buffer to be placed on top of the stack (top-1) diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 7e74264..a79ba0d 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -593,7 +593,7 @@ const char* lua_pushfstringL(lua_State* L, const char* fmt, ...) return ret; } -void lua_pushcfunction(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont) +void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont) { luaC_checkGC(L); luaC_checkthreadsleep(L); @@ -698,13 +698,13 @@ void lua_createtable(lua_State* L, int narray, int nrec) return; } -void lua_setreadonly(lua_State* L, int objindex, bool value) +void lua_setreadonly(lua_State* L, int objindex, int enabled) { const TValue* o = index2adr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); api_check(L, t != hvalue(registry(L))); - t->readonly = value; + t->readonly = bool(enabled); return; } @@ -717,12 +717,12 @@ int lua_getreadonly(lua_State* L, int objindex) return res; } -void lua_setsafeenv(lua_State* L, int objindex, bool value) +void lua_setsafeenv(lua_State* L, int objindex, int enabled) { const TValue* o = index2adr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); - t->safeenv = value; + t->safeenv = bool(enabled); return; } diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 87fc163..61798e2 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -436,8 +436,8 @@ static const luaL_Reg base_funcs[] = { static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFunction u) { - lua_pushcfunction(L, u); - lua_pushcfunction(L, f, name, 1); + lua_pushcfunction(L, u, NULL); + lua_pushcclosure(L, f, name, 1); lua_setfield(L, -2, name); } @@ -456,10 +456,10 @@ LUALIB_API int luaopen_base(lua_State* L) auxopen(L, "ipairs", luaB_ipairs, luaB_inext); auxopen(L, "pairs", luaB_pairs, luaB_next); - lua_pushcfunction(L, luaB_pcally, "pcall", 0, luaB_pcallcont); + lua_pushcclosurek(L, luaB_pcally, "pcall", 0, luaB_pcallcont); lua_setfield(L, -2, "pcall"); - lua_pushcfunction(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont); + lua_pushcclosurek(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont); lua_setfield(L, -2, "xpcall"); return 1; diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 9724c0e..0178fae 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -5,6 +5,8 @@ #include "lstate.h" #include "lvm.h" +LUAU_FASTFLAGVARIABLE(LuauCoroutineClose, false) + #define CO_RUN 0 /* running */ #define CO_SUS 1 /* suspended */ #define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ @@ -208,8 +210,7 @@ static int cowrap(lua_State* L) { cocreate(L); - lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont); - + lua_pushcclosurek(L, auxwrapy, NULL, 1, auxwrapcont); return 1; } @@ -232,6 +233,34 @@ static int coyieldable(lua_State* L) return 1; } +static int coclose(lua_State* L) +{ + if (!FFlag::LuauCoroutineClose) + luaL_error(L, "coroutine.close is not enabled"); + + lua_State* co = lua_tothread(L, 1); + luaL_argexpected(L, co, 1, "thread"); + + int status = auxstatus(L, co); + if (status != CO_DEAD && status != CO_SUS) + luaL_error(L, "cannot close %s coroutine", statnames[status]); + + if (co->status == LUA_OK || co->status == LUA_YIELD) + { + lua_pushboolean(L, true); + lua_resetthread(co); + return 1; + } + else + { + lua_pushboolean(L, false); + if (lua_gettop(co)) + lua_xmove(co, L, 1); /* move error message */ + lua_resetthread(co); + return 2; + } +} + static const luaL_Reg co_funcs[] = { {"create", cocreate}, {"running", corunning}, @@ -239,6 +268,7 @@ static const luaL_Reg co_funcs[] = { {"wrap", cowrap}, {"yield", coyield}, {"isyieldable", coyieldable}, + {"close", coclose}, {NULL, NULL}, }; @@ -246,7 +276,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L) { luaL_register(L, LUA_COLIBNAME, co_funcs); - lua_pushcfunction(L, coresumey, "resume", 0, coresumecont); + lua_pushcclosurek(L, coresumey, "resume", 0, coresumecont); lua_setfield(L, -2, "resume"); return 1; diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 1890e68..d77f84e 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -316,7 +316,7 @@ void luaG_breakpoint(lua_State* L, Proto* p, int line, bool enable) p->debuginsn[j] = LUAU_INSN_OP(p->code[j]); } - uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->code[i]); + uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->debuginsn[i]); // patch just the opcode byte, leave arguments alone p->code[i] &= ~0xff; @@ -357,17 +357,17 @@ int luaG_getline(Proto* p, int pc) return p->abslineinfo[pc >> p->linegaplog2] + p->lineinfo[pc]; } -void lua_singlestep(lua_State* L, bool singlestep) +void lua_singlestep(lua_State* L, int enabled) { - L->singlestep = singlestep; + L->singlestep = bool(enabled); } -void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable) +void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled) { const TValue* func = luaA_toobject(L, funcindex); api_check(L, ttisfunction(func) && !clvalue(func)->isC); - luaG_breakpoint(L, clvalue(func)->l.p, line, enable); + luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled)); } static size_t append(char* buf, size_t bufsize, size_t offset, const char* data) diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 328b47e..1259d46 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -19,6 +19,7 @@ LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) +LUAU_FASTFLAG(LuauCoroutineClose) /* ** {====================================================== @@ -300,7 +301,10 @@ static void resume(lua_State* L, void* ud) if (L->status == 0) { // start coroutine - LUAU_ASSERT(L->ci == L->base_ci && firstArg > L->base); + LUAU_ASSERT(L->ci == L->base_ci && firstArg >= L->base); + if (FFlag::LuauCoroutineClose && firstArg == L->base) + luaG_runerror(L, "cannot resume dead coroutine"); + if (luau_precall(L, firstArg - 1, LUA_MULTRET) != PCRLUA) return; diff --git a/VM/src/linit.cpp b/VM/src/linit.cpp index bf5e738..4e40165 100644 --- a/VM/src/linit.cpp +++ b/VM/src/linit.cpp @@ -22,7 +22,7 @@ LUALIB_API void luaL_openlibs(lua_State* L) const luaL_Reg* lib = lualibs; for (; lib->func; lib++) { - lua_pushcfunction(L, lib->func); + lua_pushcfunction(L, lib->func, NULL); lua_pushstring(L, lib->name); lua_call(L, 1, 0); } diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 0b2dfb6..24e9706 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -124,6 +124,34 @@ void luaE_freethread(lua_State* L, lua_State* L1) luaM_free(L, L1, sizeof(lua_State), L1->memcat); } +void lua_resetthread(lua_State* L) +{ + /* close upvalues before clearing anything */ + luaF_close(L, L->stack); + /* clear call frames */ + CallInfo* ci = L->base_ci; + ci->func = L->stack; + ci->base = ci->func + 1; + ci->top = ci->base + LUA_MINSTACK; + setnilvalue(ci->func); + L->ci = ci; + luaD_reallocCI(L, BASIC_CI_SIZE); + /* clear thread state */ + L->status = LUA_OK; + L->base = L->ci->base; + L->top = L->ci->base; + L->nCcalls = L->baseCcalls = 0; + /* clear thread stack */ + luaD_reallocstack(L, BASIC_STACK_SIZE); + for (int i = 0; i < L->stacksize; i++) + setnilvalue(L->stack + i); +} + +int lua_isthreadreset(lua_State* L) +{ + return L->ci == L->base_ci && L->base == L->top && L->status == LUA_OK; +} + lua_State* lua_newstate(lua_Alloc f, void* ud) { int i; diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index 80a3448..b576f80 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -748,7 +748,7 @@ static int gmatch(lua_State* L) luaL_checkstring(L, 2); lua_settop(L, 2); lua_pushinteger(L, 0); - lua_pushcfunction(L, gmatch_aux, NULL, 3); + lua_pushcclosure(L, gmatch_aux, NULL, 3); return 1; } diff --git a/VM/src/lutf8lib.cpp b/VM/src/lutf8lib.cpp index 6a02629..378de3d 100644 --- a/VM/src/lutf8lib.cpp +++ b/VM/src/lutf8lib.cpp @@ -265,7 +265,7 @@ static int iter_aux(lua_State* L) static int iter_codes(lua_State* L) { luaL_checkstring(L, 1); - lua_pushcfunction(L, iter_aux); + lua_pushcfunction(L, iter_aux, NULL); lua_pushvalue(L, 1); lua_pushinteger(L, 0); return 3; diff --git a/bench/bench.py b/bench/bench.py index b23ca89..39f219f 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -25,8 +25,8 @@ try: import scipy from scipy import stats except ModuleNotFoundError: - print("scipy package is required") - exit(1) + print("Warning: scipy package is not installed, confidence values will not be available") + stats = None scriptdir = os.path.dirname(os.path.realpath(__file__)) defaultVm = 'luau.exe' if os.name == "nt" else './luau' @@ -200,11 +200,14 @@ def finalizeResult(result): result.sampleStdDev = math.sqrt(sumOfSquares / (result.count - 1)) result.unbiasedEst = result.sampleStdDev * result.sampleStdDev - # Two-tailed distribution with 95% conf. - tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1) + if stats: + # Two-tailed distribution with 95% conf. + tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1) - # Compute confidence interval - result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count) + # Compute confidence interval + result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count) + else: + result.sampleConfidenceInterval = result.sampleStdDev else: result.sampleStdDev = 0 result.unbiasedEst = 0 @@ -377,14 +380,19 @@ def analyzeResult(subdir, main, comparisons): tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count)) degreesOfFreedom = 2 * main.count - 2 - # Two-tailed distribution with 95% conf. - tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) + if stats: + # Two-tailed distribution with 95% conf. + tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) - noSignificantDifference = tStat < tCritical + noSignificantDifference = tStat < tCritical + pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom)) + else: + noSignificantDifference = None + pValue = -1 - pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom)) - - if noSignificantDifference: + if noSignificantDifference is None: + verdict = "" + elif noSignificantDifference: verdict = "likely same" elif main.avg < compare.avg: verdict = "likely worse" diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index c85fac7..ae2399e 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -257,7 +257,7 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message) lua_State* L = lua_newthread(globalState); luaL_sandboxthread(L); - if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0) { interruptDeadline = std::chrono::system_clock::now() + kInterruptTimeout; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 8a7798f..5a7c860 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -91,10 +91,6 @@ struct ACFixture : ACFixtureImpl { }; -struct UnfrozenACFixture : ACFixtureImpl -{ -}; - TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") @@ -1919,9 +1915,10 @@ local bar: @1= foo CHECK(!ac.entryMap.count("foo")); } -// CLI-45692: Remove UnfrozenACFixture here -TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_function_no_parenthesis") +TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") { + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end @@ -1950,9 +1947,10 @@ local fp: @1= f CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -// CLI-45692: Remove UnfrozenACFixture here -TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_keywords") +TEST_CASE_FIXTURE(ACFixture, "type_correct_keywords") { + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + check(R"( local function a(x: boolean) end local function b(x: number?) end @@ -2484,7 +2482,7 @@ local t = { CHECK(ac.entryMap.count("second")); } -TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") +TEST_CASE_FIXTURE(UnfrozenFixture, "autocomplete_documentation_symbols") { loadDefinition(R"( declare y: { diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 7f03019..4ce8d08 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -11,8 +11,6 @@ #include LUAU_FASTFLAG(LuauPreloadClosures) -LUAU_FASTFLAG(LuauPreloadClosuresFenv) -LUAU_FASTFLAG(LuauPreloadClosuresUpval) LUAU_FASTFLAG(LuauGenericSpecialGlobals) using namespace Luau; @@ -2797,7 +2795,7 @@ CAPTURE UPVAL U1 RETURN R0 1 )"); - if (FFlag::LuauPreloadClosuresUpval) + if (FFlag::LuauPreloadClosures) { // recursive capture CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( @@ -3479,15 +3477,13 @@ CAPTURE VAL R0 RETURN R1 1 )"); - if (FFlag::LuauPreloadClosuresFenv) - { - // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion - CHECK_EQ("\n" + compileFunction(R"( + // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion + CHECK_EQ("\n" + compileFunction(R"( setfenv(1, {}) return function() print("hi") end )", - 1), - R"( + 1), + R"( GETIMPORT R0 1 LOADN R1 1 NEWTABLE R2 0 0 @@ -3496,23 +3492,21 @@ NEWCLOSURE R0 P0 RETURN R0 1 )"); - // note that fenv analysis isn't flow-sensitive right now, which is sort of a feature - CHECK_EQ("\n" + compileFunction(R"( + // note that fenv analysis isn't flow-sensitive right now, which is sort of a feature + CHECK_EQ("\n" + compileFunction(R"( if false then setfenv(1, {}) end return function() print("hi") end )", - 1), - R"( + 1), + R"( NEWCLOSURE R0 P0 RETURN R0 1 )"); - } } TEST_CASE("SharedClosure") { ScopedFastFlag sff1("LuauPreloadClosures", true); - ScopedFastFlag sff2("LuauPreloadClosuresUpval", true); // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( @@ -3671,7 +3665,7 @@ RETURN R0 0 )"); } -TEST_CASE("LuauGenericSpecialGlobals") +TEST_CASE("MutableGlobals") { const char* source = R"( print() @@ -3685,43 +3679,6 @@ shared.print() workspace.print() )"; - { - ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", false}; - - // Check Roblox globals are here - CHECK_EQ("\n" + compileFunction0(source), R"( -GETIMPORT R0 1 -CALL R0 0 0 -GETIMPORT R1 3 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 5 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 7 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 9 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 11 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 13 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 15 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 17 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -RETURN R0 0 -)"); - } - - ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", true}; - // Check Roblox globals are no longer here CHECK_EQ("\n" + compileFunction0(source), R"( GETIMPORT R0 1 diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index c1b790b..e495a21 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -1,5 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Compiler.h" +#include "lua.h" +#include "lualib.h" +#include "luacode.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/ModuleResolver.h" @@ -10,9 +12,6 @@ #include "doctest.h" #include "ScopedFlags.h" -#include "lua.h" -#include "lualib.h" - #include #include @@ -49,8 +48,12 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); - std::string bytecode = Luau::compile(std::string(s, l)); - if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) + size_t bytecodeSize = 0; + char* bytecode = luau_compile(s, l, nullptr, &bytecodeSize); + int result = luau_load(L, chunkname, bytecode, bytecodeSize, 0); + free(bytecode); + + if (result == 0) return 1; lua_pushnil(L); @@ -179,21 +182,17 @@ static StateRef runConformance( std::string chunkname = "=" + std::string(name); - Luau::CompileOptions copts; + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; // default copts.debugLevel = 2; // for debugger tests copts.vectorCtor = "vector"; // for vector tests - std::string bytecode = Luau::compile(source, copts); - int status = 0; + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.size(), &copts, &bytecodeSize); + int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0); + free(bytecode); - if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) - { - status = lua_resume(L, nullptr, 0); - } - else - { - status = LUA_ERRSYNTAX; - } + int status = (result == 0) ? lua_resume(L, nullptr, 0) : LUA_ERRSYNTAX; while (yield && (status == LUA_YIELD || status == LUA_BREAK)) { @@ -332,53 +331,61 @@ TEST_CASE("UTF8") TEST_CASE("Coroutine") { + ScopedFastFlag sff("LuauCoroutineClose", true); + runConformance("coroutine.lua"); } +static int cxxthrow(lua_State* L) +{ +#if LUA_USE_LONGJMP + luaL_error(L, "oops"); +#else + throw std::runtime_error("oops"); +#endif +} + TEST_CASE("PCall") { runConformance("pcall.lua", [](lua_State* L) { - lua_pushcfunction(L, [](lua_State* L) -> int { -#if LUA_USE_LONGJMP - luaL_error(L, "oops"); -#else - throw std::runtime_error("oops"); -#endif - }); + lua_pushcfunction(L, cxxthrow, "cxxthrow"); lua_setglobal(L, "cxxthrow"); - lua_pushcfunction(L, [](lua_State* L) -> int { - lua_State* co = lua_tothread(L, 1); - lua_xmove(L, co, 1); - lua_resumeerror(co, L); - return 0; - }); + lua_pushcfunction( + L, + [](lua_State* L) -> int { + lua_State* co = lua_tothread(L, 1); + lua_xmove(L, co, 1); + lua_resumeerror(co, L); + return 0; + }, + "resumeerror"); lua_setglobal(L, "resumeerror"); }); } TEST_CASE("Pack") { - ScopedFastFlag sff{ "LuauStrPackUBCastFix", true }; - + ScopedFastFlag sff{"LuauStrPackUBCastFix", true}; + runConformance("tpack.lua"); } TEST_CASE("Vector") { runConformance("vector.lua", [](lua_State* L) { - lua_pushcfunction(L, lua_vector); + lua_pushcfunction(L, lua_vector, "vector"); lua_setglobal(L, "vector"); lua_pushvector(L, 0.0f, 0.0f, 0.0f); luaL_newmetatable(L, "vector"); lua_pushstring(L, "__index"); - lua_pushcfunction(L, lua_vector_index); + lua_pushcfunction(L, lua_vector_index, nullptr); lua_settable(L, -3); lua_pushstring(L, "__namecall"); - lua_pushcfunction(L, lua_vector_namecall); + lua_pushcfunction(L, lua_vector_namecall, nullptr); lua_settable(L, -3); lua_setreadonly(L, -1, true); @@ -513,15 +520,19 @@ TEST_CASE("Debugger") }; // add breakpoint() function - lua_pushcfunction(L, [](lua_State* L) -> int { - int line = luaL_checkinteger(L, 1); + lua_pushcfunction( + L, + [](lua_State* L) -> int { + int line = luaL_checkinteger(L, 1); + bool enabled = lua_isboolean(L, 2) ? lua_toboolean(L, 2) : true; - lua_Debug ar = {}; - lua_getinfo(L, 1, "f", &ar); + lua_Debug ar = {}; + lua_getinfo(L, 1, "f", &ar); - lua_breakpoint(L, -1, line, true); - return 0; - }); + lua_breakpoint(L, -1, line, enabled); + return 0; + }, + "breakpoint"); lua_setglobal(L, "breakpoint"); }, [](lua_State* L) { @@ -744,7 +755,7 @@ TEST_CASE("ExceptionObject") if (nsize == 0) { free(ptr); - return NULL; + return nullptr; } else if (nsize > 512 * 1024) { diff --git a/tests/IostreamOptional.h b/tests/IostreamOptional.h index e55b5b0..e0756ba 100644 --- a/tests/IostreamOptional.h +++ b/tests/IostreamOptional.h @@ -4,7 +4,8 @@ #include #include -namespace std { +namespace std +{ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) { diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 18f55d2..7a3543c 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -203,6 +203,8 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + TypeVar freeTy(FreeTypeVar{TypeLevel{}}); TypePackVar freeTp(FreeTypePack{TypeLevel{}}); @@ -212,12 +214,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") bool encounteredFreeType = false; TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); - CHECK(Luau::get(clonedTy)); + CHECK_EQ("any", toString(clonedTy)); CHECK(encounteredFreeType); encounteredFreeType = false; TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, &encounteredFreeType); - CHECK(Luau::get(clonedTp)); + CHECK_EQ("...any", toString(clonedTp)); CHECK(encounteredFreeType); } diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index b076e9a..80a258f 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -198,7 +198,8 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table TypeVar tv{ttv}; - ToStringOptions o{/* exhaustive= */ false, /* useLineBreaks= */ false, /* functionTypeArguments= */ false, /* hideTableKind= */ false, 40}; + ToStringOptions o; + o.maxTableLength = 40; CHECK_EQ(toString(&tv, o), "{| a: number, b: number, c: number, d: number, e: number, ... 5 more ... |}"); } @@ -395,7 +396,7 @@ local function target(callback: nil) return callback(4, "hello") end )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(requireType("target")), "(nil) -> (*unknown*)"); + CHECK_EQ("(nil) -> (*unknown*)", toString(requireType("target"))); } TEST_CASE_FIXTURE(Fixture, "toStringGenericPack") @@ -469,4 +470,110 @@ TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") CHECK_EQ(toString(tableTy), "Table"); } +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") +{ + CheckResult result = check(R"( + local function id(x) return x end + )"); + + TypeId ty = requireType("id"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("id(x: a): a", toStringNamedFunction("id", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") +{ + CheckResult result = check(R"( + local function map(arr, fn) + local t = {} + for i = 0, #arr do + t[i] = fn(arr[i]) + end + return t + end + )"); + + TypeId ty = requireType("map"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); +} + +TEST_CASE("toStringNamedFunction_unit_f") +{ + TypePackVar empty{TypePack{}}; + FunctionTypeVar ftv{&empty, &empty, {}, false}; + CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") +{ + CheckResult result = check(R"( + local function f(x: a, ...): (a, a, b...) + return x, x, ... + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(x: a, ...: any): (a, a, b...)", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") +{ + CheckResult result = check(R"( + local function f(): ...number + return 1, 2, 3 + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(): ...number", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") +{ + CheckResult result = check(R"( + local function f(): (string, ...number) + return 'a', 1, 2, 3 + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(): (string, ...number)", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_argnames") +{ + CheckResult result = check(R"( + local f: (number, y: number) -> number + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(_: number, y: number): number", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") +{ + CheckResult result = check(R"( + local function f(x: T, g: (T) -> U)): () + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + ToStringOptions opts; + opts.hideNamedFunctionTypeParameters = true; + CHECK_EQ("f(x: T, g: (T) -> U): ()", toStringNamedFunction("f", *ftv, opts)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index c27f808..74ce155 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -109,7 +109,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_typ CheckResult result = check(R"( type A = number type A = string -- Redefinition of type 'A', previously defined at line 1 - local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" + local foo: string = 1 -- "Type 'number' could not be converted into 'string'" )"); LUAU_REQUIRE_ERROR_COUNT(2, result); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 2e40016..091c2f0 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -381,6 +381,8 @@ TEST_CASE_FIXTURE(Fixture, "typeof_expr") TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( type A = B type B = A @@ -390,7 +392,7 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") )"); TypeId fType = requireType("aa"); - const ErrorTypeVar* ftv = get(follow(fType)); + const AnyTypeVar* ftv = get(follow(fType)); REQUIRE(ftv != nullptr); REQUIRE(!result.errors.empty()); } diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index de2f015..88c2dc8 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -289,7 +289,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") end )"); // TODO: Should typecheck but currently errors CLI-39916 - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "infer_generic_property") @@ -352,7 +352,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") -- so this assignment should fail local b: boolean = f(true) )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") @@ -368,7 +368,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") local y: number = id(37) end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 733fc39..fe8e7ff 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -704,9 +704,10 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); else CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - if (FFlag::LuauQuantifyInPlace2) + if (FFlag::LuauQuantifyInPlace2) CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" else CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp new file mode 100644 index 0000000..5f95efd --- /dev/null +++ b/tests/TypeInfer.singletons.test.cpp @@ -0,0 +1,377 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeSingletons"); + +TEST_CASE_FIXTURE(Fixture, "bool_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = true + local b: false = false + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "foo" + local b: "bar" = "bar" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = false + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'false' could not be converted into 'true'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "bar" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons_escape_chars") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "\n" = "\000\r" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type '"\000\r"' could not be converted into '"\n"')", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "bool_singleton_subtype") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = true + local b: boolean = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "foo" + local b: string = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a: true, b: "foo") end + f(true, "foo") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a: true, b: "foo") end + f(true, "bar") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a, b) end + local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) + g(true, "foo") + g(false, 37) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a, b) end + local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) + g(true, 37) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type MyEnum = "foo" | "bar" | "baz" + local a : MyEnum = "foo" + local b : MyEnum = "bar" + local c : MyEnum = "baz" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauExtendedTypeMismatchError", true}, + }; + + CheckResult result = check(R"( + type MyEnum = "foo" | "bar" | "baz" + local a : MyEnum = "bang" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bang\"' could not be converted into '\"bar\" | \"baz\" | \"foo\"'; none of the union options are compatible", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type MyEnum1 = "foo" | "bar" + type MyEnum2 = MyEnum1 | "baz" + local a : MyEnum1 = "foo" + local b : MyEnum2 = a + local c : string = b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauExpectedTypesOfProperties", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Dog = { tag = "Dog", howls = true } + local b : Animal = { tag = "Cat", meows = true } + local c : Animal = a + c = b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Animal = { tag = "Cat", howls = true } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Animal = { tag = "Cat", meows = true } + a.tag = "Dog" + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type T = { + ["foo"] : number, + ["$$bar"] : string, + baz : boolean + } + local t: T = { + ["foo"] = 37, + ["$$bar"] = "hi", + baz = true + } + local a: number = t.foo + local b: string = t["$$bar"] + local c: boolean = t.baz + t.foo = 5 + t["$$bar"] = "lo" + t.baz = false + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} +TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type T = { + ["$$bar"] : string, + } + local t: T = { + ["$$bar"] = "hi", + } + t["$$bar"] = 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type S = "bar" + type T = { + [("foo")] : number, + [S] : string, + } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + local x: { ["<>"] : number } + x = { ["\n"] = 5 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Table type '{| ["\n"]: number |}' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", + toString(result.errors[0])); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 30d9130..99fd833 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -362,7 +362,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") CHECK_EQ(2, result.errors.size()); TypeId p = requireType("p"); - CHECK_EQ(*p, *typeChecker.errorType); + CHECK_EQ("*unknown*", toString(p)); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") @@ -480,7 +480,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") @@ -496,7 +496,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") @@ -512,7 +512,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") @@ -526,7 +526,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(typeChecker.errorType, requireType("a")); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") @@ -542,7 +542,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(typeChecker.errorType, requireType("a")); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") @@ -673,7 +673,7 @@ TEST_CASE_FIXTURE(Fixture, "string_index") REQUIRE(nat); CHECK_EQ("string", toString(nat->ty)); - CHECK(get(requireType("t"))); + CHECK_EQ("*unknown*", toString(requireType("t"))); } TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") @@ -1456,7 +1456,7 @@ TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") auto hootyType = requireType(bModule, "Hooty"); - CHECK_MESSAGE(get(follow(hootyType)) != nullptr, "Should be an error: " << toString(hootyType)); + CHECK_EQ("*unknown*", toString(hootyType)); } TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") @@ -2032,7 +2032,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); } -TEST_CASE_FIXTURE(Fixture, "error_types_propagate") +TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") { CheckResult result = check(R"( local err = (true).x @@ -2049,10 +2049,10 @@ TEST_CASE_FIXTURE(Fixture, "error_types_propagate") CHECK_EQ("boolean", toString(err->table)); CHECK_EQ("x", err->key); - CHECK(nullptr != get(requireType("c"))); - CHECK(nullptr != get(requireType("d"))); - CHECK(nullptr != get(requireType("e"))); - CHECK(nullptr != get(requireType("f"))); + CHECK_EQ("*unknown*", toString(requireType("c"))); + CHECK_EQ("*unknown*", toString(requireType("d"))); + CHECK_EQ("*unknown*", toString(requireType("e"))); + CHECK_EQ("*unknown*", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") @@ -2068,7 +2068,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") CHECK_EQ("unknown", err->name); - CHECK(nullptr != get(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") @@ -2077,9 +2077,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") local a = Utility.Create "Foo" {} )"); - TypeId aType = requireType("a"); - - REQUIRE_MESSAGE(nullptr != get(aType), "Not an error: " << toString(aType)); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") @@ -2146,6 +2144,8 @@ TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( --!strict local Vec3 = {} @@ -2175,11 +2175,13 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio CHECK_EQ("Vec3", toString(requireType("b"))); CHECK_EQ("Vec3", toString(requireType("c"))); CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK(get(requireType("e"))); + CHECK_EQ("Vec3", toString(requireType("e"))); } TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( --!strict local Vec3 = {} @@ -2209,7 +2211,7 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio CHECK_EQ("Vec3", toString(requireType("b"))); CHECK_EQ("Vec3", toString(requireType("c"))); CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK(get(requireType("e"))); + CHECK_EQ("Vec3", toString(requireType("e"))); } TEST_CASE_FIXTURE(Fixture, "compare_numbers") @@ -2901,6 +2903,8 @@ end TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( local x: number = 9999 function x:y(z: number) @@ -2908,7 +2912,7 @@ function x:y(z: number) end )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); + LUAU_REQUIRE_ERROR_COUNT(2, result); } TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError") @@ -2920,7 +2924,7 @@ function x:y(z: number) end )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions") @@ -3799,7 +3803,7 @@ TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") print(a) )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); } @@ -4215,7 +4219,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK(get(t0->type)); + CHECK_EQ("*unknown*", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -4238,7 +4242,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK(get(t0->type)); + CHECK_EQ("*unknown*", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -4394,6 +4398,25 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") CHECK_EQ(toString(*it), "(number) -> number"); } +TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + type Overload = ((string) -> string) & ((number, number) -> number) + local abc: Overload + local x = abc(true) + local y = abc(true,true) + local z = abc(true,true,true) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("string", toString(requireType("x"))); + CHECK_EQ("number", toString(requireType("y"))); + // Should this be string|number? + CHECK_EQ("string", toString(requireType("z"))); +} + TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") { // Simple direct arg to arg propagation @@ -4740,4 +4763,20 @@ TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions3") } } +TEST_CASE_FIXTURE(Fixture, "type_error_addition") +{ + CheckResult result = check(R"( +--!strict +local foo = makesandwich() +local bar = foo.nutrition + 100 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + // We should definitely get this error + CHECK_EQ("Unknown global 'makesandwich'", toString(result.errors[0])); + // We get this error if makesandwich() returns a free type + // CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 2d697fc..9f9a007 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -121,9 +121,26 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeId bType = requireType("b"); + CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("b"))); +} - CHECK_MESSAGE(get(bType), "Should be an error: " << toString(bType)); +TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + function f(arg: number) return arg end + local a + local b + local c = f(a, b) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("b"))); + CHECK_EQ("number", toString(requireType("c"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails") @@ -167,15 +184,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") CHECK(state.errors.empty()); } -TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_variadic_pack_with_error_should_work") -{ - TypePackId variadicPack = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.numberType}}); - TypePackId errorPack = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{Unifiable::Error{}})}); - - state.tryUnify(variadicPack, errorPack); - REQUIRE_EQ(0, state.errors.size()); -} - TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 9f29b64..48496b8 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -200,8 +200,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->key, "x"); - TypeId r = requireType("r"); - CHECK_MESSAGE(get(r), "Expected error, got " << toString(r)); + CHECK_EQ("*unknown*", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") @@ -283,7 +282,7 @@ local c = b:foo(1, 2) CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "optional_union_follow") +TEST_CASE_FIXTURE(UnfrozenFixture, "optional_union_follow") { CheckResult result = check(R"( local y: number? = 2 diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua index 7532964..4d9b129 100644 --- a/tests/conformance/coroutine.lua +++ b/tests/conformance/coroutine.lua @@ -319,4 +319,58 @@ for i=0,30 do assert(#T2 == 1 or T2[#T2] == 42) end +-- test coroutine.close +do + -- ok to close a dead coroutine + local co = coroutine.create(type) + assert(coroutine.resume(co, "testing 'coroutine.close'")) + assert(coroutine.status(co) == "dead") + local st, msg = coroutine.close(co) + assert(st and msg == nil) + -- also ok to close it again + st, msg = coroutine.close(co) + assert(st and msg == nil) + + + -- cannot close the running coroutine + coroutine.wrap(function() + local st, msg = pcall(coroutine.close, coroutine.running()) + assert(not st and string.find(msg, "running")) + end)() + + -- cannot close a "normal" coroutine + coroutine.wrap(function() + local co = coroutine.running() + coroutine.wrap(function () + local st, msg = pcall(coroutine.close, co) + assert(not st and string.find(msg, "normal")) + end)() + end)() + + -- closing a coroutine after an error + local co = coroutine.create(error) + local obj = {42} + local st, msg = coroutine.resume(co, obj) + assert(not st and msg == obj) + st, msg = coroutine.close(co) + assert(not st and msg == obj) + -- after closing, no more errors + st, msg = coroutine.close(co) + assert(st and msg == nil) + + -- closing a coroutine that has outstanding upvalues + local f + local co = coroutine.create(function() + local a = 42 + f = function() return a end + coroutine.yield() + a = 20 + end) + coroutine.resume(co) + assert(f() == 42) + st, msg = coroutine.close(co) + assert(st and msg == nil) + assert(f() == 42) +end + return'OK' diff --git a/tests/conformance/debugger.lua b/tests/conformance/debugger.lua index 5e69fc6..6ba99fb 100644 --- a/tests/conformance/debugger.lua +++ b/tests/conformance/debugger.lua @@ -45,4 +45,13 @@ breakpoint(38) -- break inside corobad() local co = coroutine.create(corobad) assert(coroutine.resume(co) == false) -- this breaks, resumes and dies! +function bar() + print("in bar") +end + +breakpoint(49) +breakpoint(49, false) -- validate that disabling breakpoints works + +bar() + return 'OK'