From 2daa6497a17348090f6af6716e4f856c44594279 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 30 Jun 2022 16:52:43 -0700 Subject: [PATCH 01/10] Sync to upstream/release/534 (#569) --- Analysis/include/Luau/Constraint.h | 19 +- .../include/Luau/ConstraintGraphBuilder.h | 109 +++- Analysis/include/Luau/ConstraintSolver.h | 12 +- Analysis/include/Luau/Error.h | 33 +- Analysis/include/Luau/Frontend.h | 5 + Analysis/include/Luau/Module.h | 2 +- Analysis/include/Luau/NotNull.h | 14 +- Analysis/include/Luau/Scope.h | 6 +- Analysis/include/Luau/TypeArena.h | 6 + Analysis/include/Luau/TypeInfer.h | 4 +- Analysis/include/Luau/TypeVar.h | 3 + Analysis/src/ConstraintGraphBuilder.cpp | 557 ++++++++++++++---- Analysis/src/ConstraintSolver.cpp | 71 ++- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 10 +- Analysis/src/Frontend.cpp | 22 +- Analysis/src/Normalize.cpp | 235 +------- Analysis/src/Quantify.cpp | 73 ++- Analysis/src/Scope.cpp | 15 + Analysis/src/ToString.cpp | 111 ++-- Analysis/src/TypeChecker2.cpp | 106 +++- Analysis/src/TypeInfer.cpp | 144 ++++- Analysis/src/Unifier.cpp | 4 +- Ast/src/Parser.cpp | 122 +++- Common/include/Luau/Bytecode.h | 5 +- Compiler/src/Builtins.cpp | 4 + Compiler/src/BytecodeBuilder.cpp | 15 +- Compiler/src/Compiler.cpp | 16 +- VM/src/lbaselib.cpp | 15 + VM/src/lbuiltins.cpp | 23 + VM/src/ldebug.h | 1 + VM/src/ltm.cpp | 2 +- VM/src/ltm.h | 2 +- VM/src/lvmexecute.cpp | 54 +- VM/src/lvmutils.cpp | 54 +- fuzz/protoprint.cpp | 11 + tests/Compiler.test.cpp | 16 +- tests/Conformance.test.cpp | 6 + tests/ConstraintGraphBuilder.test.cpp | 30 +- tests/ConstraintSolver.test.cpp | 17 +- tests/Fixture.cpp | 23 +- tests/Fixture.h | 3 +- tests/NotNull.test.cpp | 12 +- tests/Parser.test.cpp | 32 +- tests/ToString.test.cpp | 18 +- tests/TypeInfer.aliases.test.cpp | 35 +- tests/TypeInfer.annotations.test.cpp | 14 +- tests/TypeInfer.functions.test.cpp | 21 + tests/TypeInfer.generics.test.cpp | 5 +- tests/TypeInfer.operators.test.cpp | 55 ++ tests/TypeInfer.provisional.test.cpp | 22 + tests/TypeInfer.tables.test.cpp | 121 +++- tests/TypeInfer.test.cpp | 4 +- tests/conformance/basic.lua | 10 +- tests/conformance/events.lua | 38 ++ 54 files changed, 1714 insertions(+), 653 deletions(-) diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 8a41c9e..dcfb14b 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Ast.h" // Used for some of the enumerations #include "Luau/NotNull.h" #include "Luau/Variant.h" @@ -47,6 +48,21 @@ struct InstantiationConstraint TypeId superType; }; +struct UnaryConstraint +{ + AstExprUnary::Op op; + TypeId operandType; + TypeId resultType; +}; + +struct BinaryConstraint +{ + AstExprBinary::Op op; + TypeId leftType; + TypeId rightType; + TypeId resultType; +}; + // name(namedType) = name struct NameConstraint { @@ -54,7 +70,8 @@ struct NameConstraint std::string name; }; -using ConstraintV = Variant; +using ConstraintV = Variant; using ConstraintPtr = std::unique_ptr; struct Constraint diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 9b11869..a49e859 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -25,9 +25,12 @@ struct ConstraintGraphBuilder // scope pointers; the scopes themselves borrow pointers to other scopes to // define the scope hierarchy. std::vector>> scopes; + + ModuleName moduleName; SingletonTypes& singletonTypes; - TypeArena* const arena; + const NotNull arena; // The root scope of the module we're generating constraints for. + // This is null when the CGB is initially constructed. Scope2* rootScope; // A mapping of AST node to TypeId. DenseHashMap astTypes{nullptr}; @@ -39,40 +42,50 @@ struct ConstraintGraphBuilder // Type packs resolved from type annotations. Analogous to astTypePacks. DenseHashMap astResolvedTypePacks{nullptr}; - explicit ConstraintGraphBuilder(TypeArena* arena); + int recursionCount = 0; + + // It is pretty uncommon for constraint generation to itself produce errors, but it can happen. + std::vector errors; + + // Occasionally constraint generation needs to produce an ICE. + const NotNull ice; + + NotNull globalScope; + + ConstraintGraphBuilder(const ModuleName& moduleName, TypeArena* arena, NotNull ice, NotNull globalScope); /** * Fabricates a new free type belonging to a given scope. - * @param scope the scope the free type belongs to. Must not be null. + * @param scope the scope the free type belongs to. */ - TypeId freshType(Scope2* scope); + TypeId freshType(NotNull scope); /** * Fabricates a new free type pack belonging to a given scope. - * @param scope the scope the free type pack belongs to. Must not be null. + * @param scope the scope the free type pack belongs to. */ - TypePackId freshTypePack(Scope2* scope); + TypePackId freshTypePack(NotNull scope); /** * Fabricates a scope that is a child of another scope. * @param location the lexical extent of the scope in the source code. * @param parent the parent scope of the new scope. Must not be null. */ - Scope2* childScope(Location location, Scope2* parent); + NotNull childScope(Location location, NotNull parent); /** * Adds a new constraint with no dependencies to a given scope. - * @param scope the scope to add the constraint to. Must not be null. + * @param scope the scope to add the constraint to. * @param cv the constraint variant to add. */ - void addConstraint(Scope2* scope, ConstraintV cv); + void addConstraint(NotNull scope, ConstraintV cv); /** * Adds a constraint to a given scope. * @param scope the scope to add the constraint to. Must not be null. * @param c the constraint to add. */ - void addConstraint(Scope2* scope, std::unique_ptr c); + void addConstraint(NotNull scope, std::unique_ptr c); /** * The entry point to the ConstraintGraphBuilder. This will construct a set @@ -81,20 +94,22 @@ struct ConstraintGraphBuilder */ void visit(AstStatBlock* block); - void visit(Scope2* scope, AstStat* stat); - void visit(Scope2* scope, AstStatBlock* block); - void visit(Scope2* scope, AstStatLocal* local); - void visit(Scope2* scope, AstStatLocalFunction* function); - void visit(Scope2* scope, AstStatFunction* function); - void visit(Scope2* scope, AstStatReturn* ret); - void visit(Scope2* scope, AstStatAssign* assign); - void visit(Scope2* scope, AstStatIf* ifStatement); - void visit(Scope2* scope, AstStatTypeAlias* alias); + void visitBlockWithoutChildScope(NotNull scope, AstStatBlock* block); - TypePackId checkExprList(Scope2* scope, const AstArray& exprs); + void visit(NotNull scope, AstStat* stat); + void visit(NotNull scope, AstStatBlock* block); + void visit(NotNull scope, AstStatLocal* local); + void visit(NotNull scope, AstStatLocalFunction* function); + void visit(NotNull scope, AstStatFunction* function); + void visit(NotNull scope, AstStatReturn* ret); + void visit(NotNull scope, AstStatAssign* assign); + void visit(NotNull scope, AstStatIf* ifStatement); + void visit(NotNull scope, AstStatTypeAlias* alias); - TypePackId checkPack(Scope2* scope, AstArray exprs); - TypePackId checkPack(Scope2* scope, AstExpr* expr); + TypePackId checkExprList(NotNull scope, const AstArray& exprs); + + TypePackId checkPack(NotNull scope, AstArray exprs); + TypePackId checkPack(NotNull scope, AstExpr* expr); /** * Checks an expression that is expected to evaluate to one type. @@ -102,19 +117,35 @@ struct ConstraintGraphBuilder * @param expr the expression to check. * @return the type of the expression. */ - TypeId check(Scope2* scope, AstExpr* expr); + TypeId check(NotNull scope, AstExpr* expr); - TypeId checkExprTable(Scope2* scope, AstExprTable* expr); - TypeId check(Scope2* scope, AstExprIndexName* indexName); + TypeId checkExprTable(NotNull scope, AstExprTable* expr); + TypeId check(NotNull scope, AstExprIndexName* indexName); + TypeId check(NotNull scope, AstExprIndexExpr* indexExpr); + TypeId check(NotNull scope, AstExprUnary* unary); + TypeId check(NotNull scope, AstExprBinary* binary); - std::pair checkFunctionSignature(Scope2* parent, AstExprFunction* fn); + struct FunctionSignature + { + // The type of the function. + TypeId signature; + // The scope that encompasses the function's signature. May be nullptr + // if there was no need for a signature scope (the function has no + // generics). + Scope2* signatureScope; + // The scope that encompasses the function's body. Is a child scope of + // signatureScope, if present. + NotNull bodyScope; + }; + + FunctionSignature checkFunctionSignature(NotNull parent, AstExprFunction* fn); /** * Checks the body of a function expression. * @param scope the interior scope of the body of the function. * @param fn the function expression to check. */ - void checkFunctionBody(Scope2* scope, AstExprFunction* fn); + void checkFunctionBody(NotNull scope, AstExprFunction* fn); /** * Resolves a type from its AST annotation. @@ -122,7 +153,7 @@ struct ConstraintGraphBuilder * @param ty the AST annotation to resolve. * @return the type of the AST annotation. **/ - TypeId resolveType(Scope2* scope, AstType* ty); + TypeId resolveType(NotNull scope, AstType* ty); /** * Resolves a type pack from its AST annotation. @@ -130,9 +161,25 @@ struct ConstraintGraphBuilder * @param tp the AST annotation to resolve. * @return the type pack of the AST annotation. **/ - TypePackId resolveTypePack(Scope2* scope, AstTypePack* tp); + TypePackId resolveTypePack(NotNull scope, AstTypePack* tp); - TypePackId resolveTypePack(Scope2* scope, const AstTypeList& list); + TypePackId resolveTypePack(NotNull scope, const AstTypeList& list); + + std::vector> createGenerics(NotNull scope, AstArray generics); + std::vector> createGenericPacks(NotNull scope, AstArray packs); + + TypeId flattenPack(NotNull scope, Location location, TypePackId tp); + + void reportError(Location location, TypeErrorData err); + void reportCodeTooComplex(Location location); + + /** Scan the program for global definitions. + * + * ConstraintGraphBuilder needs to differentiate between globals and accesses to undefined symbols. Doing this "for + * real" in a general way is going to be pretty hard, so we are choosing not to tackle that yet. For now, we do an + * initial scan of the AST and note what globals are defined. + */ + void prepopulateGlobalScope(NotNull globalScope, AstStatBlock* program); }; /** @@ -145,6 +192,6 @@ struct ConstraintGraphBuilder * @return a list of pointers to constraints contained within the scope graph. * None of these pointers should be null. */ -std::vector> collectConstraints(Scope2* rootScope); +std::vector> collectConstraints(NotNull rootScope); } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 4870157..cf88efb 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -25,7 +25,7 @@ struct ConstraintSolver // is important to not add elements to this vector, lest the underlying // storage that we retain pointers to be mutated underneath us. const std::vector> constraints; - Scope2* rootScope; + NotNull rootScope; // This includes every constraint that has not been fully solved. // A constraint can be both blocked and unsolved, for instance. @@ -40,7 +40,7 @@ struct ConstraintSolver ConstraintSolverLogger logger; - explicit ConstraintSolver(TypeArena* arena, Scope2* rootScope); + explicit ConstraintSolver(TypeArena* arena, NotNull rootScope); /** * Attempts to dispatch all pending constraints and reach a type solution @@ -50,11 +50,17 @@ struct ConstraintSolver bool done(); + /** Attempt to dispatch a constraint. Returns true if it was successful. + * If tryDispatch() returns false, the constraint remains in the unsolved set and will be retried later. + */ bool tryDispatch(NotNull c, bool force); + bool tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force); bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force); bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force); bool tryDispatch(const InstantiationConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const UnaryConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force); bool tryDispatch(const NameConstraint& c, NotNull constraint); void block(NotNull target, NotNull constraint); @@ -115,6 +121,6 @@ private: void unblock_(BlockedConstraintId progressed); }; -void dump(Scope2* rootScope, struct ToStringOptions& opts); +void dump(NotNull rootScope, struct ToStringOptions& opts); } // namespace Luau diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index a132396..4c81d33 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -369,24 +369,25 @@ struct InternalErrorReporter [[noreturn]] void ice(const std::string& message); }; -class InternalCompilerError : public std::exception { +class InternalCompilerError : public std::exception +{ public: - explicit InternalCompilerError(const std::string& message, const std::string& moduleName) - : message(message) - , moduleName(moduleName) - { - } - explicit InternalCompilerError(const std::string& message, const std::string& moduleName, const Location& location) - : message(message) - , moduleName(moduleName) - , location(location) - { - } - virtual const char* what() const throw(); + explicit InternalCompilerError(const std::string& message, const std::string& moduleName) + : message(message) + , moduleName(moduleName) + { + } + explicit InternalCompilerError(const std::string& message, const std::string& moduleName, const Location& location) + : message(message) + , moduleName(moduleName) + , location(location) + { + } + virtual const char* what() const throw(); - const std::string message; - const std::string moduleName; - const std::optional location; + const std::string message; + const std::string moduleName; + const std::optional location; }; } // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index f4226cc..f0d4309 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -5,6 +5,7 @@ #include "Luau/Module.h" #include "Luau/ModuleResolver.h" #include "Luau/RequireTracer.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" @@ -158,6 +159,8 @@ struct Frontend void registerBuiltinDefinition(const std::string& name, std::function); void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); + NotNull getGlobalScope2(); + private: ModulePtr check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope); @@ -173,6 +176,8 @@ private: std::unordered_map environments; std::unordered_map> builtinDefinitions; + std::unique_ptr globalScope2; + public: FileResolver* fileResolver; FrontendModuleResolver moduleResolver; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 39f8dfb..b3105b7 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -68,7 +68,7 @@ struct Module std::shared_ptr allocator; std::shared_ptr names; - std::vector> scopes; // never empty + std::vector> scopes; // never empty std::vector>> scope2s; // never empty DenseHashMap astTypes{nullptr}; diff --git a/Analysis/include/Luau/NotNull.h b/Analysis/include/Luau/NotNull.h index f6043e9..714fa14 100644 --- a/Analysis/include/Luau/NotNull.h +++ b/Analysis/include/Luau/NotNull.h @@ -26,7 +26,7 @@ namespace Luau * The explicit delete statement is permitted (but not recommended) on a * NotNull through this implicit conversion. */ -template +template struct NotNull { explicit NotNull(T* t) @@ -38,10 +38,11 @@ struct NotNull explicit NotNull(std::nullptr_t) = delete; void operator=(std::nullptr_t) = delete; - template + template NotNull(NotNull other) : ptr(other.get()) - {} + { + } operator T*() const noexcept { @@ -72,12 +73,13 @@ private: T* ptr; }; -} +} // namespace Luau namespace std { -template struct hash> +template +struct hash> { size_t operator()(const Luau::NotNull& p) const { @@ -85,4 +87,4 @@ template struct hash> } }; -} +} // namespace std diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index cef4b94..0eaecf1 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -3,6 +3,7 @@ #include "Luau/Constraint.h" #include "Luau/Location.h" +#include "Luau/NotNull.h" #include "Luau/TypeVar.h" #include @@ -71,15 +72,18 @@ struct Scope2 // is the module-level scope). Scope2* parent = nullptr; // All the children of this scope. - std::vector children; + std::vector> children; std::unordered_map bindings; // TODO: I think this can be a DenseHashMap std::unordered_map typeBindings; + std::unordered_map typePackBindings; TypePackId returnType; + std::optional varargPack; // All constraints belonging to this scope. std::vector constraints; std::optional lookup(Symbol sym); std::optional lookupTypeBinding(const Name& name); + std::optional lookupTypePackBinding(const Name& name); }; } // namespace Luau diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index 559c55c..be36f19 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -34,6 +34,12 @@ struct TypeArena TypePackId addTypePack(std::vector types); TypePackId addTypePack(TypePack pack); TypePackId addTypePack(TypePackVar pack); + + template + TypePackId addTypePack(T tp) + { + return addTypePack(TypePackVar(std::move(tp))); + } }; void freeze(TypeArena& arena); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 28adc9d..455654d 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -173,7 +173,7 @@ struct TypeChecker TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, - std::optional originalNameLoc, std::optional expectedType); + std::optional originalNameLoc, std::optional selfType, std::optional expectedType); void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); void checkArgumentList( @@ -424,6 +424,8 @@ private: * (exported, name) to properly deal with the case where the two duplicates do not have the same export status. */ DenseHashSet, HashBoolNamePair> duplicateTypeAliases; + + std::vector> deferredQuantification; }; // Unit test hook diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 20f4107..6ad6b92 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -357,6 +357,9 @@ struct TableTypeVar std::optional boundTo; Tags tags; + + // Methods of this table that have an untyped self will use the same shared self type. + std::optional selfTy; }; // Represents a metatable attached to a table typevar. Somewhat analogous to a bound typevar. diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index d9e8d23..3b9000c 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -1,6 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/ConstraintGraphBuilder.h" +#include "Luau/RecursionCounter.h" +#include "Luau/ToString.h" + +LUAU_FASTINT(LuauCheckRecursionLimit); #include "Luau/Scope.h" @@ -9,32 +13,33 @@ namespace Luau const AstStat* getFallthrough(const AstStat* node); // TypeInfer.cpp -ConstraintGraphBuilder::ConstraintGraphBuilder(TypeArena* arena) - : singletonTypes(getSingletonTypes()) +ConstraintGraphBuilder::ConstraintGraphBuilder( + const ModuleName& moduleName, TypeArena* arena, NotNull ice, NotNull globalScope) + : moduleName(moduleName) + , singletonTypes(getSingletonTypes()) , arena(arena) , rootScope(nullptr) + , ice(ice) + , globalScope(globalScope) { LUAU_ASSERT(arena); } -TypeId ConstraintGraphBuilder::freshType(Scope2* scope) +TypeId ConstraintGraphBuilder::freshType(NotNull scope) { - LUAU_ASSERT(scope); return arena->addType(FreeTypeVar{scope}); } -TypePackId ConstraintGraphBuilder::freshTypePack(Scope2* scope) +TypePackId ConstraintGraphBuilder::freshTypePack(NotNull scope) { - LUAU_ASSERT(scope); FreeTypePack f{scope}; return arena->addTypePack(TypePackVar{std::move(f)}); } -Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent) +NotNull ConstraintGraphBuilder::childScope(Location location, NotNull parent) { - LUAU_ASSERT(parent); auto scope = std::make_unique(); - Scope2* borrow = scope.get(); + NotNull borrow = NotNull(scope.get()); scopes.emplace_back(location, std::move(scope)); borrow->parent = parent; @@ -44,15 +49,13 @@ Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent) return borrow; } -void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv) +void ConstraintGraphBuilder::addConstraint(NotNull scope, ConstraintV cv) { - LUAU_ASSERT(scope); scope->constraints.emplace_back(new Constraint{std::move(cv)}); } -void ConstraintGraphBuilder::addConstraint(Scope2* scope, std::unique_ptr c) +void ConstraintGraphBuilder::addConstraint(NotNull scope, std::unique_ptr c) { - LUAU_ASSERT(scope); scope->constraints.emplace_back(std::move(c)); } @@ -62,7 +65,11 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) LUAU_ASSERT(rootScope == nullptr); scopes.emplace_back(block->location, std::make_unique()); rootScope = scopes.back().second.get(); - rootScope->returnType = freshTypePack(rootScope); + NotNull borrow = NotNull(rootScope); + + rootScope->returnType = freshTypePack(borrow); + + prepopulateGlobalScope(borrow, block); // TODO: We should share the global scope. rootScope->typeBindings["nil"] = singletonTypes.nilType; @@ -71,12 +78,26 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block) rootScope->typeBindings["boolean"] = singletonTypes.booleanType; rootScope->typeBindings["thread"] = singletonTypes.threadType; - visit(rootScope, block); + visitBlockWithoutChildScope(borrow, block); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat) +void ConstraintGraphBuilder::visitBlockWithoutChildScope(NotNull scope, AstStatBlock* block) { - LUAU_ASSERT(scope); + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(block->location); + return; + } + + for (AstStat* stat : block->body) + visit(scope, stat); +} + +void ConstraintGraphBuilder::visit(NotNull scope, AstStat* stat) +{ + RecursionLimiter limiter{&recursionCount, FInt::LuauCheckRecursionLimit}; if (auto s = stat->as()) visit(scope, s); @@ -100,10 +121,8 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat) LUAU_ASSERT(0); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatLocal* local) { - LUAU_ASSERT(scope); - std::vector varTypes; for (AstLocal* local : local->vars) @@ -148,23 +167,19 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) } } -void addConstraints(Constraint* constraint, Scope2* scope) +void addConstraints(Constraint* constraint, NotNull scope) { - LUAU_ASSERT(scope); - scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); for (const auto& c : scope->constraints) constraint->dependencies.push_back(NotNull{c.get()}); - for (Scope2* childScope : scope->children) + for (NotNull childScope : scope->children) addConstraints(constraint, childScope); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatLocalFunction* function) { - LUAU_ASSERT(scope); - // Local // Global // Dotted path @@ -172,36 +187,31 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function TypeId functionType = nullptr; auto ty = scope->lookup(function->name); - if (ty.has_value()) - { - // TODO: This is duplicate definition of a local function. Is this allowed? - functionType = *ty; - } - else - { - functionType = arena->addType(BlockedTypeVar{}); - scope->bindings[function->name] = functionType; - } + LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. - auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func); - innerScope->bindings[function->name] = actualFunctionType; + functionType = arena->addType(BlockedTypeVar{}); + scope->bindings[function->name] = functionType; - checkFunctionBody(innerScope, function->func); + FunctionSignature sig = checkFunctionSignature(scope, function->func); + sig.bodyScope->bindings[function->name] = sig.signature; - std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; - addConstraints(c.get(), innerScope); + checkFunctionBody(sig.bodyScope, function->func); + + std::unique_ptr c{ + new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope : sig.bodyScope}}}; + addConstraints(c.get(), sig.bodyScope); addConstraint(scope, std::move(c)); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatFunction* function) { // Name could be AstStatLocal, AstStatGlobal, AstStatIndexName. // With or without self TypeId functionType = nullptr; - auto [actualFunctionType, innerScope] = checkFunctionSignature(scope, function->func); + FunctionSignature sig = checkFunctionSignature(scope, function->func); if (AstExprLocal* localName = function->name->as()) { @@ -216,7 +226,7 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) functionType = arena->addType(BlockedTypeVar{}); scope->bindings[localName->local] = functionType; } - innerScope->bindings[localName->local] = actualFunctionType; + sig.bodyScope->bindings[localName->local] = sig.signature; } else if (AstExprGlobal* globalName = function->name->as()) { @@ -231,32 +241,48 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatFunction* function) functionType = arena->addType(BlockedTypeVar{}); rootScope->bindings[globalName->name] = functionType; } - innerScope->bindings[globalName->name] = actualFunctionType; + sig.bodyScope->bindings[globalName->name] = sig.signature; } else if (AstExprIndexName* indexName = function->name->as()) { - LUAU_ASSERT(0); // not yet implemented + TypeId containingTableType = check(scope, indexName->expr); + + functionType = arena->addType(BlockedTypeVar{}); + TypeId prospectiveTableType = + arena->addType(TableTypeVar{}); // TODO look into stack utilization. This is probably ok because it scales with AST depth. + NotNull prospectiveTable{getMutable(prospectiveTableType)}; + + Property& prop = prospectiveTable->props[indexName->index.value]; + prop.type = functionType; + prop.location = function->name->location; + + addConstraint(scope, SubtypeConstraint{containingTableType, prospectiveTableType}); + } + else if (AstExprError* err = function->name->as()) + { + functionType = singletonTypes.errorRecoveryType(); } - checkFunctionBody(innerScope, function->func); + LUAU_ASSERT(functionType != nullptr); - std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; - addConstraints(c.get(), innerScope); + checkFunctionBody(sig.bodyScope, function->func); + + std::unique_ptr c{ + new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope : sig.bodyScope}}}; + addConstraints(c.get(), sig.bodyScope); addConstraint(scope, std::move(c)); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatReturn* ret) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatReturn* ret) { - LUAU_ASSERT(scope); - TypePackId exprTypes = checkPack(scope, ret->list); addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatBlock* block) { - LUAU_ASSERT(scope); + NotNull innerScope = childScope(block->location, scope); // In order to enable mutually-recursive type aliases, we need to // populate the type bindings before we actually check any of the @@ -271,11 +297,10 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) } } - for (AstStat* stat : block->body) - visit(scope, stat); + visitBlockWithoutChildScope(innerScope, block); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatAssign* assign) { TypePackId varPackId = checkExprList(scope, assign->vars); TypePackId valuePack = checkPack(scope, assign->values); @@ -283,21 +308,21 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatAssign* assign) addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}); } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatIf* ifStatement) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatIf* ifStatement) { check(scope, ifStatement->condition); - Scope2* thenScope = childScope(ifStatement->thenbody->location, scope); + NotNull thenScope = childScope(ifStatement->thenbody->location, scope); visit(thenScope, ifStatement->thenbody); if (ifStatement->elsebody) { - Scope2* elseScope = childScope(ifStatement->elsebody->location, scope); + NotNull elseScope = childScope(ifStatement->elsebody->location, scope); visit(elseScope, ifStatement->elsebody); } } -void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias) +void ConstraintGraphBuilder::visit(NotNull scope, AstStatTypeAlias* alias) { // TODO: Exported type aliases // TODO: Generic type aliases @@ -307,6 +332,10 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias) // AST to set up typeBindings. If it's not, we've somehow skipped // this alias in that first pass. LUAU_ASSERT(it != scope->typeBindings.end()); + if (it == scope->typeBindings.end()) + { + ice->ice("Type alias does not have a pre-populated binding", alias->location); + } TypeId ty = resolveType(scope, alias->type); @@ -319,10 +348,8 @@ void ConstraintGraphBuilder::visit(Scope2* scope, AstStatTypeAlias* alias) addConstraint(scope, NameConstraint{ty, alias->name.value}); } -TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray exprs) +TypePackId ConstraintGraphBuilder::checkPack(NotNull scope, AstArray exprs) { - LUAU_ASSERT(scope); - if (exprs.size == 0) return arena->addTypePack({}); @@ -342,7 +369,7 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray e return arena->addTypePack(TypePack{std::move(types), last}); } -TypePackId ConstraintGraphBuilder::checkExprList(Scope2* scope, const AstArray& exprs) +TypePackId ConstraintGraphBuilder::checkExprList(NotNull scope, const AstArray& exprs) { TypePackId result = arena->addTypePack({}); TypePack* resultPack = getMutable(result); @@ -363,9 +390,15 @@ TypePackId ConstraintGraphBuilder::checkExprList(Scope2* scope, const AstArray scope, AstExpr* expr) { - LUAU_ASSERT(scope); + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(expr->location); + return singletonTypes.errorRecoveryTypePack(); + } TypePackId result = nullptr; @@ -384,7 +417,7 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) astOriginalCallTypes[call->func] = fnType; - TypeId instantiatedType = freshType(scope); + TypeId instantiatedType = arena->addType(BlockedTypeVar{}); addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}); TypePackId rets = freshTypePack(scope); @@ -394,6 +427,13 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}); result = rets; } + else if (AstExprVarargs* varargs = expr->as()) + { + if (scope->varargPack) + result = *scope->varargPack; + else + result = singletonTypes.errorRecoveryTypePack(); + } else { TypeId t = check(scope, expr); @@ -405,9 +445,15 @@ TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) return result; } -TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExpr* expr) { - LUAU_ASSERT(scope); + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(expr->location); + return singletonTypes.errorRecoveryType(); + } TypeId result = nullptr; @@ -435,37 +481,38 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) if (ty) result = *ty; else - result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? - } - else if (auto a = expr->as()) - { - TypePackId packResult = checkPack(scope, expr); - if (auto f = first(packResult)) - return *f; - else if (get(packResult)) { - TypeId typeResult = freshType(scope); - TypePack onePack{{typeResult}, freshTypePack(scope)}; - TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); - - addConstraint(scope, PackSubtypeConstraint{packResult, oneTypePack}); - - return typeResult; + /* prepopulateGlobalScope() has already added all global functions to the environment by this point, so any + * global that is not already in-scope is definitely an unknown symbol. + */ + reportError(g->location, UnknownSymbol{g->name.value}); + result = singletonTypes.errorRecoveryType(); // FIXME? Record an error at this point? } } + else if (expr->is()) + result = flattenPack(scope, expr->location, checkPack(scope, expr)); + else if (expr->is()) + result = flattenPack(scope, expr->location, checkPack(scope, expr)); else if (auto a = expr->as()) { - auto [fnType, functionScope] = checkFunctionSignature(scope, a); - checkFunctionBody(functionScope, a); - return fnType; + FunctionSignature sig = checkFunctionSignature(scope, a); + checkFunctionBody(sig.bodyScope, a); + return sig.signature; } else if (auto indexName = expr->as()) - { result = check(scope, indexName); - } + else if (auto indexExpr = expr->as()) + result = check(scope, indexExpr); else if (auto table = expr->as()) - { result = checkExprTable(scope, table); + else if (auto unary = expr->as()) + result = check(scope, unary); + else if (auto binary = expr->as()) + result = check(scope, binary); + else if (auto err = expr->as()) + { + // Open question: Should we traverse into this? + result = singletonTypes.errorRecoveryType(); } else { @@ -478,7 +525,7 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) return result; } -TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName) +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprIndexName* indexName) { TypeId obj = check(scope, indexName->expr); TypeId result = freshType(scope); @@ -494,7 +541,67 @@ TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExprIndexName* indexName) return result; } -TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprIndexExpr* indexExpr) +{ + TypeId obj = check(scope, indexExpr->expr); + TypeId indexType = check(scope, indexExpr->index); + + TypeId result = freshType(scope); + + TableIndexer indexer{indexType, result}; + TypeId tableType = arena->addType(TableTypeVar{TableTypeVar::Props{}, TableIndexer{indexType, result}, TypeLevel{}, TableState::Free}); + + addConstraint(scope, SubtypeConstraint{obj, tableType}); + + return result; +} + +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprUnary* unary) +{ + TypeId operandType = check(scope, unary->expr); + + switch (unary->op) + { + case AstExprUnary::Minus: + { + TypeId resultType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, UnaryConstraint{AstExprUnary::Minus, operandType, resultType}); + return resultType; + } + default: + LUAU_ASSERT(0); + } + + LUAU_UNREACHABLE(); + return singletonTypes.errorRecoveryType(); +} + +TypeId ConstraintGraphBuilder::check(NotNull scope, AstExprBinary* binary) +{ + TypeId leftType = check(scope, binary->left); + TypeId rightType = check(scope, binary->right); + switch (binary->op) + { + case AstExprBinary::Or: + { + addConstraint(scope, SubtypeConstraint{leftType, rightType}); + return leftType; + } + case AstExprBinary::Sub: + { + TypeId resultType = arena->addType(BlockedTypeVar{}); + addConstraint(scope, BinaryConstraint{AstExprBinary::Sub, leftType, rightType, resultType}); + return resultType; + } + default: + LUAU_ASSERT(0); + } + + LUAU_ASSERT(0); + return nullptr; +} + +TypeId ConstraintGraphBuilder::checkExprTable(NotNull scope, AstExprTable* expr) { TypeId ty = arena->addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(ty); @@ -515,6 +622,8 @@ TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) for (const AstExprTable::Item& item : expr->items) { TypeId itemTy = check(scope, item.value); + if (get(follow(itemTy))) + return ty; if (item.key) { @@ -542,47 +651,111 @@ TypeId ConstraintGraphBuilder::checkExprTable(Scope2* scope, AstExprTable* expr) return ty; } -std::pair ConstraintGraphBuilder::checkFunctionSignature(Scope2* parent, AstExprFunction* fn) +ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionSignature(NotNull parent, AstExprFunction* fn) { - Scope2* innerScope = childScope(fn->body->location, parent); - TypePackId returnType = freshTypePack(innerScope); - innerScope->returnType = returnType; + Scope2* signatureScope = nullptr; + Scope2* bodyScope = nullptr; + TypePackId returnType = nullptr; + + std::vector genericTypes; + std::vector genericTypePacks; + + bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; + + // If we don't have any generics, we can save some memory and compute by not + // creating the signatureScope, which is only used to scope the declared + // generics properly. + if (hasGenerics) + { + NotNull signatureBorrow = childScope(fn->location, parent); + signatureScope = signatureBorrow.get(); + + // We need to assign returnType before creating bodyScope so that the + // return type gets propogated to bodyScope. + returnType = freshTypePack(signatureBorrow); + signatureScope->returnType = returnType; + + bodyScope = childScope(fn->body->location, signatureBorrow).get(); + + std::vector> genericDefinitions = createGenerics(signatureBorrow, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureBorrow, fn->genericPacks); + + // We do not support default values on function generics, so we only + // care about the types involved. + for (const auto& [name, g] : genericDefinitions) + { + genericTypes.push_back(g.ty); + signatureScope->typeBindings[name] = g.ty; + } + + for (const auto& [name, g] : genericPackDefinitions) + { + genericTypePacks.push_back(g.tp); + signatureScope->typePackBindings[name] = g.tp; + } + } + else + { + NotNull bodyBorrow = childScope(fn->body->location, parent); + bodyScope = bodyBorrow.get(); + + returnType = freshTypePack(bodyBorrow); + bodyBorrow->returnType = returnType; + + // To eliminate the need to branch on hasGenerics below, we say that the + // signature scope is the body scope when there is no real signature + // scope. + signatureScope = bodyScope; + } + + NotNull bodyBorrow = NotNull(bodyScope); + NotNull signatureBorrow = NotNull(signatureScope); if (fn->returnAnnotation) { - TypePackId annotatedRetType = resolveTypePack(innerScope, *fn->returnAnnotation); - addConstraint(innerScope, PackSubtypeConstraint{returnType, annotatedRetType}); + TypePackId annotatedRetType = resolveTypePack(signatureBorrow, *fn->returnAnnotation); + addConstraint(signatureBorrow, PackSubtypeConstraint{returnType, annotatedRetType}); } std::vector argTypes; for (AstLocal* local : fn->args) { - TypeId t = freshType(innerScope); + TypeId t = freshType(signatureBorrow); argTypes.push_back(t); - innerScope->bindings[local] = t; + signatureScope->bindings[local] = t; if (local->annotation) { - TypeId argAnnotation = resolveType(innerScope, local->annotation); - addConstraint(innerScope, SubtypeConstraint{t, argAnnotation}); + TypeId argAnnotation = resolveType(signatureBorrow, local->annotation); + addConstraint(signatureBorrow, SubtypeConstraint{t, argAnnotation}); } } // TODO: Vararg annotation. + // TODO: Preserve argument names in the function's type. FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; + actualFunction.hasNoGenerics = !hasGenerics; + actualFunction.generics = std::move(genericTypes); + actualFunction.genericPacks = std::move(genericTypePacks); + TypeId actualFunctionType = arena->addType(std::move(actualFunction)); LUAU_ASSERT(actualFunctionType); astTypes[fn] = actualFunctionType; - return {actualFunctionType, innerScope}; + return { + /* signature */ actualFunctionType, + // Undo the workaround we made above: if there's no signature scope, + // don't report it. + /* signatureScope */ hasGenerics ? signatureScope : nullptr, + /* bodyScope */ bodyBorrow, + }; } -void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* fn) +void ConstraintGraphBuilder::checkFunctionBody(NotNull scope, AstExprFunction* fn) { - for (AstStat* stat : fn->body->body) - visit(scope, stat); + visitBlockWithoutChildScope(scope, fn->body); // If it is possible for execution to reach the end of the function, the return type must be compatible with () @@ -593,7 +766,7 @@ void ConstraintGraphBuilder::checkFunctionBody(Scope2* scope, AstExprFunction* f } } -TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty) +TypeId ConstraintGraphBuilder::resolveType(NotNull scope, AstType* ty) { TypeId result = nullptr; @@ -636,29 +809,73 @@ TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty) } else if (auto fn = ty->as()) { - // TODO: Generic functions. - // TODO: Scope (though it may not be needed). // TODO: Recursion limit. - TypePackId argTypes = resolveTypePack(scope, fn->argTypes); - TypePackId returnTypes = resolveTypePack(scope, fn->returnTypes); + bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; + Scope2* signatureScope = nullptr; - // TODO: Is this the right constructor to use? - result = arena->addType(FunctionTypeVar{argTypes, returnTypes}); + std::vector genericTypes; + std::vector genericTypePacks; - FunctionTypeVar* ftv = getMutable(result); - ftv->argNames.reserve(fn->argNames.size); + // If we don't have generics, we do not need to generate a child scope + // for the generic bindings to live on. + if (hasGenerics) + { + NotNull signatureBorrow = childScope(fn->location, scope); + signatureScope = signatureBorrow.get(); + + std::vector> genericDefinitions = createGenerics(signatureBorrow, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureBorrow, fn->genericPacks); + + for (const auto& [name, g] : genericDefinitions) + { + genericTypes.push_back(g.ty); + signatureBorrow->typeBindings[name] = g.ty; + } + + for (const auto& [name, g] : genericPackDefinitions) + { + genericTypePacks.push_back(g.tp); + signatureBorrow->typePackBindings[name] = g.tp; + } + } + else + { + // To eliminate the need to branch on hasGenerics below, we say that + // the signature scope is the parent scope if we don't have + // generics. + signatureScope = scope.get(); + } + + NotNull signatureBorrow(signatureScope); + + TypePackId argTypes = resolveTypePack(signatureBorrow, fn->argTypes); + TypePackId returnTypes = resolveTypePack(signatureBorrow, fn->returnTypes); + + // TODO: FunctionTypeVar needs a pointer to the scope so that we know + // how to quantify/instantiate it. + FunctionTypeVar ftv{argTypes, returnTypes}; + + // This replicates the behavior of the appropriate FunctionTypeVar + // constructors. + ftv.hasNoGenerics = !hasGenerics; + ftv.generics = std::move(genericTypes); + ftv.genericPacks = std::move(genericTypePacks); + + ftv.argNames.reserve(fn->argNames.size); for (const auto& el : fn->argNames) { if (el) { const auto& [name, location] = *el; - ftv->argNames.push_back(FunctionArgument{name.value, location}); + ftv.argNames.push_back(FunctionArgument{name.value, location}); } else { - ftv->argNames.push_back(std::nullopt); + ftv.argNames.push_back(std::nullopt); } } + + result = arena->addType(std::move(ftv)); } else if (auto tof = ty->as()) { @@ -710,7 +927,7 @@ TypeId ConstraintGraphBuilder::resolveType(Scope2* scope, AstType* ty) return result; } -TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, AstTypePack* tp) +TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull scope, AstTypePack* tp) { TypePackId result; if (auto expl = tp->as()) @@ -736,7 +953,7 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, AstTypePack* t return result; } -TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, const AstTypeList& list) +TypePackId ConstraintGraphBuilder::resolveTypePack(NotNull scope, const AstTypeList& list) { std::vector head; @@ -754,16 +971,108 @@ TypePackId ConstraintGraphBuilder::resolveTypePack(Scope2* scope, const AstTypeL return arena->addTypePack(TypePack{head, tail}); } -void collectConstraints(std::vector>& result, Scope2* scope) +std::vector> ConstraintGraphBuilder::createGenerics(NotNull scope, AstArray generics) +{ + std::vector> result; + for (const auto& generic : generics) + { + TypeId genericTy = arena->addType(GenericTypeVar{scope, generic.name.value}); + std::optional defaultTy = std::nullopt; + + if (generic.defaultValue) + defaultTy = resolveType(scope, generic.defaultValue); + + result.push_back({generic.name.value, GenericTypeDefinition{ + genericTy, + defaultTy, + }}); + } + + return result; +} + +std::vector> ConstraintGraphBuilder::createGenericPacks( + NotNull scope, AstArray generics) +{ + std::vector> result; + for (const auto& generic : generics) + { + TypePackId genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope, generic.name.value}}); + std::optional defaultTy = std::nullopt; + + if (generic.defaultValue) + defaultTy = resolveTypePack(scope, generic.defaultValue); + + result.push_back({generic.name.value, GenericTypePackDefinition{ + genericTy, + defaultTy, + }}); + } + + return result; +} + +TypeId ConstraintGraphBuilder::flattenPack(NotNull scope, Location location, TypePackId tp) +{ + if (auto f = first(tp)) + return *f; + + TypeId typeResult = freshType(scope); + TypePack onePack{{typeResult}, freshTypePack(scope)}; + TypePackId oneTypePack = arena->addTypePack(std::move(onePack)); + + addConstraint(scope, PackSubtypeConstraint{tp, oneTypePack}); + + return typeResult; +} + +void ConstraintGraphBuilder::reportError(Location location, TypeErrorData err) +{ + errors.push_back(TypeError{location, moduleName, std::move(err)}); +} + +void ConstraintGraphBuilder::reportCodeTooComplex(Location location) +{ + errors.push_back(TypeError{location, moduleName, CodeTooComplex{}}); +} + +struct GlobalPrepopulator : AstVisitor +{ + const NotNull globalScope; + const NotNull arena; + + GlobalPrepopulator(NotNull globalScope, NotNull arena) + : globalScope(globalScope) + , arena(arena) + { + } + + bool visit(AstStatFunction* function) override + { + if (AstExprGlobal* g = function->name->as()) + globalScope->bindings[g->name] = arena->addType(BlockedTypeVar{}); + + return true; + } +}; + +void ConstraintGraphBuilder::prepopulateGlobalScope(NotNull globalScope, AstStatBlock* program) +{ + GlobalPrepopulator gp{NotNull{globalScope}, arena}; + + program->visit(&gp); +} + +void collectConstraints(std::vector>& result, NotNull scope) { for (const auto& c : scope->constraints) result.push_back(NotNull{c.get()}); - for (Scope2* child : scope->children) + for (NotNull child : scope->children) collectConstraints(result, child); } -std::vector> collectConstraints(Scope2* rootScope) +std::vector> collectConstraints(NotNull rootScope) { std::vector> result; collectConstraints(result, rootScope); diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 9e35523..077a4e2 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -13,7 +13,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau { -[[maybe_unused]] static void dumpBindings(Scope2* scope, ToStringOptions& opts) +[[maybe_unused]] static void dumpBindings(NotNull scope, ToStringOptions& opts) { for (const auto& [k, v] : scope->bindings) { @@ -22,22 +22,22 @@ namespace Luau printf("\t%s : %s\n", k.c_str(), d.name.c_str()); } - for (Scope2* child : scope->children) + for (NotNull child : scope->children) dumpBindings(child, opts); } -static void dumpConstraints(Scope2* scope, ToStringOptions& opts) +static void dumpConstraints(NotNull scope, ToStringOptions& opts) { for (const ConstraintPtr& c : scope->constraints) { printf("\t%s\n", toString(*c, opts).c_str()); } - for (Scope2* child : scope->children) + for (NotNull child : scope->children) dumpConstraints(child, opts); } -void dump(Scope2* rootScope, ToStringOptions& opts) +void dump(NotNull rootScope, ToStringOptions& opts) { printf("constraints:\n"); dumpConstraints(rootScope, opts); @@ -55,7 +55,7 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) } } -ConstraintSolver::ConstraintSolver(TypeArena* arena, Scope2* rootScope) +ConstraintSolver::ConstraintSolver(TypeArena* arena, NotNull rootScope) : arena(arena) , constraints(collectConstraints(rootScope)) , rootScope(rootScope) @@ -180,6 +180,10 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*gc, constraint, force); else if (auto ic = get(*constraint)) success = tryDispatch(*ic, constraint, force); + else if (auto uc = get(*constraint)) + success = tryDispatch(*uc, constraint, force); + else if (auto bc = get(*constraint)) + success = tryDispatch(*bc, constraint, force); else if (auto nc = get(*constraint)) success = tryDispatch(*nc, constraint); else @@ -246,12 +250,65 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull instantiated = inst.substitute(c.superType); LUAU_ASSERT(instantiated); // TODO FIXME HANDLE THIS - unify(c.subType, *instantiated); + if (isBlocked(c.subType)) + asMutable(c.subType)->ty.emplace(*instantiated); + else + unify(c.subType, *instantiated); + unblock(c.subType); return true; } +bool ConstraintSolver::tryDispatch(const UnaryConstraint& c, NotNull constraint, bool force) +{ + TypeId operandType = follow(c.operandType); + + if (isBlocked(operandType)) + return block(operandType, constraint); + + if (get(operandType)) + return block(operandType, constraint); + + LUAU_ASSERT(get(c.resultType)); + + if (isNumber(operandType) || get(operandType) || get(operandType)) + { + asMutable(c.resultType)->ty.emplace(c.operandType); + return true; + } + + LUAU_ASSERT(0); // TODO metatable handling + return false; +} + +bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull constraint, bool force) +{ + TypeId leftType = follow(c.leftType); + TypeId rightType = follow(c.rightType); + + if (isBlocked(leftType) || isBlocked(rightType)) + { + block(leftType, constraint); + block(rightType, constraint); + return false; + } + + if (isNumber(leftType)) + { + unify(leftType, rightType); + asMutable(c.resultType)->ty.emplace(leftType); + return true; + } + + if (get(leftType) && !force) + return block(leftType, constraint); + + // TODO metatables, classes + + return true; +} + bool ConstraintSolver::tryDispatch(const NameConstraint& c, NotNull constraint) { if (isBlocked(c.namedType)) diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 2407e3e..1b5275f 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAG(LuauCheckLenMT) + namespace Luau { @@ -202,7 +204,13 @@ declare function unpack(tab: {V}, i: number?, j: number?): ...V std::string getBuiltinDefinitionSource() { - return kBuiltinDefinitionLuaSrc; + std::string result = kBuiltinDefinitionLuaSrc; + + // TODO: move this into kBuiltinDefinitionLuaSrc + if (FFlag::LuauCheckLenMT) + result += "declare function rawlen(obj: {[K]: V} | string): number\n"; + + return result; } } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 85c5dbc..4cfaa11 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -787,14 +787,32 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons return const_cast(this)->getSourceModule(moduleName); } +NotNull Frontend::getGlobalScope2() +{ + if (!globalScope2) + { + const SingletonTypes& singletonTypes = getSingletonTypes(); + + globalScope2 = std::make_unique(); + globalScope2->typeBindings["nil"] = singletonTypes.nilType; + globalScope2->typeBindings["number"] = singletonTypes.numberType; + globalScope2->typeBindings["string"] = singletonTypes.stringType; + globalScope2->typeBindings["boolean"] = singletonTypes.booleanType; + globalScope2->typeBindings["thread"] = singletonTypes.threadType; + } + + return NotNull(globalScope2.get()); +} + ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const ScopePtr& environmentScope) { ModulePtr result = std::make_shared(); - ConstraintGraphBuilder cgb{&result->internalTypes}; + ConstraintGraphBuilder cgb{sourceModule.name, &result->internalTypes, NotNull(&iceHandler), getGlobalScope2()}; cgb.visit(sourceModule.root); + result->errors = std::move(cgb.errors); - ConstraintSolver cs{&result->internalTypes, cgb.rootScope}; + ConstraintSolver cs{&result->internalTypes, NotNull(cgb.rootScope)}; cs.run(); result->scope2s = std::move(cgb.scopes); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index d36665e..8ce7f74 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -5,7 +5,6 @@ #include #include "Luau/Clone.h" -#include "Luau/Substitution.h" #include "Luau/Unifier.h" #include "Luau/VisitTypeVar.h" @@ -16,7 +15,6 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineEqFix, false); -LUAU_FASTFLAGVARIABLE(LuauReplaceReplacer, false); LUAU_FASTFLAG(LuauQuantifyConstrained) namespace Luau @@ -25,238 +23,33 @@ namespace Luau namespace { -struct Replacer : Substitution +struct Replacer { + TypeArena* arena; TypeId sourceType; TypeId replacedType; - DenseHashMap replacedTypes{nullptr}; - DenseHashMap replacedPacks{nullptr}; + DenseHashMap newTypes; Replacer(TypeArena* arena, TypeId sourceType, TypeId replacedType) - : Substitution(TxnLog::empty(), arena) + : arena(arena) , sourceType(sourceType) , replacedType(replacedType) + , newTypes(nullptr) { } - bool isDirty(TypeId ty) override - { - if (!sourceType) - return false; - - auto vecHasSourceType = [sourceType = sourceType](const auto& vec) { - return end(vec) != std::find(begin(vec), end(vec), sourceType); - }; - - // Walk every kind of TypeVar and find pointers to sourceType - if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return vecHasSourceType(t->parts); - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - { - if (vecHasSourceType(t->generics)) - return true; - - return false; - } - else if (auto t = get(ty)) - { - if (t->boundTo) - return *t->boundTo == sourceType; - - for (const auto& [_name, prop] : t->props) - { - if (prop.type == sourceType) - return true; - } - - if (auto indexer = t->indexer) - { - if (indexer->indexType == sourceType || indexer->indexResultType == sourceType) - return true; - } - - if (vecHasSourceType(t->instantiatedTypeParams)) - return true; - - return false; - } - else if (auto t = get(ty)) - return t->table == sourceType || t->metatable == sourceType; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return false; - else if (auto t = get(ty)) - return vecHasSourceType(t->options); - else if (auto t = get(ty)) - return vecHasSourceType(t->parts); - else if (auto t = get(ty)) - return false; - - LUAU_ASSERT(!"Luau::Replacer::isDirty internal error: Unknown TypeVar type"); - LUAU_UNREACHABLE(); - } - - bool isDirty(TypePackId tp) override - { - if (auto it = replacedPacks.find(tp)) - return false; - - if (auto pack = get(tp)) - { - for (TypeId ty : pack->head) - { - if (ty == sourceType) - return true; - } - return false; - } - else if (auto vtp = get(tp)) - return vtp->ty == sourceType; - else - return false; - } - - TypeId clean(TypeId ty) override - { - LUAU_ASSERT(sourceType && replacedType); - - // Walk every kind of TypeVar and create a copy with sourceType replaced by replacedType - // Before returning, memoize the result for later use. - - // Helpfully, Substitution::clone() only shallow-clones the kinds of types that we care to work with. This - // function returns the identity for things like primitives. - TypeId res = clone(ty); - - if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = getMutable(res)) - { - for (TypeId& part : t->parts) - { - if (part == sourceType) - part = replacedType; - } - } - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = getMutable(res)) - { - // The constituent typepacks are cleaned separately. We just need to walk the generics array. - for (TypeId& g : t->generics) - { - if (g == sourceType) - g = replacedType; - } - } - else if (auto t = getMutable(res)) - { - for (auto& [_key, prop] : t->props) - { - if (prop.type == sourceType) - prop.type = replacedType; - } - } - else if (auto t = getMutable(res)) - { - if (t->table == sourceType) - t->table = replacedType; - if (t->metatable == sourceType) - t->table = replacedType; - } - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else if (auto t = getMutable(res)) - { - for (TypeId& option : t->options) - { - if (option == sourceType) - option = replacedType; - } - } - else if (auto t = getMutable(res)) - { - for (TypeId& part : t->parts) - { - if (part == sourceType) - part = replacedType; - } - } - else if (auto t = get(res)) - LUAU_ASSERT(!"Impossible"); - else - LUAU_ASSERT(!"Luau::Replacer::clean internal error: Unknown TypeVar type"); - - replacedTypes[ty] = res; - return res; - } - - TypePackId clean(TypePackId tp) override - { - TypePackId res = clone(tp); - - if (auto pack = getMutable(res)) - { - for (TypeId& type : pack->head) - { - if (type == sourceType) - type = replacedType; - } - } - else if (auto vtp = getMutable(res)) - { - if (vtp->ty == sourceType) - vtp->ty = replacedType; - } - - replacedPacks[tp] = res; - return res; - } - TypeId smartClone(TypeId t) { - if (FFlag::LuauReplaceReplacer) - { - // The new smartClone is just a memoized clone() - // TODO: Remove the Substitution base class and all other methods from this struct. - // Add DenseHashMap newTypes; - t = log->follow(t); - TypeId* res = newTypes.find(t); - if (res) - return *res; - - TypeId result = shallowClone(t, *arena, TxnLog::empty()); - newTypes[t] = result; - newTypes[result] = result; - - return result; - } - else - { - std::optional res = replace(t); - LUAU_ASSERT(res.has_value()); // TODO think about this - if (*res == t) - return clone(t); + t = follow(t); + TypeId* res = newTypes.find(t); + if (res) return *res; - } + + TypeId result = shallowClone(t, *arena, TxnLog::empty()); + newTypes[t] = result; + newTypes[result] = result; + + return result; } }; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 40e14c6..294c479 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -8,6 +8,7 @@ #include "Luau/VisitTypeVar.h" LUAU_FASTFLAG(LuauAlwaysQuantify); +LUAU_FASTFLAG(DebugLuauSharedSelf) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAGVARIABLE(LuauQuantifyConstrained, false) @@ -158,24 +159,61 @@ struct Quantifier final : TypeVarOnceVisitor void quantify(TypeId ty, TypeLevel level) { - Quantifier q{level}; - q.traverse(ty); - - FunctionTypeVar* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - if (FFlag::LuauAlwaysQuantify) + if (FFlag::DebugLuauSharedSelf) { - ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); - ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + ty = follow(ty); + + if (auto ttv = getTableType(ty); ttv && ttv->selfTy) + { + Quantifier selfQ{level}; + selfQ.traverse(*ttv->selfTy); + + Quantifier q{level}; + q.traverse(ty); + + for (const auto& [_, prop] : ttv->props) + { + auto ftv = getMutable(follow(prop.type)); + if (!ftv || !ftv->hasSelf) + continue; + + if (Luau::first(ftv->argTypes) == ttv->selfTy) + { + ftv->generics.insert(ftv->generics.end(), selfQ.generics.begin(), selfQ.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), selfQ.genericPacks.begin(), selfQ.genericPacks.end()); + } + } + } + else if (auto ftv = getMutable(ty)) + { + Quantifier q{level}; + q.traverse(ty); + + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + + if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) + ftv->hasNoGenerics = true; + } } else { - ftv->generics = q.generics; - ftv->genericPacks = q.genericPacks; - } + Quantifier q{level}; + q.traverse(ty); - if (ftv->generics.empty() && ftv->genericPacks.empty() && !q.seenMutableType && !q.seenGenericType) - ftv->hasNoGenerics = true; + FunctionTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + if (FFlag::LuauAlwaysQuantify) + { + ftv->generics.insert(ftv->generics.end(), q.generics.begin(), q.generics.end()); + ftv->genericPacks.insert(ftv->genericPacks.end(), q.genericPacks.begin(), q.genericPacks.end()); + } + else + { + ftv->generics = q.generics; + ftv->genericPacks = q.genericPacks; + } + } } void quantify(TypeId ty, Scope2* scope) @@ -206,8 +244,8 @@ struct PureQuantifier : Substitution std::vector insertedGenerics; std::vector insertedGenericPacks; - PureQuantifier(const TxnLog* log, TypeArena* arena, Scope2* scope) - : Substitution(log, arena) + PureQuantifier(TypeArena* arena, Scope2* scope) + : Substitution(TxnLog::empty(), arena) , scope(scope) { } @@ -286,7 +324,7 @@ struct PureQuantifier : Substitution TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope) { - PureQuantifier quantifier{TxnLog::empty(), arena, scope}; + PureQuantifier quantifier{arena, scope}; std::optional result = quantifier.substitute(ty); LUAU_ASSERT(result); @@ -294,8 +332,7 @@ TypeId quantify(TypeArena* arena, TypeId ty, Scope2* scope) LUAU_ASSERT(ftv); ftv->generics.insert(ftv->generics.end(), quantifier.insertedGenerics.begin(), quantifier.insertedGenerics.end()); ftv->genericPacks.insert(ftv->genericPacks.end(), quantifier.insertedGenericPacks.begin(), quantifier.insertedGenericPacks.end()); - - // TODO: Set hasNoGenerics. + ftv->hasNoGenerics = ftv->generics.empty() && ftv->genericPacks.empty(); return *result; } diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 66aaee1..247a9dd 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -153,4 +153,19 @@ std::optional Scope2::lookupTypeBinding(const Name& name) return std::nullopt; } +std::optional Scope2::lookupTypePackBinding(const Name& name) +{ + Scope2* s = this; + while (s) + { + auto it = s->typePackBindings.find(name); + if (it != s->typePackBindings.end()) + return it->second; + + s = s->parent; + } + + return std::nullopt; +} + } // namespace Luau diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 7a45896..bbbad9d 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1377,51 +1377,74 @@ std::string generateName(size_t i) return n; } -std::string toString(const Constraint& c, ToStringOptions& opts) +std::string toString(const Constraint& constraint, ToStringOptions& opts) { - if (const SubtypeConstraint* sc = Luau::get_if(&c.c)) - { - ToStringResult subStr = toStringDetailed(sc->subType, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(sc->superType, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " <: " + superStr.name; - } - else if (const PackSubtypeConstraint* psc = Luau::get_if(&c.c)) - { - ToStringResult subStr = toStringDetailed(psc->subPack, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(psc->superPack, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " <: " + superStr.name; - } - else if (const GeneralizationConstraint* gc = Luau::get_if(&c.c)) - { - ToStringResult subStr = toStringDetailed(gc->generalizedType, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(gc->sourceType, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " ~ gen " + superStr.name; - } - else if (const InstantiationConstraint* ic = Luau::get_if(&c.c)) - { - ToStringResult subStr = toStringDetailed(ic->subType, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(ic->superType, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " ~ inst " + superStr.name; - } - else if (const NameConstraint* nc = Luau::get(c)) - { - ToStringResult namedStr = toStringDetailed(nc->namedType, opts); - opts.nameMap = std::move(namedStr.nameMap); - return "@name(" + namedStr.name + ") = " + nc->name; - } - else - { - LUAU_ASSERT(false); - return ""; - } + auto go = [&opts](auto&& c) { + using T = std::decay_t; + + if constexpr (std::is_same_v) + { + ToStringResult subStr = toStringDetailed(c.subType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(c.superType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " <: " + superStr.name; + } + else if constexpr (std::is_same_v) + { + ToStringResult subStr = toStringDetailed(c.subPack, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(c.superPack, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " <: " + superStr.name; + } + else if constexpr (std::is_same_v) + { + ToStringResult subStr = toStringDetailed(c.generalizedType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(c.sourceType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " ~ gen " + superStr.name; + } + else if constexpr (std::is_same_v) + { + ToStringResult subStr = toStringDetailed(c.subType, opts); + opts.nameMap = std::move(subStr.nameMap); + ToStringResult superStr = toStringDetailed(c.superType, opts); + opts.nameMap = std::move(superStr.nameMap); + return subStr.name + " ~ inst " + superStr.name; + } + else if constexpr (std::is_same_v) + { + ToStringResult resultStr = toStringDetailed(c.resultType, opts); + opts.nameMap = std::move(resultStr.nameMap); + ToStringResult operandStr = toStringDetailed(c.operandType, opts); + opts.nameMap = std::move(operandStr.nameMap); + + return resultStr.name + " ~ Unary<" + toString(c.op) + ", " + operandStr.name + ">"; + } + else if constexpr (std::is_same_v) + { + ToStringResult resultStr = toStringDetailed(c.resultType); + opts.nameMap = std::move(resultStr.nameMap); + ToStringResult leftStr = toStringDetailed(c.leftType); + opts.nameMap = std::move(leftStr.nameMap); + ToStringResult rightStr = toStringDetailed(c.rightType); + opts.nameMap = std::move(rightStr.nameMap); + + return resultStr.name + " ~ Binary<" + toString(c.op) + ", " + leftStr.name + ", " + rightStr.name + ">"; + } + else if constexpr (std::is_same_v) + { + ToStringResult namedStr = toStringDetailed(c.namedType, opts); + opts.nameMap = std::move(namedStr.nameMap); + return "@name(" + namedStr.name + ") = " + c.name; + } + else + static_assert(always_false_v, "Non-exhaustive constraint switch"); + }; + + return visit(go, constraint.c); } std::string dump(const Constraint& c) diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 63e5800..30e498a 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -6,8 +6,10 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" #include "Luau/Clone.h" +#include "Luau/Instantiation.h" #include "Luau/Normalize.h" -#include "Luau/ConstraintGraphBuilder.h" // FIXME move Scope2 into its own header +#include "Luau/TxnLog.h" +#include "Luau/TypeUtils.h" #include "Luau/Unifier.h" #include "Luau/ToString.h" @@ -19,10 +21,12 @@ struct TypeChecker2 : public AstVisitor const SourceModule* sourceModule; Module* module; InternalErrorReporter ice; // FIXME accept a pointer from Frontend + SingletonTypes& singletonTypes; TypeChecker2(const SourceModule* sourceModule, Module* module) : sourceModule(sourceModule) , module(module) + , singletonTypes(getSingletonTypes()) { } @@ -30,16 +34,30 @@ struct TypeChecker2 : public AstVisitor TypePackId lookupPack(AstExpr* expr) { + // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. + // We'll just return anyType in these cases. Typechecking against any is very fast and this + // allows us not to think about this very much in the actual typechecking logic. TypePackId* tp = module->astTypePacks.find(expr); - LUAU_ASSERT(tp); - return follow(*tp); + if (tp) + return follow(*tp); + else + return singletonTypes.anyTypePack; } TypeId lookupType(AstExpr* expr) { + // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. + // We'll just return anyType in these cases. Typechecking against any is very fast and this + // allows us not to think about this very much in the actual typechecking logic. TypeId* ty = module->astTypes.find(expr); - LUAU_ASSERT(ty); - return follow(*ty); + if (ty) + return follow(*ty); + + TypePackId* tp = module->astTypePacks.find(expr); + if (tp) + return flattenPack(*tp); + + return singletonTypes.anyType; } TypeId lookupAnnotation(AstType* annotation) @@ -78,7 +96,7 @@ struct TypeChecker2 : public AstVisitor bestLocation = scopeBounds; } } - else + else if (scopeBounds.begin > location.end) { // TODO: Is this sound? This relies on the fact that scopes are inserted // into the scope list in the order that they appear in the AST. @@ -147,16 +165,14 @@ struct TypeChecker2 : public AstVisitor for (size_t i = 0; i < count; ++i) { AstExpr* lhs = assign->vars.data[i]; - TypeId* lhsType = module->astTypes.find(lhs); - LUAU_ASSERT(lhsType); + TypeId lhsType = lookupType(lhs); AstExpr* rhs = assign->values.data[i]; - TypeId* rhsType = module->astTypes.find(rhs); - LUAU_ASSERT(rhsType); + TypeId rhsType = lookupType(rhs); - if (!isSubtype(*rhsType, *lhsType, ice)) + if (!isSubtype(rhsType, lhsType, ice)) { - reportError(TypeMismatch{*lhsType, *rhsType}, rhs->location); + reportError(TypeMismatch{lhsType, rhsType}, rhs->location); } } @@ -181,7 +197,7 @@ struct TypeChecker2 : public AstVisitor if (!ok) { for (const TypeError& e : u.errors) - module->errors.push_back(e); + reportError(e); } return true; @@ -189,10 +205,14 @@ struct TypeChecker2 : public AstVisitor bool visit(AstExprCall* call) override { + TypeArena arena; + Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}}; + TypePackId expectedRetType = lookupPack(call); TypeId functionType = lookupType(call->func); + TypeId instantiatedFunctionType = instantiation.substitute(functionType).value_or(nullptr); + LUAU_ASSERT(functionType); - TypeArena arena; TypePack args; for (const auto& arg : call->args) { @@ -204,7 +224,7 @@ struct TypeChecker2 : public AstVisitor TypePackId argsTp = arena.addTypePack(args); FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); - if (!isSubtype(expectedType, functionType, ice)) + if (!isSubtype(expectedType, instantiatedFunctionType, ice)) { unfreeze(module->interfaceTypes); CloneState cloneState; @@ -252,16 +272,12 @@ struct TypeChecker2 : public AstVisitor // leftType must have a property called indexName->index - if (auto ttv = get(leftType)) + std::optional t = findTablePropertyRespectingMeta(module->errors, leftType, indexName->index.value, indexName->location); + if (t) { - auto it = ttv->props.find(indexName->index.value); - if (it == ttv->props.end()) + if (!isSubtype(resultType, *t, ice)) { - reportError(UnknownProperty{leftType, indexName->index.value}, indexName->location); - } - else if (!isSubtype(resultType, it->second.type, ice)) - { - reportError(TypeMismatch{resultType, it->second.type}, indexName->location); + reportError(TypeMismatch{resultType, *t}, indexName->location); } } else @@ -277,7 +293,7 @@ struct TypeChecker2 : public AstVisitor TypeId actualType = lookupType(number); TypeId numberType = getSingletonTypes().numberType; - if (!isSubtype(actualType, numberType, ice)) + if (!isSubtype(numberType, actualType, ice)) { reportError(TypeMismatch{actualType, numberType}, number->location); } @@ -290,7 +306,7 @@ struct TypeChecker2 : public AstVisitor TypeId actualType = lookupType(string); TypeId stringType = getSingletonTypes().stringType; - if (!isSubtype(actualType, stringType, ice)) + if (!isSubtype(stringType, actualType, ice)) { reportError(TypeMismatch{actualType, stringType}, string->location); } @@ -298,6 +314,41 @@ struct TypeChecker2 : public AstVisitor return true; } + /** Extract a TypeId for the first type of the provided pack. + * + * Note that this may require modifying some types. I hope this doesn't cause problems! + */ + TypeId flattenPack(TypePackId pack) + { + pack = follow(pack); + + while (auto tp = get(pack)) + { + if (tp->head.empty() && tp->tail) + pack = *tp->tail; + } + + if (auto ty = first(pack)) + return *ty; + else if (auto vtp = get(pack)) + return vtp->ty; + else if (auto ftp = get(pack)) + { + TypeId result = module->internalTypes.addType(FreeTypeVar{ftp->scope}); + TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope}); + + TypePack& resultPack = asMutable(pack)->ty.emplace(); + resultPack.head.assign(1, result); + resultPack.tail = freeTail; + + return result; + } + else if (get(pack)) + return singletonTypes.errorRecoveryType(); + else + ice.ice("flattenPack got a weird pack!"); + } + bool visit(AstType* ty) override { return true; @@ -321,6 +372,11 @@ struct TypeChecker2 : public AstVisitor { module->errors.emplace_back(location, sourceModule->name, std::move(data)); } + + void reportError(TypeError e) + { + module->errors.emplace_back(std::move(e)); + } }; void check(const SourceModule& sourceModule, Module* module) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 44635e8..d9486a4 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -38,11 +38,13 @@ LUAU_FASTFLAGVARIABLE(LuauReduceUnionRecursion, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) LUAU_FASTFLAGVARIABLE(LuauReturnTypeInferenceInNonstrict, false) +LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false); LUAU_FASTFLAGVARIABLE(LuauAlwaysQuantify, false); LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false) LUAU_FASTFLAG(LuauQuantifyConstrained) LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false) LUAU_FASTFLAGVARIABLE(LuauNonCopyableTypeVarFields, false) +LUAU_FASTFLAGVARIABLE(LuauCheckLenMT, false) namespace Luau { @@ -238,7 +240,7 @@ static bool isMetamethod(const Name& name) { return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || - name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode"; + name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; } size_t HashBoolNamePair::operator()(const std::pair& pair) const @@ -327,10 +329,19 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo currentModule->timeout = true; } + if (FFlag::DebugLuauSharedSelf) + { + for (auto& [ty, scope] : deferredQuantification) + Luau::quantify(ty, scope->level); + deferredQuantification.clear(); + } + if (get(follow(moduleScope->returnType))) moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); else + { moduleScope->returnType = anyify(moduleScope, moduleScope->returnType, Location{}); + } for (auto& [_, typeFun] : moduleScope->exportedTypeBindings) typeFun.type = anyify(moduleScope, typeFun.type, Location{}); @@ -537,18 +548,43 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A } else if (auto fun = (*protoIter)->as()) { + std::optional selfType; std::optional expectedType; - if (!fun->func->self) + if (FFlag::DebugLuauSharedSelf) { if (auto name = fun->name->as()) { - TypeId exprTy = checkExpr(scope, *name->expr).type; - expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false); + TypeId baseTy = checkExpr(scope, *name->expr).type; + tablify(baseTy); + + if (!fun->func->self) + expectedType = getIndexTypeFromType(scope, baseTy, name->index.value, name->indexLocation, false); + else if (auto ttv = getMutableTableType(baseTy)) + { + if (!baseTy->persistent && ttv->state != TableState::Sealed && !ttv->selfTy) + { + ttv->selfTy = anyIfNonstrict(freshType(ttv->level)); + deferredQuantification.push_back({baseTy, scope}); + } + + selfType = ttv->selfTy; + } + } + } + else + { + if (!fun->func->self) + { + if (auto name = fun->name->as()) + { + TypeId exprTy = checkExpr(scope, *name->expr).type; + expectedType = getIndexTypeFromType(scope, exprTy, name->index.value, name->indexLocation, false); + } } } - auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, expectedType); + auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, selfType, expectedType); auto [funTy, funScope] = pair; functionDecls[*protoIter] = pair; @@ -560,7 +596,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A } else if (auto fun = (*protoIter)->as()) { - auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt); + auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt, std::nullopt); auto [funTy, funScope] = pair; functionDecls[*protoIter] = pair; @@ -2076,7 +2112,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) { - auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, expectedType); + auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, std::nullopt, expectedType); checkFunctionBody(funScope, funTy, expr); @@ -2296,6 +2332,8 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); state.log.commit(); + reportErrors(state.errors); + TypeId retType = first(retTypePack).value_or(nilType); if (!state.errors.empty()) retType = errorRecoveryType(retType); @@ -2322,6 +2360,23 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp DenseHashSet seen{nullptr}; + if (FFlag::LuauCheckLenMT && typeCouldHaveMetatable(operandType)) + { + if (auto fnt = findMetatableEntry(operandType, "__len", expr.location)) + { + TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); + TypePackId arguments = addTypePack({operandType}); + TypePackId retTypePack = addTypePack({numberType}); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); + + Unifier state = mkUnifier(expr.location); + state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); + state.log.commit(); + + reportErrors(state.errors); + } + } + if (!hasLength(operandType, seen, &recursionCount)) reportError(TypeError{expr.location, NotATable{operandType}}); @@ -2530,17 +2585,15 @@ TypeId TypeChecker::checkRelationalOperation( } } - if (!matches) { reportError( expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", - toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); return errorRecoveryType(booleanType); } } - if (leftMetatable) { std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location); @@ -3139,8 +3192,8 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T // `(X) -> Y...`, but after typechecking the body, we cam unify `Y...` with `X` // to get type `(X) -> X`, then we quantify the free types to get the final // generic type `(a) -> a`. -std::pair TypeChecker::checkFunctionSignature( - const ScopePtr& scope, int subLevel, const AstExprFunction& expr, std::optional originalName, std::optional expectedType) +std::pair TypeChecker::checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, + std::optional originalName, std::optional selfType, std::optional expectedType) { ScopePtr funScope = childFunctionScope(scope, expr.location, subLevel); @@ -3241,12 +3294,25 @@ std::pair TypeChecker::checkFunctionSignature( funScope->returnType = retPack; - if (expr.self) + if (FFlag::DebugLuauSharedSelf) { - // TODO: generic self types: CLI-39906 - TypeId selfType = anyIfNonstrict(freshType(funScope)); - funScope->bindings[expr.self] = {selfType, expr.self->location}; - argTypes.push_back(selfType); + if (expr.self) + { + // TODO: generic self types: CLI-39906 + TypeId selfTy = anyIfNonstrict(selfType ? *selfType : freshType(funScope)); + funScope->bindings[expr.self] = {selfTy, expr.self->location}; + argTypes.push_back(selfTy); + } + } + else + { + if (expr.self) + { + // TODO: generic self types: CLI-39906 + TypeId selfType = anyIfNonstrict(freshType(funScope)); + funScope->bindings[expr.self] = {selfType, expr.self->location}; + argTypes.push_back(selfType); + } } // Prepare expected argument type iterators if we have an expected function type @@ -4457,25 +4523,43 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location { ty = follow(ty); - const FunctionTypeVar* ftv = get(ty); - - if (FFlag::LuauAlwaysQuantify) + if (FFlag::DebugLuauSharedSelf) { - if (ftv) + if (auto ftv = get(ty)) Luau::quantify(ty, scope->level); + else if (auto ttv = getTableType(ty); ttv && ttv->selfTy) + Luau::quantify(ty, scope->level); + + if (FFlag::LuauLowerBoundsCalculation) + { + auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + return t; + } } else { - if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) - Luau::quantify(ty, scope->level); - } + const FunctionTypeVar* ftv = get(ty); - if (FFlag::LuauLowerBoundsCalculation && ftv) - { - auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); - if (!ok) - reportError(location, NormalizationTooComplex{}); - return t; + if (FFlag::LuauAlwaysQuantify) + { + if (ftv) + Luau::quantify(ty, scope->level); + } + else + { + if (ftv && ftv->generics.empty() && ftv->genericPacks.empty()) + Luau::quantify(ty, scope->level); + } + + if (FFlag::LuauLowerBoundsCalculation && ftv) + { + auto [t, ok] = Luau::normalize(ty, currentModule, *iceHandler); + if (!ok) + reportError(location, NormalizationTooComplex{}); + return t; + } } return ty; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 6147e11..0792a35 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -740,7 +740,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I std::optional unificationTooComplex; std::optional firstFailedOption; - // T <: A & B if A <: T and B <: T + // T <: A & B if T <: A and T <: B for (TypeId type : uv->parts) { Unifier innerState = makeChildUnifier(); @@ -765,7 +765,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) { - // A & B <: T if T <: A or T <: B + // A & B <: T if A <: T or B <: T bool found = false; std::optional unificationTooComplex; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 95bce3e..70c9255 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -5,6 +5,9 @@ #include +#include +#include + // Warning: If you are introducing new syntax, ensure that it is behind a separate // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. @@ -14,6 +17,18 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false) LUAU_FASTFLAGVARIABLE(LuauReturnTypeTokenConfusion, false) +LUAU_FASTFLAGVARIABLE(LuauFixNamedFunctionParse, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseWrongNamedType, false) + +bool lua_telemetry_parsed_named_non_function_type = false; + +LUAU_FASTFLAGVARIABLE(LuauErrorParseIntegerIssues, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) + +bool lua_telemetry_parsed_out_of_range_bin_integer = false; +bool lua_telemetry_parsed_out_of_range_hex_integer = false; +bool lua_telemetry_parsed_double_prefix_hex_integer = false; + namespace Luau { @@ -1330,7 +1345,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); - bool monomorphic = lexer.current().type != '<'; + bool forceFunctionType = lexer.current().type == '<'; Lexeme begin = lexer.current(); @@ -1355,21 +1370,33 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) AstArray paramTypes = copy(params); + if (FFlag::LuauFixNamedFunctionParse && !names.empty()) + forceFunctionType = true; + bool returnTypeIntroducer = FFlag::LuauReturnTypeTokenConfusion ? lexer.current().type == Lexeme::SkinnyArrow || lexer.current().type == ':' : false; // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element - if (params.size() == 1 && !varargAnnotation && monomorphic && + if (params.size() == 1 && !varargAnnotation && !forceFunctionType && (FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow)) { + if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) + lua_telemetry_parsed_named_non_function_type = true; + if (allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; else return {params[0], {}}; } - if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && monomorphic && allowPack) + if ((FFlag::LuauReturnTypeTokenConfusion ? !returnTypeIntroducer : lexer.current().type != Lexeme::SkinnyArrow) && !forceFunctionType && + allowPack) + { + if (DFFlag::LuaReportParseWrongNamedType && !names.empty()) + lua_telemetry_parsed_named_non_function_type = true; + return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, varargAnnotation})}; + } AstArray> paramNames = copy(names); @@ -2010,7 +2037,63 @@ AstExpr* Parser::parseAssertionExpr() return expr; } -static bool parseNumber(double& result, const char* data) +static const char* parseInteger(double& result, const char* data, int base) +{ + char* end = nullptr; + unsigned long long value = strtoull(data, &end, base); + + if (value == ULLONG_MAX && errno == ERANGE) + { + // 'errno' might have been set before we called 'strtoull', but we don't want the overhead of resetting a TLS variable on each call + // so we only reset it when we get a result that might be an out-of-range error and parse again to make sure + errno = 0; + value = strtoull(data, &end, base); + + if (errno == ERANGE) + { + if (DFFlag::LuaReportParseIntegerIssues) + { + if (base == 2) + lua_telemetry_parsed_out_of_range_bin_integer = true; + else + lua_telemetry_parsed_out_of_range_hex_integer = true; + } + + if (FFlag::LuauErrorParseIntegerIssues) + return "Integer number value is out of range"; + } + } + + result = double(value); + return *end == 0 ? nullptr : "Malformed number"; +} + +static const char* parseNumber(double& result, const char* data) +{ + // binary literal + if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2]) + return parseInteger(result, data + 2, 2); + + // hexadecimal literal + if (data[0] == '0' && (data[1] == 'x' || data[1] == 'X') && data[2]) + { + if (DFFlag::LuaReportParseIntegerIssues && data[2] == '0' && (data[3] == 'x' || data[3] == 'X')) + lua_telemetry_parsed_double_prefix_hex_integer = true; + + if (FFlag::LuauErrorParseIntegerIssues) + return parseInteger(result, data, 16); // keep prefix, it's handled by 'strtoull' + else + return parseInteger(result, data + 2, 16); + } + + char* end = nullptr; + double value = strtod(data, &end); + + result = value; + return *end == 0 ? nullptr : "Malformed number"; +} + +static bool parseNumber_DEPRECATED(double& result, const char* data) { // binary literal if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2]) @@ -2080,18 +2163,37 @@ AstExpr* Parser::parseSimpleExpr() scratchData.erase(std::remove(scratchData.begin(), scratchData.end(), '_'), scratchData.end()); } - double value = 0; - if (parseNumber(value, scratchData.c_str())) + if (DFFlag::LuaReportParseIntegerIssues || FFlag::LuauErrorParseIntegerIssues) { - nextLexeme(); + double value = 0; + if (const char* error = parseNumber(value, scratchData.c_str())) + { + nextLexeme(); - return allocator.alloc(start, value); + return reportExprError(start, {}, "%s", error); + } + else + { + nextLexeme(); + + return allocator.alloc(start, value); + } } else { - nextLexeme(); + double value = 0; + if (parseNumber_DEPRECATED(value, scratchData.c_str())) + { + nextLexeme(); - return reportExprError(start, {}, "Malformed number"); + return allocator.alloc(start, value); + } + else + { + nextLexeme(); + + return reportExprError(start, {}, "Malformed number"); + } } } else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 218bb5d..0cb7e1d 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -276,7 +276,7 @@ enum LuauOpcode // FORGLOOP: adjust loop variables for one iteration of a generic for loop, jump back to the loop header if loop needs to continue // A: target register; generic for loops assume a register layout [generator, state, index, variables...] // D: jump offset (-32768..32767) - // AUX: variable count (1..255) + // AUX: variable count (1..255) in the low 8 bits, high bit indicates whether to use ipairs-style traversal in the fast path // loop variables are adjusted by calling generator(state, index) and expecting it to return a tuple that's copied to the user variables // the first variable is then copied into index; generator/state are immutable, index isn't visible to user code LOP_FORGLOOP, @@ -490,6 +490,9 @@ enum LuauBuiltinFunction // select(_, ...) LBF_SELECT_VARARG, + + // rawlen + LBF_RAWLEN, }; // Capture type, used in LOP_CAPTURE diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index ff75311..6bd24b6 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,6 +4,8 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" +LUAU_FASTFLAGVARIABLE(LuauCompileRawlen, false) + namespace Luau { namespace Compile @@ -58,6 +60,8 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) return LBF_RAWGET; if (builtin.isGlobal("rawequal")) return LBF_RAWEQUAL; + if (FFlag::LuauCompileRawlen && builtin.isGlobal("rawlen")) + return LBF_RAWLEN; if (builtin.isGlobal("unpack")) return LBF_TABLE_UNPACK; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 301cf25..5e2669b 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -1302,20 +1302,22 @@ void BytecodeBuilder::validate() const case LOP_FORNPREP: case LOP_FORNLOOP: - VREG(LUAU_INSN_A(insn) + 2); // for loop protocol: A, A+1, A+2 are used for iteration + // for loop protocol: A, A+1, A+2 are used for iteration + VREG(LUAU_INSN_A(insn) + 2); VJUMP(LUAU_INSN_D(insn)); break; case LOP_FORGPREP: - VREG(LUAU_INSN_A(insn) + 2 + 1); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + VREG(LUAU_INSN_A(insn) + 2 + 1); VJUMP(LUAU_INSN_D(insn)); break; case LOP_FORGLOOP: - VREG( - LUAU_INSN_A(insn) + 2 + insns[i + 1]); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + VREG(LUAU_INSN_A(insn) + 2 + uint8_t(insns[i + 1])); VJUMP(LUAU_INSN_D(insn)); - LUAU_ASSERT(insns[i + 1] >= 1); + LUAU_ASSERT(uint8_t(insns[i + 1]) >= 1); break; case LOP_FORGPREP_INEXT: @@ -1679,7 +1681,8 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, break; case LOP_FORGLOOP: - formatAppend(result, "FORGLOOP R%d L%d %d\n", LUAU_INSN_A(insn), targetLabel, *code++); + formatAppend(result, "FORGLOOP R%d L%d %d%s\n", LUAU_INSN_A(insn), targetLabel, uint8_t(*code), int(*code) < 0 ? " [inext]" : ""); + code++; break; case LOP_FORGPREP_INEXT: diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index e732256..d7c8155 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -23,6 +23,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) +LUAU_FASTFLAGVARIABLE(LuauCompileNoIpairs, false) + namespace Luau { @@ -2665,7 +2667,7 @@ struct Compiler if (builtin.isGlobal("ipairs")) // for .. in ipairs(t) { skipOp = LOP_FORGPREP_INEXT; - loopOp = LOP_FORGLOOP_INEXT; + loopOp = FFlag::LuauCompileNoIpairs ? LOP_FORGLOOP : LOP_FORGLOOP_INEXT; } else if (builtin.isGlobal("pairs")) // for .. in pairs(t) { @@ -2709,8 +2711,16 @@ struct Compiler bytecode.emitAD(loopOp, regs, 0); + if (FFlag::LuauCompileNoIpairs) + { + // TODO: remove loopOp as it's a constant now + LUAU_ASSERT(loopOp == LOP_FORGLOOP); + + // FORGLOOP uses aux to encode variable count and fast path flag for ipairs traversal in the high bit + bytecode.emitAux((skipOp == LOP_FORGPREP_INEXT ? 0x80000000 : 0) | uint32_t(stat->vars.size)); + } // note: FORGLOOP needs variable count encoded in AUX field, other loop instructions assume a fixed variable count - if (loopOp == LOP_FORGLOOP) + else if (loopOp == LOP_FORGLOOP) bytecode.emitAux(uint32_t(stat->vars.size)); size_t endLabel = bytecode.emitLabel(); @@ -3341,7 +3351,7 @@ struct Compiler std::vector upvals; }; - struct ReturnVisitor: AstVisitor + struct ReturnVisitor : AstVisitor { Compiler* self; bool returnsOne = true; diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index f791761..4fc5033 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -11,6 +11,8 @@ #include #include +LUAU_FASTFLAG(LuauLenTM) + static void writestring(const char* s, size_t l) { fwrite(s, 1, l, stdout); @@ -178,6 +180,18 @@ static int luaB_rawset(lua_State* L) return 1; } +static int luaB_rawlen(lua_State* L) +{ + if (!FFlag::LuauLenTM) + luaL_error(L, "'rawlen' is not available"); + + int tt = lua_type(L, 1); + luaL_argcheck(L, tt == LUA_TTABLE || tt == LUA_TSTRING, 1, "table or string expected"); + int len = lua_objlen(L, 1); + lua_pushinteger(L, len); + return 1; +} + static int luaB_gcinfo(lua_State* L) { lua_pushinteger(L, lua_gc(L, LUA_GCCOUNT, 0)); @@ -428,6 +442,7 @@ static const luaL_Reg base_funcs[] = { {"rawequal", luaB_rawequal}, {"rawget", luaB_rawget}, {"rawset", luaB_rawset}, + {"rawlen", luaB_rawlen}, {"select", luaB_select}, {"setfenv", luaB_setfenv}, {"setmetatable", luaB_setmetatable}, diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index deaf140..e98660a 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1117,6 +1117,27 @@ static int luauF_select(lua_State* L, StkId res, TValue* arg0, int nresults, Stk return -1; } +static int luauF_rawlen(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1) + { + if (ttistable(arg0)) + { + Table* h = hvalue(arg0); + setnvalue(res, double(luaH_getn(h))); + return 1; + } + else if (ttisstring(arg0)) + { + TString* ts = tsvalue(arg0); + setnvalue(res, double(ts->len)); + return 1; + } + } + + return -1; +} + luau_FastFunction luauF_table[256] = { NULL, luauF_assert, @@ -1188,4 +1209,6 @@ luau_FastFunction luauF_table[256] = { luauF_countrz, luauF_select, + + luauF_rawlen, }; diff --git a/VM/src/ldebug.h b/VM/src/ldebug.h index cf905e9..75bb8dc 100644 --- a/VM/src/ldebug.h +++ b/VM/src/ldebug.h @@ -19,6 +19,7 @@ LUAI_FUNC l_noret luaG_concaterror(lua_State* L, StkId p1, StkId p2); LUAI_FUNC l_noret luaG_aritherror(lua_State* L, const TValue* p1, const TValue* p2, TMS op); LUAI_FUNC l_noret luaG_ordererror(lua_State* L, const TValue* p1, const TValue* p2, TMS op); LUAI_FUNC l_noret luaG_indexerror(lua_State* L, const TValue* p1, const TValue* p2); + LUAI_FUNC LUA_PRINTF_ATTR(2, 3) l_noret luaG_runerrorL(lua_State* L, const char* fmt, ...); LUAI_FUNC void luaG_pusherror(lua_State* L, const char* error); diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index e7df4e5..49982b2 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -39,6 +39,7 @@ const char* const luaT_eventname[] = { "__namecall", "__call", "__iter", + "__len", "__eq", @@ -52,7 +53,6 @@ const char* const luaT_eventname[] = { "__unm", - "__len", "__lt", "__le", "__concat", diff --git a/VM/src/ltm.h b/VM/src/ltm.h index a522394..e11ddb3 100644 --- a/VM/src/ltm.h +++ b/VM/src/ltm.h @@ -18,6 +18,7 @@ typedef enum TM_NAMECALL, TM_CALL, TM_ITER, + TM_LEN, TM_EQ, /* last tag method with `fast' access */ @@ -31,7 +32,6 @@ typedef enum TM_UNM, - TM_LEN, TM_LT, TM_LE, TM_CONCAT, diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index e0a9647..85829ca 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,6 +16,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauLenTM, false) + // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -2082,13 +2084,25 @@ static void luau_execute(lua_State* L) // fast-path #1: tables if (ttistable(rb)) { - setnvalue(ra, cast_num(luaH_getn(hvalue(rb)))); - VM_NEXT(); + Table* h = hvalue(rb); + + if (!FFlag::LuauLenTM || fastnotm(h->metatable, TM_LEN)) + { + setnvalue(ra, cast_num(luaH_getn(h))); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_dolen(L, ra, rb)); + VM_NEXT(); + } } // fast-path #2: strings (not very important but easy to do) else if (ttisstring(rb)) { - setnvalue(ra, cast_num(tsvalue(rb)->len)); + TString* ts = tsvalue(rb); + setnvalue(ra, cast_num(ts->len)); VM_NEXT(); } else @@ -2226,6 +2240,15 @@ static void luau_execute(lua_State* L) VM_PROTECT(luaD_call(L, ra, 3)); L->top = L->ci->top; + + /* recompute ra since stack might have been reallocated */ + ra = VM_REG(LUAU_INSN_A(insn)); + + /* protect against __iter returning nil, since nil is used as a marker for builtin iteration in FORGLOOP */ + if (ttisnil(ra)) + { + VM_PROTECT(luaG_typeerror(L, ra, "call")); + } } else if (fasttm(L, mt, TM_CALL)) { @@ -2258,27 +2281,38 @@ static void luau_execute(lua_State* L) uint32_t aux = *pc; // fast-path: builtin table iteration - if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) + // note: ra=nil guarantees ra+1=table and ra+2=userdata because of the setup by FORGPREP* opcodes + // TODO: remove the table check per guarantee above + if (ttisnil(ra) && ttistable(ra + 1)) { Table* h = hvalue(ra + 1); int index = int(reinterpret_cast(pvalue(ra + 2))); int sizearray = h->sizearray; - int sizenode = 1 << h->lsizenode; // clear extra variables since we might have more than two - if (LUAU_UNLIKELY(aux > 2)) + // note: while aux encodes ipairs bit, when set we always use 2 variables, so it's safe to check this via a signed comparison + if (LUAU_UNLIKELY(int(aux) > 2)) for (int i = 2; i < int(aux); ++i) setnilvalue(ra + 3 + i); + // terminate ipairs-style traversal early when encountering nil + if (int(aux) < 0 && (unsigned(index) >= unsigned(sizearray) || ttisnil(&h->array[index]))) + { + pc++; + VM_NEXT(); + } + // first we advance index through the array portion while (unsigned(index) < unsigned(sizearray)) { - if (!ttisnil(&h->array[index])) + TValue* e = &h->array[index]; + + if (!ttisnil(e)) { setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); setnvalue(ra + 3, double(index + 1)); - setobj2s(L, ra + 4, &h->array[index]); + setobj2s(L, ra + 4, e); pc += LUAU_INSN_D(insn); LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); @@ -2288,6 +2322,8 @@ static void luau_execute(lua_State* L) index++; } + int sizenode = 1 << h->lsizenode; + // then we advance index through the hash portion while (unsigned(index - sizearray) < unsigned(sizenode)) { @@ -2321,7 +2357,7 @@ static void luau_execute(lua_State* L) L->top = ra + 3 + 3; /* func + 2 args (state and index) */ LUAU_ASSERT(L->top <= L->stack_last); - VM_PROTECT(luaD_call(L, ra + 3, aux)); + VM_PROTECT(luaD_call(L, ra + 3, uint8_t(aux))); L->top = L->ci->top; // recompute ra since stack might have been reallocated diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 8a18a4d..b9e762e 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -10,6 +10,9 @@ #include "lnumutils.h" #include +#include + +LUAU_FASTFLAG(LuauLenTM) /* limit for table tag-method chains (to avoid loops) */ #define MAXTAGLOOP 100 @@ -51,7 +54,7 @@ const float* luaV_tovector(const TValue* obj) return nullptr; } -static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1, const TValue* p2) +static StkId callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1, const TValue* p2) { ptrdiff_t result = savestack(L, res); // using stack room beyond top is technically safe here, but for very complicated reasons: @@ -71,6 +74,7 @@ static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1 res = restorestack(L, result); L->top--; setobjs2s(L, res, L->top); + return res; } static void callTM(lua_State* L, const TValue* f, const TValue* p1, const TValue* p2, const TValue* p3) @@ -472,22 +476,56 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM void luaV_dolen(lua_State* L, StkId ra, const TValue* rb) { + if (!FFlag::LuauLenTM) + { + switch (ttype(rb)) + { + case LUA_TTABLE: + { + setnvalue(ra, cast_num(luaH_getn(hvalue(rb)))); + break; + } + case LUA_TSTRING: + { + setnvalue(ra, cast_num(tsvalue(rb)->len)); + break; + } + default: + { /* try metamethod */ + if (!call_binTM(L, rb, luaO_nilobject, ra, TM_LEN)) + luaG_typeerror(L, rb, "get length of"); + } + } + return; + } + + const TValue* tm = NULL; switch (ttype(rb)) { case LUA_TTABLE: { - setnvalue(ra, cast_num(luaH_getn(hvalue(rb)))); + Table* h = hvalue(rb); + if ((tm = fasttm(L, h->metatable, TM_LEN)) == NULL) + { + setnvalue(ra, cast_num(luaH_getn(h))); + return; + } break; } case LUA_TSTRING: { - setnvalue(ra, cast_num(tsvalue(rb)->len)); - break; + TString* ts = tsvalue(rb); + setnvalue(ra, cast_num(ts->len)); + return; } default: - { /* try metamethod */ - if (!call_binTM(L, rb, luaO_nilobject, ra, TM_LEN)) - luaG_typeerror(L, rb, "get length of"); - } + tm = luaT_gettmbyobj(L, rb, TM_LEN); } + + if (ttisnil(tm)) + luaG_typeerror(L, rb, "get length of"); + + StkId res = callTMres(L, ra, tm, rb, luaO_nilobject); + if (!ttisnumber(res)) + luaG_runerror(L, "'__len' must return a number"); /* note, we can't access rb since stack may have been reallocated */ } diff --git a/fuzz/protoprint.cpp b/fuzz/protoprint.cpp index 66a89f2..d4d5227 100644 --- a/fuzz/protoprint.cpp +++ b/fuzz/protoprint.cpp @@ -11,6 +11,7 @@ static const std::string kNames[] = { "__div", "__eq", "__index", + "__iter", "__le", "__len", "__lt", @@ -41,13 +42,18 @@ static const std::string kNames[] = { "ceil", "char", "charpattern", + "clamp", "clock", + "clone", + "close", "codepoint", "codes", "concat", "coroutine", "cos", "cosh", + "countlz", + "countrz", "create", "date", "debug", @@ -63,6 +69,7 @@ static const std::string kNames[] = { "foreachi", "format", "frexp", + "freeze", "function", "gcinfo", "getfenv", @@ -72,8 +79,10 @@ static const std::string kNames[] = { "gmatch", "gsub", "huge", + "info", "insert", "ipairs", + "isfrozen", "isyieldable", "ldexp", "len", @@ -93,6 +102,7 @@ static const std::string kNames[] = { "newproxy", "next", "nil", + "noise", "number", "offset", "os", @@ -121,6 +131,7 @@ static const std::string kNames[] = { "select", "setfenv", "setmetatable", + "sign", "sin", "sinh", "sort", diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 655e48c..fafafd7 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -261,6 +261,8 @@ L1: RETURN R0 0 TEST_CASE("ForBytecode") { + ScopedFastFlag sff("LuauCompileNoIpairs", true); + // basic for loop: variable directly refers to internal iteration index (R2) CHECK_EQ("\n" + compileFunction0("for i=1,5 do print(i) end"), R"( LOADN R2 1 @@ -313,7 +315,7 @@ L0: GETIMPORT R5 3 MOVE R6 R3 MOVE R7 R4 CALL R5 2 0 -L1: FORGLOOP_INEXT R0 L0 +L1: FORGLOOP R0 L0 2 [inext] RETURN R0 0 )"); @@ -347,13 +349,15 @@ RETURN R0 0 TEST_CASE("ForBytecodeBuiltin") { + ScopedFastFlag sff("LuauCompileNoIpairs", true); + // we generally recognize builtins like pairs/ipairs and emit special opcodes CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"( GETIMPORT R0 1 NEWTABLE R1 0 0 CALL R0 1 3 FORGPREP_INEXT R0 L0 -L0: FORGLOOP_INEXT R0 L0 +L0: FORGLOOP R0 L0 2 [inext] RETURN R0 0 )"); @@ -364,7 +368,7 @@ MOVE R1 R0 NEWTABLE R2 0 0 CALL R1 1 3 FORGPREP_INEXT R1 L0 -L0: FORGLOOP_INEXT R1 L0 +L0: FORGLOOP R1 L0 2 [inext] RETURN R0 0 )"); @@ -374,7 +378,7 @@ GETUPVAL R0 0 NEWTABLE R1 0 0 CALL R0 1 3 FORGPREP_INEXT R0 L0 -L0: FORGLOOP_INEXT R0 L0 +L0: FORGLOOP R0 L0 2 [inext] RETURN R0 0 )"); @@ -2107,6 +2111,8 @@ RETURN R3 -1 TEST_CASE("UpvaluesLoopsBytecode") { + ScopedFastFlag sff("LuauCompileNoIpairs", true); + CHECK_EQ("\n" + compileFunction(R"( function test() for i=1,10 do @@ -2169,7 +2175,7 @@ JUMPIFNOT R5 L1 CLOSEUPVALS R3 JUMP L3 L1: CLOSEUPVALS R3 -L2: FORGLOOP_INEXT R0 L0 +L2: FORGLOOP R0 L0 1 [inext] L3: LOADN R0 0 RETURN R0 1 )"); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 96a2775..3f41514 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -231,6 +231,8 @@ TEST_CASE("Assert") TEST_CASE("Basic") { + ScopedFastFlag sff("LuauLenTM", true); + runConformance("basic.lua"); } @@ -301,6 +303,8 @@ TEST_CASE("Errors") TEST_CASE("Events") { + ScopedFastFlag sff("LuauLenTM", true); + runConformance("events.lua"); } @@ -475,6 +479,8 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { + ScopedFastFlag sff("LuauCheckLenMT", true); + runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; Luau::InternalErrorReporter iceHandler; diff --git a/tests/ConstraintGraphBuilder.test.cpp b/tests/ConstraintGraphBuilder.test.cpp index 96b2161..00c3309 100644 --- a/tests/ConstraintGraphBuilder.test.cpp +++ b/tests/ConstraintGraphBuilder.test.cpp @@ -17,7 +17,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello_world") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(2 == constraints.size()); @@ -36,7 +36,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "primitives") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(3 == constraints.size()); @@ -54,15 +54,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "nil_primitive") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); ToStringOptions opts; REQUIRE(5 <= constraints.size()); CHECK("*blocked-1* ~ gen () -> (a...)" == toString(*constraints[0], opts)); - CHECK("b ~ inst *blocked-1*" == toString(*constraints[1], opts)); - CHECK("() -> (c...) <: b" == toString(*constraints[2], opts)); - CHECK("c... <: d" == toString(*constraints[3], opts)); + CHECK("*blocked-2* ~ inst *blocked-1*" == toString(*constraints[1], opts)); + CHECK("() -> (b...) <: *blocked-2*" == toString(*constraints[2], opts)); + CHECK("b... <: c" == toString(*constraints[3], opts)); CHECK("nil <: a..." == toString(*constraints[4], opts)); } @@ -74,15 +74,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "function_application") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(4 == constraints.size()); ToStringOptions opts; CHECK("string <: a" == toString(*constraints[0], opts)); - CHECK("b ~ inst a" == toString(*constraints[1], opts)); - CHECK("(string) -> (c...) <: b" == toString(*constraints[2], opts)); - CHECK("c... <: d" == toString(*constraints[3], opts)); + CHECK("*blocked-1* ~ inst a" == toString(*constraints[1], opts)); + CHECK("(string) -> (b...) <: *blocked-1*" == toString(*constraints[2], opts)); + CHECK("b... <: c" == toString(*constraints[3], opts)); } TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") @@ -94,7 +94,7 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "local_function_definition") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(2 == constraints.size()); @@ -112,15 +112,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "recursive_function") )"); cgb.visit(block); - auto constraints = collectConstraints(cgb.rootScope); + auto constraints = collectConstraints(NotNull(cgb.rootScope)); REQUIRE(4 == constraints.size()); ToStringOptions opts; CHECK("*blocked-1* ~ gen (a) -> (b...)" == toString(*constraints[0], opts)); - CHECK("c ~ inst (a) -> (b...)" == toString(*constraints[1], opts)); - CHECK("(a) -> (d...) <: c" == toString(*constraints[2], opts)); - CHECK("d... <: b..." == toString(*constraints[3], opts)); + CHECK("*blocked-2* ~ inst (a) -> (b...)" == toString(*constraints[1], opts)); + CHECK("(a) -> (c...) <: *blocked-2*" == toString(*constraints[2], opts)); + CHECK("c... <: b..." == toString(*constraints[3], opts)); } TEST_SUITE_END(); diff --git a/tests/ConstraintSolver.test.cpp b/tests/ConstraintSolver.test.cpp index 5959f55..f521c66 100644 --- a/tests/ConstraintSolver.test.cpp +++ b/tests/ConstraintSolver.test.cpp @@ -9,7 +9,7 @@ using namespace Luau; -static TypeId requireBinding(Scope2* scope, const char* name) +static TypeId requireBinding(NotNull scope, const char* name) { auto b = linearSearchForBinding(scope, name); LUAU_ASSERT(b.has_value()); @@ -26,12 +26,13 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "hello") )"); cgb.visit(block); + NotNull rootScope = NotNull(cgb.rootScope); - ConstraintSolver cs{&arena, cgb.rootScope}; + ConstraintSolver cs{&arena, rootScope}; cs.run(); - TypeId bType = requireBinding(cgb.rootScope, "b"); + TypeId bType = requireBinding(rootScope, "b"); CHECK("number" == toString(bType)); } @@ -45,12 +46,13 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "generic_function") )"); cgb.visit(block); + NotNull rootScope = NotNull(cgb.rootScope); - ConstraintSolver cs{&arena, cgb.rootScope}; + ConstraintSolver cs{&arena, rootScope}; cs.run(); - TypeId idType = requireBinding(cgb.rootScope, "id"); + TypeId idType = requireBinding(rootScope, "id"); CHECK("(a) -> a" == toString(idType)); } @@ -71,14 +73,15 @@ TEST_CASE_FIXTURE(ConstraintGraphBuilderFixture, "proper_let_generalization") )"); cgb.visit(block); + NotNull rootScope = NotNull(cgb.rootScope); ToStringOptions opts; - ConstraintSolver cs{&arena, cgb.rootScope}; + ConstraintSolver cs{&arena, rootScope}; cs.run(); - TypeId idType = requireBinding(cgb.rootScope, "b"); + TypeId idType = requireBinding(rootScope, "b"); CHECK("(a) -> number" == toString(idType, opts)); } diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index ac22f65..c92c445 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -195,12 +195,15 @@ ParseResult Fixture::matchParseError(const std::string& source, const std::strin sourceModule.reset(new SourceModule); ParseResult result = Parser::parse(source.c_str(), source.length(), *sourceModule->names, *sourceModule->allocator, options); - REQUIRE_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); + CHECK_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); - CHECK_EQ(result.errors.front().getMessage(), message); + if (!result.errors.empty()) + { + CHECK_EQ(result.errors.front().getMessage(), message); - if (location) - CHECK_EQ(result.errors.front().getLocation(), *location); + if (location) + CHECK_EQ(result.errors.front().getLocation(), *location); + } return result; } @@ -213,11 +216,14 @@ ParseResult Fixture::matchParseErrorPrefix(const std::string& source, const std: sourceModule.reset(new SourceModule); ParseResult result = Parser::parse(source.c_str(), source.length(), *sourceModule->names, *sourceModule->allocator, options); - REQUIRE_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); + CHECK_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); - const std::string& message = result.errors.front().getMessage(); - CHECK_GE(message.length(), prefix.length()); - CHECK_EQ(prefix, message.substr(0, prefix.size())); + if (!result.errors.empty()) + { + const std::string& message = result.errors.front().getMessage(); + CHECK_GE(message.length(), prefix.length()); + CHECK_EQ(prefix, message.substr(0, prefix.size())); + } return result; } @@ -428,6 +434,7 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() : Fixture() + , cgb(mainModuleName, &arena, NotNull(&ice), frontend.getGlobalScope2()) , forceTheFlag{"DebugLuauDeferredConstraintResolution", true} { BlockedTypeVar::nextIndex = 0; diff --git a/tests/Fixture.h b/tests/Fixture.h index 0e3735f..1bc573d 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -133,6 +133,7 @@ struct Fixture TestConfigResolver configResolver; std::unique_ptr sourceModule; Frontend frontend; + InternalErrorReporter ice; TypeChecker& typeChecker; std::string decorateWithTypes(const std::string& code); @@ -160,7 +161,7 @@ struct BuiltinsFixture : Fixture struct ConstraintGraphBuilderFixture : Fixture { TypeArena arena; - ConstraintGraphBuilder cgb{&arena}; + ConstraintGraphBuilder cgb; ScopedFastFlag forceTheFlag; diff --git a/tests/NotNull.test.cpp b/tests/NotNull.test.cpp index ed1c25e..e77ba78 100644 --- a/tests/NotNull.test.cpp +++ b/tests/NotNull.test.cpp @@ -30,23 +30,23 @@ struct Test int Test::count = 0; -} +} // namespace int foo(NotNull p) { return *p; } -void bar(int* q) -{} +void bar(int* q) {} TEST_SUITE_BEGIN("NotNull"); TEST_CASE("basic_stuff") { - NotNull a = NotNull{new int(55)}; // Does runtime test - NotNull b{new int(55)}; // As above - // NotNull c = new int(55); // Nope. Mildly regrettable, but implicit conversion from T* to NotNull in the general case is not good. + NotNull a = NotNull{new int(55)}; // Does runtime test + NotNull b{new int(55)}; // As above + // NotNull c = new int(55); // Nope. Mildly regrettable, but implicit conversion from T* to NotNull in the general case is not + // good. // a = nullptr; // nope diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 878023e..c3c7599 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -6,6 +6,8 @@ #include "doctest.h" +#include + using namespace Luau; namespace @@ -786,33 +788,46 @@ TEST_CASE_FIXTURE(Fixture, "parse_numbers_decimal") TEST_CASE_FIXTURE(Fixture, "parse_numbers_hexadecimal") { - AstStat* stat = parse("return 0xab, 0XAB05, 0xff_ff"); + AstStat* stat = parse("return 0xab, 0XAB05, 0xff_ff, 0xffffffffffffffff"); REQUIRE(stat != nullptr); AstStatReturn* str = stat->as()->body.data[0]->as(); - CHECK(str->list.size == 3); + CHECK(str->list.size == 4); CHECK_EQ(str->list.data[0]->as()->value, 0xab); CHECK_EQ(str->list.data[1]->as()->value, 0xAB05); CHECK_EQ(str->list.data[2]->as()->value, 0xFFFF); + CHECK_EQ(str->list.data[3]->as()->value, double(ULLONG_MAX)); } TEST_CASE_FIXTURE(Fixture, "parse_numbers_binary") { - AstStat* stat = parse("return 0b1, 0b0, 0b101010"); + AstStat* stat = parse("return 0b1, 0b0, 0b101010, 0b1111111111111111111111111111111111111111111111111111111111111111"); REQUIRE(stat != nullptr); AstStatReturn* str = stat->as()->body.data[0]->as(); - CHECK(str->list.size == 3); + CHECK(str->list.size == 4); CHECK_EQ(str->list.data[0]->as()->value, 1); CHECK_EQ(str->list.data[1]->as()->value, 0); CHECK_EQ(str->list.data[2]->as()->value, 42); + CHECK_EQ(str->list.data[3]->as()->value, double(ULLONG_MAX)); } TEST_CASE_FIXTURE(Fixture, "parse_numbers_error") { + ScopedFastFlag luauErrorParseIntegerIssues{"LuauErrorParseIntegerIssues", true}; + CHECK_EQ(getParseError("return 0b123"), "Malformed number"); CHECK_EQ(getParseError("return 123x"), "Malformed number"); CHECK_EQ(getParseError("return 0xg"), "Malformed number"); + CHECK_EQ(getParseError("return 0x0x123"), "Malformed number"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_numbers_range_error") +{ + ScopedFastFlag luauErrorParseIntegerIssues{"LuauErrorParseIntegerIssues", true}; + + CHECK_EQ(getParseError("return 0x10000000000000000"), "Integer number value is out of range"); + CHECK_EQ(getParseError("return 0b10000000000000000000000000000000000000000000000000000000000000000"), "Integer number value is out of range"); } TEST_CASE_FIXTURE(Fixture, "break_return_not_last_error") @@ -2111,6 +2126,15 @@ type C = Packed<(number, X...)> REQUIRE(stat != nullptr); } +TEST_CASE_FIXTURE(Fixture, "invalid_type_forms") +{ + ScopedFastFlag luauFixNamedFunctionParse{"LuauFixNamedFunctionParse", true}; + + matchParseError("type A = (b: number)", "Expected '->' when parsing function type, got "); + matchParseError("type P = () -> T... type B = P<(x: number, y: string)>", "Expected '->' when parsing function type, got '>'"); + matchParseError("type F = (T...) -> ()", "Expected '->' when parsing function type, got '>'"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index e03069a..387e07c 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -409,6 +409,8 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") { + ScopedFastFlag sff2{"DebugLuauSharedSelf", true}; + CheckResult result = check(R"( local base = {} function base:one() return 1 end @@ -424,7 +426,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") TypeId tType = requireType("inst"); ToStringResult r = toStringDetailed(tType); - CHECK_EQ("{ @metatable { __index: { @metatable { __index: base }, child } }, inst }", r.name); + CHECK_EQ("{ @metatable { __index: { @metatable {| __index: base |}, child } }, inst }", r.name); CHECK_EQ(0, r.nameMap.typeVars.size()); ToStringOptions opts; @@ -455,11 +457,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") std::string twoResult = toString(tMeta6->props["two"].type, opts); - REQUIRE_EQ("(a) -> number", oneResult.name); - REQUIRE_EQ("(b) -> number", twoResult); + CHECK_EQ("(a) -> number", oneResult.name); + CHECK_EQ("(b) -> number", twoResult); } - TEST_CASE_FIXTURE(Fixture, "toStringErrorPack") { CheckResult result = check(R"( @@ -688,6 +689,10 @@ TEST_CASE_FIXTURE(Fixture, "pick_distinct_names_for_mixed_explicit_and_implicit_ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") { + ScopedFastFlag sff[]{ + {"DebugLuauSharedSelf", true}, + }; + CheckResult result = check(R"( local foo = {} function foo:method(arg: string): () @@ -701,9 +706,12 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_include_self_param") CHECK_EQ("foo:method(self: a, arg: string): ()", toStringNamedFunction("foo:method", *ftv)); } - TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_self_param") { + ScopedFastFlag sff[]{ + {"DebugLuauSharedSelf", true}, + }; + CheckResult result = check(R"( local foo = {} function foo:method(arg: string): () diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index d6f0a0c..bdd4d6f 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -73,24 +73,24 @@ TEST_CASE_FIXTURE(Fixture, "cannot_steal_hoisted_type_alias") if (FFlag::DebugLuauDeferredConstraintResolution) { CHECK(result.errors[0] == TypeError{ - Location{{1, 21}, {1, 26}}, - getMainSourceModule()->name, - TypeMismatch{ - getSingletonTypes().numberType, - getSingletonTypes().stringType, - }, - }); + Location{{1, 21}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + getSingletonTypes().numberType, + getSingletonTypes().stringType, + }, + }); } else { CHECK(result.errors[0] == TypeError{ - Location{{1, 8}, {1, 26}}, - getMainSourceModule()->name, - TypeMismatch{ - getSingletonTypes().numberType, - getSingletonTypes().stringType, - }, - }); + Location{{1, 8}, {1, 26}}, + getMainSourceModule()->name, + TypeMismatch{ + getSingletonTypes().numberType, + getSingletonTypes().stringType, + }, + }); } } @@ -716,6 +716,10 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2") { + ScopedFastFlag sff[] = { + {"DebugLuauSharedSelf", true}, + }; + CheckResult result = check(R"( local B = {} B.bar = 4 @@ -737,7 +741,8 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni type FutureIntersection = A & B )"); - LUAU_REQUIRE_NO_ERRORS(result); + // TODO: shared self causes this test to break in bizarre ways. + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 3e2ad6d..8a86ee5 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -134,13 +134,13 @@ TEST_CASE_FIXTURE(Fixture, "unknown_type_reference_generates_error") LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK(result.errors[0] == TypeError{ - Location{{1, 17}, {1, 28}}, - getMainSourceModule()->name, - UnknownSymbol{ - "IDoNotExist", - UnknownSymbol::Context::Type, - }, - }); + Location{{1, 17}, {1, 28}}, + getMainSourceModule()->name, + UnknownSymbol{ + "IDoNotExist", + UnknownSymbol::Context::Type, + }, + }); } TEST_CASE_FIXTURE(Fixture, "typeof_variable_type_annotation_should_return_its_type") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 036a667..401a6c6 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -37,6 +37,27 @@ TEST_CASE_FIXTURE(Fixture, "check_function_bodies") }})); } +TEST_CASE_FIXTURE(Fixture, "cannot_hoist_interior_defns_into_signature") +{ + // This test verifies that the signature does not have access to types + // declared within the body. Under DCR, if the function's inner scope + // encompasses the entire function expression, it would be possible for this + // to type check (but the solver output is somewhat undefined). This test + // ensures that this isn't the case. + CheckResult result = check(R"( + local function f(x: T) + type T = number + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(result.errors[0] == TypeError{Location{{1, 28}, {1, 29}}, getMainSourceModule()->name, + UnknownSymbol{ + "T", + UnknownSymbol::Context::Type, + }}); +} + TEST_CASE_FIXTURE(Fixture, "infer_return_type") { CheckResult result = check("function take_five() return 5 end"); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 97ba080..e9e94cf 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -271,13 +271,16 @@ TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") TEST_CASE_FIXTURE(Fixture, "infer_generic_methods") { + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + CheckResult result = check(R"( local x = {} function x:id(x) return x end function x:f(): string return self:id("hello") end function x:g(): number return self:id(37) end )"); - LUAU_REQUIRE_NO_ERRORS(result); + // TODO: Quantification should be doing the conversion, not normalization. + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index fd9b1dd..e6174df 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -461,6 +461,61 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus") REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") +{ + CheckResult result = check(R"( + --!strict + local foo = { + value = 10 + } + + local mt = {} + setmetatable(foo, mt) + + mt.__unm = function(val: boolean): string + return "test" + end + + local a = -foo + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("string", toString(requireType("a"))); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE_EQ(*tm->wantedType, *typeChecker.booleanType); + // given type is the typeof(foo) which is complex to compare against +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error") +{ + ScopedFastFlag sff("LuauCheckLenMT", true); + + CheckResult result = check(R"( + --!strict + local foo = { + value = 10 + } + local mt = {} + setmetatable(foo, mt) + + mt.__len = function(val: any): string + return "test" + end + + local a = #foo + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("number", toString(requireType("a"))); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); + REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "unary_not_is_boolean") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 487e597..059aed2 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -499,4 +499,26 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent") CHECK_EQ("(t1) -> {| [t1]: boolean |} where t1 = t2 ; t2 = {+ m1: (t1) -> (a...), m2: (t2) -> (b...) +}", toString(requireType("f"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_function_with_no_returns") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local T = {} + T.__index = T + + function T.new() + local self = setmetatable({}, T) + return self:ctor() or self + end + + function T:ctor() + -- oops, no return! + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Not all codepaths in this function return '{ @metatable T, {| |} }, a...'.", toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 77a2928..eead5b3 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1863,6 +1863,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "quantifying_a_bound_var_works") TEST_CASE_FIXTURE(BuiltinsFixture, "less_exponential_blowup_please") { + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + CheckResult result = check(R"( --!strict @@ -1890,7 +1892,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "less_exponential_blowup_please") newData:First() )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(2, result); } TEST_CASE_FIXTURE(Fixture, "common_table_element_union_in_call") @@ -2868,6 +2870,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") { ScopedFastFlag sff[] = { {"LuauLowerBoundsCalculation", true}, + {"DebugLuauSharedSelf", true}, }; check(R"( @@ -2887,7 +2890,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_return_type_of_free_table") end )"); - CHECK_EQ("(t1) -> {| Byte: (b) -> (a...), PeekByte: (c) -> (a...) |} where t1 = {+ byte: (t1, number) -> (a...) +}", + CHECK_EQ("(t1) -> {| Byte: (a) -> (b...), PeekByte: (a) -> (b...) |} where t1 = {+ byte: (t1, number) -> (b...) +}", toString(requireType("Base64FileReader"))); } @@ -2904,6 +2907,66 @@ TEST_CASE_FIXTURE(Fixture, "mixed_tables_with_implicit_numbered_keys") CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[2])); } +TEST_CASE_FIXTURE(Fixture, "shared_selfs") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local t = {} + t.x = 5 + + function t:m1() return self.x end + function t:m2() return self.y end + + return t + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ("{| m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b, x: number |}", toString(requireType("t"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "shared_selfs_from_free_param") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local function f(t) + function t:m1() return self.x end + function t:m2() return self.y end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("({+ m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b +}) -> ()", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "shared_selfs_through_metatables") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local t = {} + t.__index = t + setmetatable({}, t) + + function t:m1() return self.x end + function t:m2() return self.y end + + return t + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ( + toString(requireType("t"), opts), "t1 where t1 = {| __index: t1, m1: ({+ x: a, y: b +}) -> a, m2: ({+ x: a, y: b +}) -> b |}"); +} + TEST_CASE_FIXTURE(Fixture, "expected_indexer_value_type_extra") { CheckResult result = check(R"( @@ -2953,4 +3016,58 @@ TEST_CASE_FIXTURE(Fixture, "prop_access_on_unions_of_indexers_where_key_whose_ty CHECK_EQ("Type '{number} | {| [boolean]: number |}' does not have key 'x'", toString(result.errors[0])); } +TEST_CASE_FIXTURE(BuiltinsFixture, "quantify_metatables_of_metatables_of_table") +{ + ScopedFastFlag sff[]{ + {"DebugLuauSharedSelf", true}, + }; + + CheckResult result = check(R"( + local T = {} + + function T:m() + return self.x, self.y + end + + function T:n() + end + + local U = setmetatable({}, {__index = T}) + + local V = setmetatable({}, {__index = U}) + + return V + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ(toString(requireType("V"), opts), "{ @metatable { __index: { @metatable { __index: {| m: ({+ x: a, y: b +}) -> (a, b), n: ({+ x: a, y: b +}) -> () |} }, { } } }, { } }"); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_even_that_table_was_never_exported_at_all") +{ + ScopedFastFlag sff{"DebugLuauSharedSelf", true}; + + CheckResult result = check(R"( + local T = {} + + function T:m() + return self.x + end + + function T:n() + return self.y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + CHECK_EQ("{| m: ({+ x: a, y: b +}) -> a, n: ({+ x: a, y: b +}) -> b |}", toString(requireType("T"), opts)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 6a048b2..efdfe0b 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -369,14 +369,14 @@ TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") CHECK_EQ("foo", us->name); } -TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_scope_locals_do") +TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") { CheckResult result = check(R"( do local a = 1 end - print(a) -- oops! + local b = a -- oops! )"); LUAU_REQUIRE_ERROR_COUNT(1, result); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index f803c31..385a045 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -118,10 +118,12 @@ assert((function() return #_G end)() == 0) assert((function() return #{1,2} end)() == 2) assert((function() return #'g' end)() == 1) -assert((function() local ud = newproxy(true) getmetatable(ud).__len = function() return 42 end return #ud end)() == 42) - assert((function() local a = 1 a = -a return a end)() == -1) +-- __len metamethod +assert((function() local ud = newproxy(true) getmetatable(ud).__len = function() return 42 end return #ud end)() == 42) +assert((function() local t = {} setmetatable(t, { __len = function() return 42 end }) return #t end)() == 42) + -- while/repeat assert((function() local a = 10 local b = 1 while a > 1 do b = b * 2 a = a - 1 end return b end)() == 512) assert((function() local a = 10 local b = 1 repeat b = b * 2 a = a - 1 until a == 1 return b end)() == 512) @@ -889,6 +891,10 @@ assert((function() return table.concat(res, ',') end)() == "6,8,10") +-- typeof and type require an argument +assert(pcall(typeof) == false) +assert(pcall(type) == false) + -- typeof == type in absence of custom userdata assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number,nil,table,userdata") diff --git a/tests/conformance/events.lua b/tests/conformance/events.lua index 32e090a..42f1bed 100644 --- a/tests/conformance/events.lua +++ b/tests/conformance/events.lua @@ -386,4 +386,42 @@ do assert(t.X) -- fails if table flags are set incorrectly end +do + -- verify __len behavior & error handling + local t = {1} + + setmetatable(t, {}) + assert(#t == 1) + + setmetatable(t, { __len = rawlen }) + assert(#t == 1) + + setmetatable(t, { __len = function() return 42 end }) + assert(#t == 42) + + setmetatable(t, { __len = 42 }) + local ok, err = pcall(function() return #t end) + assert(not ok and err:match("attempt to call a number value")) + + setmetatable(t, { __len = function() end }) + local ok, err = pcall(function() return #t end) + assert(not ok and err:match("'__len' must return a number")) + + setmetatable(t, { __len = error }) + local ok, err = pcall(function() return #t end) + assert(not ok and err == t) +end + +-- verify rawlen behavior +do + local t = {1} + setmetatable(t, { __len = 42 }) + + assert(rawlen(t) == 1) + assert(rawlen("foo") == 3) + + local ok, err = pcall(function() return rawlen(42) end) + assert(not ok and err:match("table or string expected")) +end + return 'OK' From 6467c855e8fd4dff99ae868230d9be61991ccbd9 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 30 Jun 2022 17:07:56 -0700 Subject: [PATCH 02/10] Update compatibility.md (#566) Update `__len` metamethod (pending the code change that implements this) --- docs/_pages/compatibility.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_pages/compatibility.md b/docs/_pages/compatibility.md index d1686c2..bdb3d81 100644 --- a/docs/_pages/compatibility.md +++ b/docs/_pages/compatibility.md @@ -54,7 +54,7 @@ Sandboxing challenges are [covered in the dedicated section](sandbox). | goto statement | ❌ | this complicates the compiler, makes control flow unstructured and doesn't address a significant need | | finalizers for tables | ❌ | no `__gc` support due to sandboxing and performance/complexity | | no more fenv for threads or functions | 😞 | we love this, but it breaks compatibility | -| tables honor the `__len` metamethod | 🤷‍♀️ | performance implications, no strong use cases +| tables honor the `__len` metamethod | ✔️ | | | hex and `\z` escapes in strings | ✔️ | | | support for hexadecimal floats | 🤷‍♀️ | no strong use cases | | order metamethods work for different types | ❌ | no strong use cases and more complicated semantics, compatibility and performance implications | From 48aa7a51622d7eb8125a78cad549cdd384bcdb1e Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Mon, 4 Jul 2022 11:13:07 -0700 Subject: [PATCH 03/10] bench: Implement first class support for callgrind (#570) Since callgrind allows to control stats collection from the guest, this allows us to reset the collection right before the benchmark starts. This change exposes this to the benchmark runner and integrates callgrind data parsing into bench.py, so that we can run bench.py with --callgrind argument and, as long as the runner was built with callgrind support, we get instruction counts from the run. We convert instruction counts to seconds using 10G instructions/second rate; there's no correct way to do this without simulating the full CPU pipeline but it results in time units on a similar scale to real runs. --- .github/workflows/benchmark-dev.yml | 270 ++++++++++++++++++++++++++++ .github/workflows/benchmark.yml | 233 ++---------------------- CLI/Repl.cpp | 37 ++++ Makefile | 4 + bench/bench.py | 27 +++ bench/bench_support.lua | 10 ++ 6 files changed, 363 insertions(+), 218 deletions(-) create mode 100644 .github/workflows/benchmark-dev.yml diff --git a/.github/workflows/benchmark-dev.yml b/.github/workflows/benchmark-dev.yml new file mode 100644 index 0000000..21f9559 --- /dev/null +++ b/.github/workflows/benchmark-dev.yml @@ -0,0 +1,270 @@ +name: benchmark-dev + +on: + push: + branches: + - master + paths-ignore: + - "docs/**" + - "papers/**" + - "rfcs/**" + - "*.md" + - "prototyping/**" + +jobs: + windows: + name: windows-${{matrix.arch}} + strategy: + fail-fast: false + matrix: + os: [windows-latest] + arch: [Win32, x64] + bench: + - { + script: "run-benchmarks", + timeout: 12, + title: "Luau Benchmarks", + cachegrindTitle: "Performance", + cachegrindIterCount: 20, + } + benchResultsRepo: + - { name: "luau-lang/benchmark-data", branch: "main" } + + runs-on: ${{ matrix.os }} + steps: + - name: Checkout Luau repository + uses: actions/checkout@v3 + + - name: Build Luau + shell: bash # necessary for fail-fast + run: | + mkdir build && cd build + cmake .. -DCMAKE_BUILD_TYPE=Release + cmake --build . --target Luau.Repl.CLI --config Release + cmake --build . --target Luau.Analyze.CLI --config Release + + - name: Move build files to root + run: | + move build/Release/* . + + - uses: actions/setup-python@v3 + with: + python-version: "3.9" + architecture: "x64" + + - name: Install python dependencies + run: | + python -m pip install requests + python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose + + - name: Run benchmark + run: | + python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt + + - name: Checkout Benchmark Results repository + uses: actions/checkout@v3 + with: + repository: ${{ matrix.benchResultsRepo.name }} + ref: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + + - name: Store ${{ matrix.bench.title }} result + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} (Windows ${{matrix.arch}}) + tool: "benchmarkluau" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Push benchmark results + if: github.event_name == 'push' + run: | + echo "Pushing benchmark results..." + cd gh-pages + git config user.name github-actions + git config user.email github@users.noreply.github.com + git add ./dev/bench/data.json + git commit -m "Add benchmarks results for ${{ github.sha }}" + git push + cd .. + + unix: + name: ${{matrix.os}} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + bench: + - { + script: "run-benchmarks", + timeout: 12, + title: "Luau Benchmarks", + cachegrindTitle: "Performance", + cachegrindIterCount: 20, + } + benchResultsRepo: + - { name: "luau-lang/benchmark-data", branch: "main" } + + runs-on: ${{ matrix.os }} + steps: + - name: Checkout Luau repository + uses: actions/checkout@v3 + + - name: Build Luau + run: make config=release luau luau-analyze + + - uses: actions/setup-python@v3 + with: + python-version: "3.9" + architecture: "x64" + + - name: Install python dependencies + run: | + python -m pip install requests + python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose + + - name: Run benchmark + run: | + python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt + + - name: Install valgrind + if: matrix.os == 'ubuntu-latest' + run: | + sudo apt-get install valgrind + + - name: Run ${{ matrix.bench.title }} (Cold Cachegrind) + if: matrix.os == 'ubuntu-latest' + run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/bench.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 | tee -a ${{ matrix.bench.script }}-output.txt + + - name: Run ${{ matrix.bench.title }} (Warm Cachegrind) + if: matrix.os == 'ubuntu-latest' + run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/bench.py "${{ matrix.bench.cachegrindTitle }}" ${{ matrix.bench.cachegrindIterCount }} | tee -a ${{ matrix.bench.script }}-output.txt + + - name: Checkout Benchmark Results repository + uses: actions/checkout@v3 + with: + repository: ${{ matrix.benchResultsRepo.name }} + ref: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + + - name: Store ${{ matrix.bench.title }} result + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} + tool: "benchmarkluau" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} + + - name: Store ${{ matrix.bench.title }} result (CacheGrind) + if: matrix.os == 'ubuntu-latest' + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} (CacheGrind) + tool: "roblox" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} + + - name: Push benchmark results + if: github.event_name == 'push' + run: | + echo "Pushing benchmark results..." + cd gh-pages + git config user.name github-actions + git config user.email github@users.noreply.github.com + git add ./dev/bench/data.json + git commit -m "Add benchmarks results for ${{ github.sha }}" + git push + cd .. + + static-analysis: + name: luau-analyze + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + bench: + - { + script: "run-analyze", + timeout: 12, + title: "Luau Analyze", + cachegrindTitle: "Performance", + cachegrindIterCount: 20, + } + benchResultsRepo: + - { name: "luau-lang/benchmark-data", branch: "main" } + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + with: + token: "${{ secrets.BENCH_GITHUB_TOKEN }}" + + - name: Build Luau + run: make config=release luau luau-analyze + + - uses: actions/setup-python@v4 + with: + python-version: "3.9" + architecture: "x64" + + - name: Install python dependencies + run: | + sudo pip install requests numpy scipy matplotlib ipython jupyter pandas sympy nose + + - name: Install valgrind + run: | + sudo apt-get install valgrind + + - name: Run Luau Analyze on static file + run: sudo python ./bench/measure_time.py ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee ${{ matrix.bench.script }}-output.txt + + - name: Run ${{ matrix.bench.title }} (Cold Cachegrind) + run: sudo ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt + + - name: Run ${{ matrix.bench.title }} (Warm Cachegrind) + run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt + + - name: Checkout Benchmark Results repository + uses: actions/checkout@v3 + with: + repository: ${{ matrix.benchResultsRepo.name }} + ref: ${{ matrix.benchResultsRepo.branch }} + token: ${{ secrets.BENCH_GITHUB_TOKEN }} + path: "./gh-pages" + + - name: Store ${{ matrix.bench.title }} result + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} + tool: "benchmarkluau" + + gh-pages-branch: "main" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} + + - name: Store ${{ matrix.bench.title }} result (CacheGrind) + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: ${{ matrix.bench.title }} + tool: "roblox" + gh-pages-branch: "main" + output-file-path: ./${{ matrix.bench.script }}-output.txt + external-data-json-path: ./gh-pages/dev/bench/data.json + github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} + + - name: Push benchmark results + if: github.event_name == 'push' + run: | + echo "Pushing benchmark results..." + cd gh-pages + git config user.name github-actions + git config user.email github@users.noreply.github.com + git add ./dev/bench/data.json + git commit -m "Add benchmarks results for ${{ github.sha }}" + git push + cd .. diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 4df7b2f..4829682 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -12,21 +12,13 @@ on: - "prototyping/**" jobs: - windows: - name: windows-${{matrix.arch}} + callgrind: + name: callgrind ${{ matrix.compiler }} strategy: fail-fast: false matrix: - os: [windows-latest] - arch: [Win32, x64] - bench: - - { - script: "run-benchmarks", - timeout: 12, - title: "Luau Benchmarks", - cachegrindTitle: "Performance", - cachegrindIterCount: 20, - } + os: [ubuntu-22.04] + compiler: [g++] benchResultsRepo: - { name: "luau-lang/benchmark-data", branch: "main" } @@ -35,200 +27,18 @@ jobs: - name: Checkout Luau repository uses: actions/checkout@v3 - - name: Build Luau - shell: bash # necessary for fail-fast - run: | - mkdir build && cd build - cmake .. -DCMAKE_BUILD_TYPE=Release - cmake --build . --target Luau.Repl.CLI --config Release - cmake --build . --target Luau.Analyze.CLI --config Release - - - name: Move build files to root - run: | - move build/Release/* . - - - uses: actions/setup-python@v3 - with: - python-version: "3.9" - architecture: "x64" - - - name: Install python dependencies - run: | - python -m pip install requests - python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose - - - name: Run benchmark - run: | - python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt - - - name: Checkout Benchmark Results repository - uses: actions/checkout@v3 - with: - repository: ${{ matrix.benchResultsRepo.name }} - ref: ${{ matrix.benchResultsRepo.branch }} - token: ${{ secrets.BENCH_GITHUB_TOKEN }} - path: "./gh-pages" - - - name: Store ${{ matrix.bench.title }} result - uses: Roblox/rhysd-github-action-benchmark@v-luau - with: - name: ${{ matrix.bench.title }} (Windows ${{matrix.arch}}) - tool: "benchmarkluau" - output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json - github-token: ${{ secrets.GITHUB_TOKEN }} - - - name: Push benchmark results - if: github.event_name == 'push' - run: | - echo "Pushing benchmark results..." - cd gh-pages - git config user.name github-actions - git config user.email github@users.noreply.github.com - git add ./dev/bench/data.json - git commit -m "Add benchmarks results for ${{ github.sha }}" - git push - cd .. - - unix: - name: ${{matrix.os}} - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest, macos-latest] - bench: - - { - script: "run-benchmarks", - timeout: 12, - title: "Luau Benchmarks", - cachegrindTitle: "Performance", - cachegrindIterCount: 20, - } - benchResultsRepo: - - { name: "luau-lang/benchmark-data", branch: "main" } - - runs-on: ${{ matrix.os }} - steps: - - name: Checkout Luau repository - uses: actions/checkout@v3 - - - name: Build Luau - run: make config=release luau luau-analyze - - - uses: actions/setup-python@v3 - with: - python-version: "3.9" - architecture: "x64" - - - name: Install python dependencies - run: | - python -m pip install requests - python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose - - - name: Run benchmark - run: | - python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt - - - name: Install valgrind - if: matrix.os == 'ubuntu-latest' - run: | - sudo apt-get install valgrind - - - name: Run ${{ matrix.bench.title }} (Cold Cachegrind) - if: matrix.os == 'ubuntu-latest' - run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/bench.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 | tee -a ${{ matrix.bench.script }}-output.txt - - - name: Run ${{ matrix.bench.title }} (Warm Cachegrind) - if: matrix.os == 'ubuntu-latest' - run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/bench.py "${{ matrix.bench.cachegrindTitle }}" ${{ matrix.bench.cachegrindIterCount }} | tee -a ${{ matrix.bench.script }}-output.txt - - - name: Checkout Benchmark Results repository - uses: actions/checkout@v3 - with: - repository: ${{ matrix.benchResultsRepo.name }} - ref: ${{ matrix.benchResultsRepo.branch }} - token: ${{ secrets.BENCH_GITHUB_TOKEN }} - path: "./gh-pages" - - - name: Store ${{ matrix.bench.title }} result - uses: Roblox/rhysd-github-action-benchmark@v-luau - with: - name: ${{ matrix.bench.title }} - tool: "benchmarkluau" - output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json - github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - - - name: Store ${{ matrix.bench.title }} result (CacheGrind) - if: matrix.os == 'ubuntu-latest' - uses: Roblox/rhysd-github-action-benchmark@v-luau - with: - name: ${{ matrix.bench.title }} (CacheGrind) - tool: "roblox" - output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json - github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - - - name: Push benchmark results - if: github.event_name == 'push' - run: | - echo "Pushing benchmark results..." - cd gh-pages - git config user.name github-actions - git config user.email github@users.noreply.github.com - git add ./dev/bench/data.json - git commit -m "Add benchmarks results for ${{ github.sha }}" - git push - cd .. - - static-analysis: - name: luau-analyze - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest] - bench: - - { - script: "run-analyze", - timeout: 12, - title: "Luau Analyze", - cachegrindTitle: "Performance", - cachegrindIterCount: 20, - } - benchResultsRepo: - - { name: "luau-lang/benchmark-data", branch: "main" } - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v3 - with: - token: "${{ secrets.BENCH_GITHUB_TOKEN }}" - - - name: Build Luau - run: make config=release luau luau-analyze - - - uses: actions/setup-python@v4 - with: - python-version: "3.9" - architecture: "x64" - - - name: Install python dependencies - run: | - sudo pip install requests numpy scipy matplotlib ipython jupyter pandas sympy nose - - name: Install valgrind run: | sudo apt-get install valgrind - - name: Run Luau Analyze on static file - run: sudo python ./bench/measure_time.py ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee ${{ matrix.bench.script }}-output.txt + - name: Build Luau + run: CXX=${{ matrix.compiler }} make config=release CALLGRIND=1 luau - - name: Run ${{ matrix.bench.title }} (Cold Cachegrind) - run: sudo ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt + - name: Run benchmark + run: | + python bench/bench.py --callgrind --vm "./luau -O2" | tee output.txt - - name: Run ${{ matrix.bench.title }} (Warm Cachegrind) - run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt - - - name: Checkout Benchmark Results repository + - name: Checkout benchmark results uses: actions/checkout@v3 with: repository: ${{ matrix.benchResultsRepo.name }} @@ -236,26 +46,13 @@ jobs: token: ${{ secrets.BENCH_GITHUB_TOKEN }} path: "./gh-pages" - - name: Store ${{ matrix.bench.title }} result + - name: Store results uses: Roblox/rhysd-github-action-benchmark@v-luau with: - name: ${{ matrix.bench.title }} + name: callgrind ${{ matrix.compiler }} tool: "benchmarkluau" - - gh-pages-branch: "main" - output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json - github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - - - name: Store ${{ matrix.bench.title }} result (CacheGrind) - uses: Roblox/rhysd-github-action-benchmark@v-luau - with: - name: ${{ matrix.bench.title }} - tool: "roblox" - gh-pages-branch: "main" - output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json - github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} + output-file-path: ./output.txt + external-data-json-path: ./gh-pages/bench/data.json - name: Push benchmark results if: github.event_name == 'push' @@ -264,7 +61,7 @@ jobs: cd gh-pages git config user.name github-actions git config user.email github@users.noreply.github.com - git add ./dev/bench/data.json + git add ./bench/data.json git commit -m "Add benchmarks results for ${{ github.sha }}" git push cd .. diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 83060f5..5fe12be 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -21,6 +21,10 @@ #include #endif +#ifdef CALLGRIND +#include +#endif + #include LUAU_FASTFLAG(DebugLuauTimeTracing) @@ -166,6 +170,36 @@ static int lua_collectgarbage(lua_State* L) luaL_error(L, "collectgarbage must be called with 'count' or 'collect'"); } +#ifdef CALLGRIND +static int lua_callgrind(lua_State* L) +{ + const char* option = luaL_checkstring(L, 1); + + if (strcmp(option, "running") == 0) + { + int r = RUNNING_ON_VALGRIND; + lua_pushboolean(L, r); + return 1; + } + + if (strcmp(option, "zero") == 0) + { + CALLGRIND_ZERO_STATS; + return 0; + } + + if (strcmp(option, "dump") == 0) + { + const char* name = luaL_checkstring(L, 2); + + CALLGRIND_DUMP_STATS_AT(name); + return 0; + } + + luaL_error(L, "callgrind must be called with one of 'running', 'zero', 'dump'"); +} +#endif + void setupState(lua_State* L) { luaL_openlibs(L); @@ -174,6 +208,9 @@ void setupState(lua_State* L) {"loadstring", lua_loadstring}, {"require", lua_require}, {"collectgarbage", lua_collectgarbage}, +#ifdef CALLGRIND + {"callgrind", lua_callgrind}, +#endif {NULL, NULL}, }; diff --git a/Makefile b/Makefile index 1082666..b807789 100644 --- a/Makefile +++ b/Makefile @@ -93,6 +93,10 @@ ifeq ($(config),fuzz) LDFLAGS+=-fsanitize=address,fuzzer endif +ifneq ($(CALLGRIND),) + CXXFLAGS+=-DCALLGRIND=$(CALLGRIND) +endif + # target-specific flags $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include diff --git a/bench/bench.py b/bench/bench.py index 67fc8cf..e78e96a 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -40,6 +40,7 @@ argumentParser.add_argument('--results', dest='results',type=str,nargs='*',help= argumentParser.add_argument('--run-test', action='store', default=None, help='Regex test filter') argumentParser.add_argument('--extra-loops', action='store',type=int,default=0, help='Amount of times to loop over one test (one test already performs multiple runs)') argumentParser.add_argument('--filename', action='store',type=str,default='bench', help='File name for graph and results file') +argumentParser.add_argument('--callgrind', dest='callgrind',action='store_const',const=1,default=0,help='Use callgrind to run benchmarks') if matplotlib != None: argumentParser.add_argument('--absolute', dest='absolute',action='store_const',const=1,default=0,help='Display absolute values instead of relative (enabled by default when benchmarking a single VM)') @@ -55,6 +56,9 @@ argumentParser.add_argument('--no-print-influx-debugging', action='store_false', argumentParser.add_argument('--no-print-final-summary', action='store_false', dest='print_final_summary', help="Don't print a table summarizing the results after all tests are run") +# Assume 2.5 IPC on a 4 GHz CPU; this is obviously incorrect but it allows us to display simulated instruction counts using regular time units +CALLGRIND_INSN_PER_SEC = 2.5 * 4e9 + def arrayRange(count): result = [] @@ -71,6 +75,21 @@ def arrayRangeOffset(count, offset): return result +def getCallgrindOutput(lines): + result = [] + name = None + + for l in lines: + if l.startswith("desc: Trigger: Client Request: "): + name = l[31:].strip() + elif l.startswith("summary: ") and name != None: + insn = int(l[9:]) + # Note: we only run each bench once under callgrind so we only report a single time per run; callgrind instruction count variance is ~0.01% so it might as well be zero + result += "|><|" + name + "|><|" + str(insn / CALLGRIND_INSN_PER_SEC * 1000.0) + "||_||" + name = None + + return "".join(result) + def getVmOutput(cmd): if os.name == "nt": try: @@ -79,6 +98,14 @@ def getVmOutput(cmd): exit(1) except: return "" + elif arguments.callgrind: + try: + subprocess.check_call("valgrind --tool=callgrind --callgrind-out-file=callgrind.out --combine-dumps=yes --dump-line=no " + cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, cwd=scriptdir) + file = open(os.path.join(scriptdir, "callgrind.out"), "r") + lines = file.readlines() + return getCallgrindOutput(lines) + except: + return "" else: with subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, cwd=scriptdir) as p: # Try to lock to a single processor diff --git a/bench/bench_support.lua b/bench/bench_support.lua index 171b8da..a9608ec 100644 --- a/bench/bench_support.lua +++ b/bench/bench_support.lua @@ -5,6 +5,16 @@ bench.runs = 20 bench.extraRuns = 4 function bench.runCode(f, description) + -- Under Callgrind, run the test only once and measure just the execution cost + if callgrind and callgrind("running") then + if collectgarbage then collectgarbage() end + + callgrind("zero") + f() -- unfortunately we can't easily separate setup cost from runtime cost in f unless it calls callgrind() + callgrind("dump", description) + return + end + local timeTable = {} for i = 1,bench.runs + bench.extraRuns do From 2460e389989558f1db89eceb15c9267a4df58704 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Tue, 5 Jul 2022 08:23:09 -0700 Subject: [PATCH 04/10] bench: Implement luau-analyze and luau --compile benchmarks (#575) This change adds another file for benchmarking luau-analyze and sets up benchmarks for both non-strict/strict modes for analysis and all three optimization levels for compilation performance. To avoid issues with race conditions on repository update we do all this in the same job in benchmark.yml. To be able to benchmark both modes from a single file, luau-analyze gains --mode argument which allows to override the default typechecking mode. Not sure if we'll want this to be a hard override on top of the module-specified mode in the future, but this works for now. --- .github/workflows/benchmark-dev.yml | 6 +- .github/workflows/benchmark.yml | 52 +- CLI/Analyze.cpp | 18 +- .../LuauPolyfillMap.lua | 1 - bench/other/regex.lua | 2089 +++++++++++++++++ 5 files changed, 2152 insertions(+), 14 deletions(-) rename bench/{static_analysis => other}/LuauPolyfillMap.lua (99%) create mode 100644 bench/other/regex.lua diff --git a/.github/workflows/benchmark-dev.yml b/.github/workflows/benchmark-dev.yml index 21f9559..03ff9d6 100644 --- a/.github/workflows/benchmark-dev.yml +++ b/.github/workflows/benchmark-dev.yml @@ -220,13 +220,13 @@ jobs: sudo apt-get install valgrind - name: Run Luau Analyze on static file - run: sudo python ./bench/measure_time.py ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee ${{ matrix.bench.script }}-output.txt + run: sudo python ./bench/measure_time.py ./build/release/luau-analyze bench/other/LuauPolyfillMap.lua | tee ${{ matrix.bench.script }}-output.txt - name: Run ${{ matrix.bench.title }} (Cold Cachegrind) - run: sudo ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt + run: sudo ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}Cold" 1 ./build/release/luau-analyze bench/other/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt - name: Run ${{ matrix.bench.title }} (Warm Cachegrind) - run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}" 1 ./build/release/luau-analyze bench/static_analysis/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt + run: sudo bash ./scripts/run-with-cachegrind.sh python ./bench/measure_time.py "${{ matrix.bench.cachegrindTitle}}" 1 ./build/release/luau-analyze bench/other/LuauPolyfillMap.lua | tee -a ${{ matrix.bench.script }}-output.txt - name: Checkout Benchmark Results repository uses: actions/checkout@v3 diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 4829682..9d26186 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -32,11 +32,33 @@ jobs: sudo apt-get install valgrind - name: Build Luau - run: CXX=${{ matrix.compiler }} make config=release CALLGRIND=1 luau + run: CXX=${{ matrix.compiler }} make config=release CALLGRIND=1 luau luau-analyze - - name: Run benchmark + - name: Run benchmark (bench) run: | - python bench/bench.py --callgrind --vm "./luau -O2" | tee output.txt + python bench/bench.py --callgrind --vm "./luau -O2" | tee -a bench-output.txt + + - name: Run benchmark (analyze) + run: | + filter() { + awk '/.*I\s+refs:\s+[0-9,]+/ {gsub(",", "", $4); X=$4} END {print "SUCCESS: '$1' : " X/1e7 "ms +/- 0% on luau-analyze"}' + } + valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-nonstrict | tee -a analyze-output.txt + valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-strict | tee -a analyze-output.txt + valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/regex.lua 2>&1 | filter regex-nonstrict | tee -a analyze-output.txt + valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/regex.lua 2>&1 | filter regex-strict | tee -a analyze-output.txt + + - name: Run benchmark (compile) + run: | + filter() { + awk '/.*I\s+refs:\s+[0-9,]+/ {gsub(",", "", $4); X=$4} END {print "SUCCESS: '$1' : " X/1e7 "ms +/- 0% on luau --compile"}' + } + valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O0 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O1 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-O2 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O0 bench/other/regex.lua 2>&1 | filter regex-O0 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O1 bench/other/regex.lua 2>&1 | filter regex-O1 | tee -a compile-output.txt + valgrind --tool=callgrind ./luau --compile=null -O2 bench/other/regex.lua 2>&1 | filter regex-O2 | tee -a compile-output.txt - name: Checkout benchmark results uses: actions/checkout@v3 @@ -46,13 +68,29 @@ jobs: token: ${{ secrets.BENCH_GITHUB_TOKEN }} path: "./gh-pages" - - name: Store results + - name: Store results (bench) uses: Roblox/rhysd-github-action-benchmark@v-luau with: name: callgrind ${{ matrix.compiler }} tool: "benchmarkluau" - output-file-path: ./output.txt - external-data-json-path: ./gh-pages/bench/data.json + output-file-path: ./bench-output.txt + external-data-json-path: ./gh-pages/bench.json + + - name: Store results (analyze) + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: luau-analyze + tool: "benchmarkluau" + output-file-path: ./analyze-output.txt + external-data-json-path: ./gh-pages/analyze.json + + - name: Store results (compile) + uses: Roblox/rhysd-github-action-benchmark@v-luau + with: + name: luau --compile + tool: "benchmarkluau" + output-file-path: ./compile-output.txt + external-data-json-path: ./gh-pages/compile.json - name: Push benchmark results if: github.event_name == 'push' @@ -61,7 +99,7 @@ jobs: cd gh-pages git config user.name github-actions git config user.email github@users.noreply.github.com - git add ./bench/data.json + git add *.json git commit -m "Add benchmarks results for ${{ github.sha }}" git push cd .. diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 81db7c3..f07f9a0 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -8,6 +8,10 @@ #include "FileUtils.h" +#ifdef CALLGRIND +#include +#endif + LUAU_FASTFLAG(DebugLuauTimeTracing) LUAU_FASTFLAG(LuauTypeMismatchModuleNameResolution) @@ -112,6 +116,7 @@ static void displayHelp(const char* argv0) printf("Available options:\n"); printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); + printf(" --mode=strict: default to strict mode when typechecking\n"); printf(" --timetrace: record compiler time tracing information into trace.json\n"); } @@ -178,9 +183,9 @@ struct CliConfigResolver : Luau::ConfigResolver mutable std::unordered_map configCache; mutable std::vector> configErrors; - CliConfigResolver() + CliConfigResolver(Luau::Mode mode) { - defaultConfig.mode = Luau::Mode::Nonstrict; + defaultConfig.mode = mode; } const Luau::Config& getConfig(const Luau::ModuleName& name) const override @@ -229,6 +234,7 @@ int main(int argc, char** argv) } ReportFormat format = ReportFormat::Default; + Luau::Mode mode = Luau::Mode::Nonstrict; bool annotate = false; for (int i = 1; i < argc; ++i) @@ -240,6 +246,8 @@ int main(int argc, char** argv) format = ReportFormat::Luacheck; else if (strcmp(argv[i], "--formatter=gnu") == 0) format = ReportFormat::Gnu; + else if (strcmp(argv[i], "--mode=strict") == 0) + mode = Luau::Mode::Strict; else if (strcmp(argv[i], "--annotate") == 0) annotate = true; else if (strcmp(argv[i], "--timetrace") == 0) @@ -258,12 +266,16 @@ int main(int argc, char** argv) frontendOptions.retainFullTypeGraphs = annotate; CliFileResolver fileResolver; - CliConfigResolver configResolver; + CliConfigResolver configResolver(mode); Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions); Luau::registerBuiltinTypes(frontend.typeChecker); Luau::freeze(frontend.typeChecker.globalTypes); +#ifdef CALLGRIND + CALLGRIND_ZERO_STATS; +#endif + std::vector files = getSourceFiles(argc, argv); int failed = 0; diff --git a/bench/static_analysis/LuauPolyfillMap.lua b/bench/other/LuauPolyfillMap.lua similarity index 99% rename from bench/static_analysis/LuauPolyfillMap.lua rename to bench/other/LuauPolyfillMap.lua index 1cfd018..1f957d4 100644 --- a/bench/static_analysis/LuauPolyfillMap.lua +++ b/bench/other/LuauPolyfillMap.lua @@ -1,5 +1,4 @@ -- This file is part of the Roblox luau-polyfill repository and is licensed under MIT License; see LICENSE.txt for details ---!nonstrict -- #region Array -- Array related local Array = {} diff --git a/bench/other/regex.lua b/bench/other/regex.lua new file mode 100644 index 0000000..270ab3d --- /dev/null +++ b/bench/other/regex.lua @@ -0,0 +1,2089 @@ +--[[ + PCRE2-based RegEx implemention for Luau + Version 1.0.0a2 + BSD 2-Clause Licence + Copyright © 2020 - Blockzez (devforum.roblox.com/u/Blockzez and github.com/Blockzez) + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +]] +--[[ Settings ]]-- +-- You can change them here +local options = { + -- The maximum cache size for regex so the patterns are cached so it doesn't recompile the pattern + -- The only accepted value are number values >= 0, strings that can be automatically coered to numbers that are >= 0, false and nil + -- Do note that empty regex patterns (comment-only patterns included) are never cached regardless + -- The default is 256 + cacheSize = 256, + + -- A boolean that determines whether this use unicode data + -- If this value evalulates to false, you can remove _unicodechar_category, _scripts and _xuc safely and it'll now error if: + -- - You try to compile a RegEx with unicode flag + -- - You try to use the \p pattern + -- The default is true + unicodeData = false, +}; + +-- +local u_categories = options.unicodeData and require(script:WaitForChild("_unicodechar_category")); +local chr_scripts = options.unicodeData and require(script:WaitForChild("_scripts")); +local xuc_chr = options.unicodeData and require(script:WaitForChild("_xuc")); +local proxy = setmetatable({ }, { __mode = 'k' }); +local re, re_m, match_m = { }, { }, { }; +local lockmsg; + +--[[ Functions ]]-- +local function to_str_arr(self, init) + if init then + self = string.sub(self, utf8.offset(self, init)); + end; + local len = utf8.len(self); + if len <= 1999 then + return { n = len, s = self, utf8.codepoint(self, 1, #self) }; + end; + local clen = math.ceil(len / 1999); + local ret = table.create(len); + local p = 1; + for i = 1, clen do + local c = table.pack(utf8.codepoint(self, utf8.offset(self, i * 1999 - 1998), utf8.offset(self, i * 1999 - (i == clen and 1998 - ((len - 1) % 1999 + 1) or - 1)) - 1)); + table.move(c, 1, c.n, p, ret); + p += c.n; + end; + ret.s, ret.n = self, len; + return ret; +end; + +local function from_str_arr(self) + local len = self.n or #self; + if len <= 7997 then + return utf8.char(table.unpack(self)); + end; + local clen = math.ceil(len / 7997); + local r = table.create(clen); + for i = 1, clen do + r[i] = utf8.char(table.unpack(self, i * 7997 - 7996, i * 7997 - (i == clen and 7997 - ((len - 1) % 7997 + 1) or 0))); + end; + return table.concat(r); +end; + +local function utf8_sub(self, i, j) + j = utf8.offset(self, j); + return string.sub(self, utf8.offset(self, i), j and j - 1); +end; + +-- +local flag_map = { + a = 'anchored', i = 'caseless', m = 'multiline', s = 'dotall', u = 'unicode', U = 'ungreedy', x ='extended', +}; + +local posix_class_names = { + alnum = true, alpha = true, ascii = true, blank = true, cntrl = true, digit = true, graph = true, lower = true, print = true, punct = true, space = true, upper = true, word = true, xdigit = true, +}; + +local escape_chars = { + -- grouped + -- digit, spaces and words + [0x44] = { "class", "digit", true }, [0x53] = { "class", "space", true }, [0x57] = { "class", "word", true }, + [0x64] = { "class", "digit", false }, [0x73] = { "class", "space", false }, [0x77] = { "class", "word", false }, + -- horizontal/vertical whitespace and newline + [0x48] = { "class", "blank", true }, [0x56] = { "class", "vertical_tab", true }, + [0x68] = { "class", "blank", false }, [0x76] = { "class", "vertical_tab", false }, + [0x4E] = { 0x4E }, [0x52] = { 0x52 }, + + -- not grouped + [0x42] = 0x08, + [0x6E] = 0x0A, [0x72] = 0x0D, [0x74] = 0x09, +}; + +local b_escape_chars = { + -- word boundary and not word boundary + [0x62] = { 0x62, { "class", "word", false } }, [0x42] = { 0x42, { "class", "word", false } }, + + -- keep match out + [0x4B] = { 0x4B }, + + -- start & end of string + [0x47] = { 0x47 }, [0x4A] = { 0x4A }, [0x5A] = { 0x5A }, [0x7A] = { 0x7A }, +}; + +local valid_categories = { + C = true, Cc = true, Cf = true, Cn = true, Co = true, Cs = true, + L = true, Ll = true, Lm = true, Lo = true, Lt = true, Lu = true, + M = true, Mc = true, Me = true, Mn = true, + N = true, Nd = true, Nl = true, No = true, + P = true, Pc = true, Pd = true, Pe = true, Pf = true, Pi = true, Po = true, Ps = true, + S = true, Sc = true, Sk = true, Sm = true, So = true, + Z = true, Zl = true, Zp = true, Zs = true, + + Xan = true, Xps = true, Xsp = true, Xuc = true, Xwd = true, +}; + +local class_ascii_punct = { + [0x21] = true, [0x22] = true, [0x23] = true, [0x24] = true, [0x25] = true, [0x26] = true, [0x27] = true, [0x28] = true, [0x29] = true, [0x2A] = true, [0x2B] = true, [0x2C] = true, [0x2D] = true, [0x2E] = true, [0x2F] = true, + [0x3A] = true, [0x3B] = true, [0x3C] = true, [0x3D] = true, [0x3E] = true, [0x3F] = true, [0x40] = true, [0x5B] = true, [0x5C] = true, [0x5D] = true, [0x5E] = true, [0x5F] = true, [0x60] = true, [0x7B] = true, [0x7C] = true, + [0x7D] = true, [0x7E] = true, +}; + +local end_str = { 0x24 }; +local dot = { 0x2E }; +local beginning_str = { 0x5E }; +local alternation = { 0x7C }; + +local function check_re(re_type, name, func) + if re_type == "Match" then + return function(...) + local arg_n = select('#', ...); + if arg_n < 1 then + error("missing argument #1 (Match expected)", 2); + end; + local arg0, arg1 = ...; + if not (proxy[arg0] and proxy[arg0].name == "Match") then + error(string.format("invalid argument #1 to %q (Match expected, got %s)", name, typeof(arg0)), 2); + else + arg0 = proxy[arg0]; + end; + if name == "group" or name == "span" then + if arg1 == nil then + arg1 = 0; + end; + end; + return func(arg0, arg1); + end; + end; + return function(...) + local arg_n = select('#', ...); + if arg_n < 1 then + error("missing argument #1 (RegEx expected)", 2); + elseif arg_n < 2 then + error("missing argument #2 (string expected)", 2); + end; + local arg0, arg1, arg2, arg3, arg4, arg5 = ...; + if not (proxy[arg0] and proxy[arg0].name == "RegEx") then + if type(arg0) ~= "string" and type(arg0) ~= "number" then + error(string.format("invalid argument #1 to %q (RegEx expected, got %s)", name, typeof(arg0)), 2); + end; + arg0 = re.fromstring(arg0); + elseif name == "sub" then + if type(arg2) == "number" then + arg2 ..= ''; + elseif type(arg2) ~= "string" then + error(string.format("invalid argument #3 to 'sub' (string expected, got %s)", typeof(arg2)), 2); + end; + elseif type(arg1) == "number" then + arg1 ..= ''; + elseif type(arg1) ~= "string" then + error(string.format("invalid argument #2 to %q (string expected, got %s)", name, typeof(arg1)), 2); + end; + if name ~= "sub" and name ~= "split" then + local init_type = typeof(arg2); + if init_type ~= 'nil' then + arg2 = tonumber(arg2); + if not arg2 then + error(string.format("invalid argument #3 to %q (number expected, got %s)", name, init_type), 2); + elseif arg2 < 0 then + arg2 = #arg1 + math.floor(arg2 + 0.5) + 1; + else + arg2 = math.max(math.floor(arg2 + 0.5), 1); + end; + end; + end; + arg0 = proxy[arg0]; + if name == "match" or name == "matchiter" then + arg3 = ...; + elseif name == "sub" then + arg5 = ...; + end; + return func(arg0, arg1, arg2, arg3, arg4, arg5); + end; +end; + +--[[ Matches ]]-- +local function match_tostr(self) + local spans = proxy[self].spans; + local s_start, s_end = spans[0][1], spans[0][2]; + if s_end <= s_start then + return string.format("Match (%d..%d, empty)", s_start, s_end - 1); + end; + return string.format("Match (%d..%d): %s", s_start, s_end - 1, utf8_sub(spans.input, s_start, s_end)); +end; + +local function new_match(span_arr, group_id, re, str) + span_arr.source, span_arr.input = re, str; + local object = newproxy(true); + local object_mt = getmetatable(object); + object_mt.__metatable = lockmsg; + object_mt.__index = setmetatable(span_arr, match_m); + object_mt.__tostring = match_tostr; + + proxy[object] = { name = "Match", spans = span_arr, group_id = group_id }; + return object; +end; + +match_m.group = check_re('Match', 'group', function(self, group_id) + local span = self.spans[type(group_id) == "number" and group_id or self.group_id[group_id]]; + if not span then + return nil; + end; + return utf8_sub(self.spans.input, span[1], span[2]); +end); + +match_m.span = check_re('Match', 'span', function(self, group_id) + local span = self.spans[type(group_id) == "number" and group_id or self.group_id[group_id]]; + if not span then + return nil; + end; + return span[1], span[2] - 1; +end); + +match_m.groups = check_re('Match', 'groups', function(self) + local spans = self.spans; + if spans.n > 0 then + local ret = table.create(spans.n); + for i = 0, spans.n do + local v = spans[i]; + if v then + ret[i] = utf8_sub(spans.input, v[1], v[2]); + end; + end; + return table.unpack(ret, 1, spans.n); + end; + return utf8_sub(spans.input, spans[0][1], spans[0][2]); +end); + +match_m.groupdict = check_re('Match', 'groupdict', function(self) + local spans = self.spans; + local ret = { }; + for k, v in pairs(self.group_id) do + v = spans[v]; + if v then + ret[k] = utf8_sub(spans.input, v[1], v[2]); + end; + end; + return ret; +end); + +match_m.grouparr = check_re('Match', 'groupdict', function(self) + local spans = self.spans; + local ret = table.create(spans.n); + for i = 0, spans.n do + local v = spans[i]; + if v then + ret[i] = utf8_sub(spans.input, v[1], v[2]); + end; + end; + ret.n = spans.n; + return ret; +end); + +-- +local line_verbs = { + CR = 0, LF = 1, CRLF = 2, ANYRLF = 3, ANY = 4, NUL = 5, +}; +local function is_newline(str_arr, i, verb_flags) + local line_verb_n = verb_flags.newline; + local chr = str_arr[i]; + if line_verb_n == 0 then + -- carriage return + return chr == 0x0D; + elseif line_verb_n == 2 then + -- carriage return followed by line feed + return chr == 0x0A and str_arr[i - 1] == 0x20; + elseif line_verb_n == 3 then + -- any of the above + return chr == 0x0A or chr == 0x0D; + elseif line_verb_n == 4 then + -- any of Unicode newlines + return chr == 0x0A or chr == 0x0B or chr == 0x0C or chr == 0x0D or chr == 0x85 or chr == 0x2028 or chr == 0x2029; + elseif line_verb_n == 5 then + -- null + return chr == 0; + end; + -- linefeed + return chr == 0x0A; +end; + + +local function tkn_char_match(tkn_part, str_arr, i, flags, verb_flags) + local chr = str_arr[i]; + if not chr then + return false; + elseif flags.ignoreCase and chr >= 0x61 and chr <= 0x7A then + chr -= 0x20; + end; + if type(tkn_part) == "number" then + return tkn_part == chr; + elseif tkn_part[1] == "charset" then + for _, v in ipairs(tkn_part[3]) do + if tkn_char_match(v, str_arr, i, flags, verb_flags) then + return not tkn_part[2]; + end; + end; + return tkn_part[2]; + elseif tkn_part[1] == "range" then + return chr >= tkn_part[2] and chr <= tkn_part[3] or flags.ignoreCase and chr >= 0x41 and chr <= 0x5A and (chr + 0x20) >= tkn_part[2] and (chr + 0x20) <= tkn_part[3]; + elseif tkn_part[1] == "class" then + local char_class = tkn_part[2]; + local negate = tkn_part[3]; + local match = false; + -- if and elseifs :( + -- Might make these into tables in the future + if char_class == "xdigit" then + match = chr >= 0x30 and chr <= 0x39 or chr >= 0x41 and chr <= 0x46 or chr >= 0x61 and chr <= 0x66; + elseif char_class == "ascii" then + match = chr <= 0x7F; + -- cannot be accessed through POSIX classes + elseif char_class == "vertical_tab" then + match = chr >= 0x0A and chr <= 0x0D or chr == 0x2028 or chr == 0x2029; + -- + elseif flags.unicode then + local current_category = u_categories[chr] or 'Cn'; + local first_category = current_category:sub(1, 1); + if char_class == "alnum" then + match = first_category == 'L' or current_category == 'Nl' or current_category == 'Nd'; + elseif char_class == "alpha" then + match = first_category == 'L' or current_category == 'Nl'; + elseif char_class == "blank" then + match = current_category == 'Zs' or chr == 0x09; + elseif char_class == "cntrl" then + match = current_category == 'Cc'; + elseif char_class == "digit" then + match = current_category == 'Nd'; + elseif char_class == "graph" then + match = first_category ~= 'P' and first_category ~= 'C'; + elseif char_class == "lower" then + match = current_category == 'Ll'; + elseif char_class == "print" then + match = first_category ~= 'C'; + elseif char_class == "punct" then + match = first_category == 'P'; + elseif char_class == "space" then + match = first_category == 'Z' or chr >= 0x09 and chr <= 0x0D; + elseif char_class == "upper" then + match = current_category == 'Lu'; + elseif char_class == "word" then + match = first_category == 'L' or current_category == 'Nl' or current_category == 'Nd' or current_category == 'Pc'; + end; + elseif char_class == "alnum" then + match = chr >= 0x30 and chr <= 0x39 or chr >= 0x41 and chr <= 0x5A or chr >= 0x61 and chr <= 0x7A; + elseif char_class == "alpha" then + match = chr >= 0x41 and chr <= 0x5A or chr >= 0x61 and chr <= 0x7A; + elseif char_class == "blank" then + match = chr == 0x09 or chr == 0x20; + elseif char_class == "cntrl" then + match = chr <= 0x1F or chr == 0x7F; + elseif char_class == "digit" then + match = chr >= 0x30 and chr <= 0x39; + elseif char_class == "graph" then + match = chr >= 0x21 and chr <= 0x7E; + elseif char_class == "lower" then + match = chr >= 0x61 and chr <= 0x7A; + elseif char_class == "print" then + match = chr >= 0x20 and chr <= 0x7E; + elseif char_class == "punct" then + match = class_ascii_punct[chr]; + elseif char_class == "space" then + match = chr >= 0x09 and chr <= 0x0D or chr == 0x20; + elseif char_class == "upper" then + match = chr >= 0x41 and chr <= 0x5A; + elseif char_class == "word" then + match = chr >= 0x30 and chr <= 0x39 or chr >= 0x41 and chr <= 0x5A or chr >= 0x61 and chr <= 0x7A or chr == 0x5F; + end; + if negate then + return not match; + end; + return match; + elseif tkn_part[1] == "category" then + local chr_category = u_categories[chr] or 'Cn'; + local category_v = tkn_part[3]; + local category_len = #category_v; + if category_len == 3 then + local match = false; + if category_v == "Xan" or category_v == "Xwd" then + match = chr_category:find("^[LN]") or category_v == "Xwd" and chr == 0x5F; + elseif category_v == "Xps" or category_v == "Xsp" then + match = chr_category:sub(1, 1) == 'Z' or chr >= 0x09 and chr <= 0x0D; + elseif category_v == "Xuc" then + match = tkn_char_match(xuc_chr, str_arr, i, flags, verb_flags); + end; + if tkn_part[2] then + return not match; + end + return match; + elseif chr_category:sub(1, category_len) == category_v then + return not tkn_part[2]; + end; + return tkn_part[2]; + elseif tkn_part[1] == 0x2E then + return flags.dotAll or not is_newline(str_arr, i, verb_flags); + elseif tkn_part[1] == 0x4E then + return not is_newline(str_arr, i, verb_flags); + elseif tkn_part[1] == 0x52 then + if verb_flags.newline_seq == 0 then + -- CR, LF or CRLF + return chr == 0x0A or chr == 0x0D; + end; + -- any unicode newline + return chr == 0x0A or chr == 0x0B or chr == 0x0C or chr == 0x0D or chr == 0x85 or chr == 0x2028 or chr == 0x2029; + end; + return false; +end; + +local function find_alternation(token, i, count) + while true do + local v = token[i]; + local is_table = type(v) == "table"; + if v == alternation then + return i, count; + elseif is_table and v[1] == 0x28 then + if count then + count += v.count; + end; + i = v[3]; + elseif is_table and v[1] == "quantifier" and type(v[5]) == "table" and v[5][1] == 0x28 then + if count then + count += v[5].count; + end; + i = v[5][3]; + elseif not v or is_table and v[1] == 0x29 then + return nil, count; + elseif count then + if is_table and v[1] == "quantifier" then + count += v[3]; + else + count += 1; + end; + end; + i += 1; + end; +end; + +local function re_rawfind(token, str_arr, init, flags, verb_flags, as_bool) + local tkn_i, str_i, start_i = 0, init, init; + local states = { }; + while tkn_i do + if tkn_i == 0 then + tkn_i += 1; + local next_alt = find_alternation(token, tkn_i); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + continue; + end; + local ctkn = token[tkn_i]; + local tkn_type = type(ctkn) == "table" and ctkn[1]; + if not ctkn then + break; + elseif ctkn == "ACCEPT" then + local not_lookaround = true; + local close_i = tkn_i; + repeat + close_i += 1; + local is_table = type(token[close_i]) == "table"; + local close_i_tkn = token[close_i]; + if is_table and (close_i_tkn[1] == 0x28 or close_i_tkn[1] == "quantifier" and type(close_i_tkn[5]) == "table" and close_i_tkn[5][1] == 0x28) then + close_i = close_i_tkn[1] == "quantifier" and close_i_tkn[5][3] or close_i_tkn[3]; + elseif is_table and close_i_tkn[1] == 0x29 and (close_i_tkn[4] == 0x21 or close_i_tkn[4] == 0x3D) then + not_lookaround = false; + tkn_i = close_i; + break; + end; + until not close_i_tkn; + if not_lookaround then + break; + end; + elseif ctkn == "PRUNE" or ctkn == "SKIP" then + table.insert(states, 1, { ctkn, str_i }); + tkn_i += 1; + elseif tkn_type == 0x28 then + table.insert(states, 1, { "group", tkn_i, str_i, nil, ctkn[2], ctkn[3], ctkn[4] }); + tkn_i += 1; + local next_alt, count = find_alternation(token, tkn_i, (ctkn[4] == 0x21 or ctkn[4] == 0x3D) and ctkn[5] and 0); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + if count then + str_i -= count; + end; + elseif tkn_type == 0x29 and ctkn[4] ~= 0x21 then + if ctkn[4] == 0x21 or ctkn[4] == 0x3D then + while true do + local selected_match_start; + local selected_state = table.remove(states, 1); + if selected_state[1] == "group" and selected_state[2] == ctkn[3] then + if (ctkn[4] == 0x21 or ctkn[4] == 0x3D) and not ctkn[5] then + str_i = selected_state[3]; + end; + if selected_match_start then + table.insert(states, 1, selected_match_start); + end; + break; + elseif selected_state[1] == "matchStart" and not selected_match_start and ctkn[4] == 0x3D then + selected_match_start = selected_state; + end; + end; + elseif ctkn[4] == 0x3E then + repeat + local selected_state = table.remove(states, 1); + until not selected_state or selected_state[1] == "group" and selected_state[2] == ctkn[3]; + else + for i, v in ipairs(states) do + if v[1] == "group" and v[2] == ctkn[3] then + if v.jmp then + -- recursive match + tkn_i = v.jmp; + end; + v[4] = str_i; + if v[7] == "quantifier" and v[10] + 1 < v[9] then + if token[ctkn[3]][4] ~= "lazy" or v[10] + 1 < v[8] then + tkn_i = ctkn[3]; + end; + local ctkn1 = token[ctkn[3]]; + local new_group = { "group", v[2], str_i, nil, ctkn1[5][2], ctkn1[5][3], "quantifier", ctkn1[2], ctkn1[3], v[10] + 1, v[11], ctkn1[4] }; + table.insert(states, 1, new_group); + if v[11] then + table.insert(states, 1, { "alternation", v[11], str_i }); + end; + end; + break; + end; + end; + end; + tkn_i += 1; + elseif tkn_type == 0x4B then + table.insert(states, 1, { "matchStart", str_i }); + tkn_i += 1; + elseif tkn_type == 0x7C then + local close_i = tkn_i; + repeat + close_i += 1; + local is_table = type(token[close_i]) == "table"; + local close_i_tkn = token[close_i]; + if is_table and (close_i_tkn[1] == 0x28 or close_i_tkn[1] == "quantifier" and type(close_i_tkn[5]) == "table" and close_i_tkn[5][1] == 0x28) then + close_i = close_i_tkn[1] == "quantifier" and close_i_tkn[5][3] or close_i_tkn[3]; + end; + until is_table and close_i_tkn[1] == 0x29 or not close_i_tkn; + if token[close_i] then + for _, v in ipairs(states) do + if v[1] == "group" and v[6] == close_i then + tkn_i = v[6]; + break; + end; + end; + else + tkn_i = close_i; + end; + elseif tkn_type == "recurmatch" then + table.insert(states, 1, { "group", ctkn[3], str_i, nil, nil, token[ctkn[3]][3], nil, jmp = tkn_i }); + tkn_i = ctkn[3] + 1; + local next_alt, count = find_alternation(token, tkn_i); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + else + local match; + if ctkn == "FAIL" then + match = false; + elseif tkn_type == 0x29 then + repeat + local selected_state = table.remove(states, 1); + until selected_state[1] == "group" and selected_state[2] == ctkn[3]; + elseif tkn_type == "quantifier" then + if type(ctkn[5]) == "table" and ctkn[5][1] == 0x28 then + local next_alt = find_alternation(token, tkn_i + 1); + if next_alt then + table.insert(states, 1, { "alternation", next_alt, str_i }); + end; + table.insert(states, next_alt and 2 or 1, { "group", tkn_i, str_i, nil, ctkn[5][2], ctkn[5][3], "quantifier", ctkn[2], ctkn[3], 0, next_alt, ctkn[4] }); + if ctkn[4] == "lazy" and ctkn[2] == 0 then + tkn_i = ctkn[5][3]; + end; + match = true; + else + local start_i, end_i; + local pattern_count = 1; + local is_backref = type(ctkn[5]) == "table" and ctkn[5][1] == "backref"; + if is_backref then + pattern_count = 0; + local group_n = ctkn[5][2]; + for _, v in ipairs(states) do + if v[1] == "group" and v[5] == group_n then + start_i, end_i = v[3], v[4]; + pattern_count = end_i - start_i; + break; + end; + end; + end; + local min_max_i = str_i + ctkn[2] * pattern_count; + local mcount = 0; + while mcount < ctkn[3] do + if is_backref then + if start_i and end_i then + local org_i = str_i; + if utf8_sub(str_arr.s, start_i, end_i) ~= utf8_sub(str_arr.s, org_i, str_i + pattern_count) then + break; + end; + else + break; + end; + elseif not tkn_char_match(ctkn[5], str_arr, str_i, flags, verb_flags) then + break; + end; + str_i += pattern_count; + mcount += 1; + end; + match = mcount >= ctkn[2]; + if match and ctkn[4] ~= "possessive" then + if ctkn[4] == "lazy" then + min_max_i, str_i = str_i, min_max_i; + end; + table.insert(states, 1, { "quantifier", tkn_i, str_i, math.min(min_max_i, str_arr.n + 1), (ctkn[4] == "lazy" and 1 or -1) * pattern_count }); + end; + end; + elseif tkn_type == "backref" then + local start_i, end_i; + local group_n = ctkn[2]; + for _, v in ipairs(states) do + if v[1] == "group" and v[5] == group_n then + start_i, end_i = v[3], v[4]; + break; + end; + end; + if start_i and end_i then + local org_i = str_i; + str_i += end_i - start_i; + match = utf8_sub(str_arr.s, start_i, end_i) == utf8_sub(str_arr.s, org_i, str_i); + end; + else + local chr = str_arr[str_i]; + if tkn_type == 0x24 or tkn_type == 0x5A or tkn_type == 0x7A then + match = str_i == str_arr.n + 1 or tkn_type == 0x24 and flags.multiline and is_newline(str_arr, str_i + 1, verb_flags) or tkn_type == 0x5A and str_i == str_arr.n and is_newline(str_arr, str_i, verb_flags); + elseif tkn_type == 0x5E or tkn_type == 0x41 or tkn_type == 0x47 then + match = str_i == 1 or tkn_type == 0x5E and flags.multiline and is_newline(str_arr, str_i - 1, verb_flags) or tkn_type == 0x47 and str_i == init; + elseif tkn_type == 0x42 or tkn_type == 0x62 then + local start_m = str_i == 1 or flags.multiline and is_newline(str_arr, str_i - 1, verb_flags); + local end_m = str_i == str_arr.n + 1 or flags.multiline and is_newline(str_arr, str_i, verb_flags); + local w_m = tkn_char_match(ctkn[2], str_arr[str_i - 1], flags) and 0 or tkn_char_match(ctkn[2], chr, flags) and 1; + if w_m == 0 then + match = end_m or not tkn_char_match(ctkn[2], chr, flags); + elseif w_m then + match = start_m or not tkn_char_match(ctkn[2], str_arr[str_i - 1], flags); + end; + if tkn_type == 0x42 then + match = not match; + end; + else + match = tkn_char_match(ctkn, str_arr, str_i, flags, verb_flags); + str_i += 1; + end; + end; + if not match then + while true do + local prev_type, prev_state = states[1] and states[1][1], states[1]; + if not prev_type or prev_type == "PRUNE" or prev_type == "SKIP" then + if prev_type then + table.clear(states); + end; + if start_i > str_arr.n then + if as_bool then + return false; + end; + return nil; + end; + start_i = prev_type == "SKIP" and prev_state[2] or start_i + 1; + tkn_i, str_i = 0, start_i; + break; + elseif prev_type == "alternation" then + tkn_i, str_i = prev_state[2], prev_state[3]; + local next_alt, count = find_alternation(token, tkn_i + 1); + if next_alt then + prev_state[2] = next_alt; + else + table.remove(states, 1); + end; + if count then + str_i -= count; + end; + break; + elseif prev_type == "group" then + if prev_state[7] == "quantifier" then + if prev_state[12] == "greedy" and prev_state[10] >= prev_state[8] + or prev_state[12] == "lazy" and prev_state[10] < prev_state[9] and not prev_state[13] then + tkn_i, str_i = prev_state[12] == "greedy" and prev_state[6] or prev_state[2], prev_state[3]; + if prev_state[12] == "greedy" then + table.remove(states, 1); + break; + elseif prev_state[10] >= prev_state[8] then + prev_state[13] = true; + break; + end; + end; + elseif prev_state[7] == 0x21 then + table.remove(states, 1); + tkn_i, str_i = prev_state[6], prev_state[3]; + break; + end; + elseif prev_type == "quantifier" then + if math.sign(prev_state[4] - prev_state[3]) == math.sign(prev_state[5]) then + prev_state[3] += prev_state[5]; + tkn_i, str_i = prev_state[2], prev_state[3]; + break; + end; + end; + -- keep match out state and recursive state, can be safely removed + -- prevents infinite loop + table.remove(states, 1); + end; + end; + tkn_i += 1; + end; + end; + if as_bool then + return true; + end; + local match_start_ran = false; + local span = table.create(token.group_n); + span[0], span.n = { start_i, str_i }, token.group_n; + for _, v in ipairs(states) do + if v[1] == "matchStart" and not match_start_ran then + span[0][1], match_start_ran = v[2], true; + elseif v[1] == "group" and v[5] and not span[v[5]] then + span[v[5]] = { v[3], v[4] }; + end; + end; + return span; +end; + +--[[ Methods ]]-- +re_m.test = check_re('RegEx', 'test', function(self, str, init) + return re_rawfind(self.token, to_str_arr(str, init), 1, self.flags, self.verb_flags, true); +end); + +re_m.match = check_re('RegEx', 'match', function(self, str, init, source) + local span = re_rawfind(self.token, to_str_arr(str, init), 1, self.flags, self.verb_flags, false); + if not span then + return nil; + end; + return new_match(span, self.group_id, source, str); +end); + +re_m.matchall = check_re('RegEx', 'matchall', function(self, str, init, source) + str = to_str_arr(str, init); + local i = 1; + return function() + local span = i <= str.n + 1 and re_rawfind(self.token, str, i, self.flags, self.verb_flags, false); + if not span then + return nil; + end; + i = span[0][2] + (span[0][1] >= span[0][2] and 1 or 0); + return new_match(span, self.group_id, source, str.s); + end; +end); + +local function insert_tokenized_sub(repl_r, str, span, tkn) + for _, v in ipairs(tkn) do + if type(v) == "table" then + if v[1] == "condition" then + if span[v[2]] then + if v[3] then + insert_tokenized_sub(repl_r, str, span, v[3]); + else + table.move(str, span[v[2]][1], span[v[2]][2] - 1, #repl_r + 1, repl_r); + end; + elseif v[4] then + insert_tokenized_sub(repl_r, str, span, v[4]); + end; + else + table.move(v, 1, #v, #repl_r + 1, repl_r); + end; + elseif span[v] then + table.move(str, span[v][1], span[v][2] - 1, #repl_r + 1, repl_r); + end; + end; + repl_r.n = #repl_r; + return repl_r; +end; + +re_m.sub = check_re('RegEx', 'sub', function(self, repl, str, n, repl_flag_str, source) + if repl_flag_str ~= nil and type(repl_flag_str) ~= "number" and type(repl_flag_str) ~= "string" then + error(string.format("invalid argument #5 to 'sub' (string expected, got %s)", typeof(repl_flag_str)), 3); + end + local repl_flags = { + l = false, o = false, u = false, + }; + for f in string.gmatch(repl_flag_str or '', utf8.charpattern) do + if repl_flags[f] ~= false then + error("invalid regular expression substitution flag " .. f, 3); + end; + repl_flags[f] = true; + end; + local repl_type = type(repl); + if repl_type == "number" then + repl ..= ''; + elseif repl_type ~= "string" and repl_type ~= "function" and (not repl_flags.o or repl_type ~= "table") then + error(string.format("invalid argument #2 to 'sub' (string/function%s expected, got %s)", repl_flags.o and "/table" or '', typeof(repl)), 3); + end; + if tonumber(n) then + n = tonumber(n); + if n <= -1 or n ~= n then + n = math.huge; + end; + elseif n ~= nil then + error(string.format("invalid argument #4 to 'sub' (number expected, got %s)", typeof(n)), 3); + else + n = math.huge; + end; + if n < 1 then + return str, 0; + end; + local min_repl_n = 0; + if repl_type == "string" then + repl = to_str_arr(repl); + if not repl_flags.l then + local i1 = 0; + local repl_r = table.create(3); + local group_n = self.token.group_n; + local conditional_c = { }; + while i1 < repl.n do + local i2 = i1; + repeat + i2 += 1; + until not repl[i2] or repl[i2] == 0x24 or repl[i2] == 0x5C or (repl[i2] == 0x3A or repl[i2] == 0x7D) and conditional_c[1]; + min_repl_n += i2 - i1 - 1; + if i2 - i1 > 1 then + table.insert(repl_r, table.move(repl, i1 + 1, i2 - 1, 1, table.create(i2 - i1 - 1))); + end; + if repl[i2] == 0x3A then + local current_conditional_c = conditional_c[1]; + if current_conditional_c[2] then + error("malformed substitution pattern", 3); + end; + current_conditional_c[2] = table.move(repl_r, current_conditional_c[3], #repl_r, 1, table.create(#repl_r + 1 - current_conditional_c[3])); + for i3 = #repl_r, current_conditional_c[3], -1 do + repl_r[i3] = nil; + end; + elseif repl[i2] == 0x7D then + local current_conditional_c = table.remove(conditional_c, 1); + local second_c = table.move(repl_r, current_conditional_c[3], #repl_r, 1, table.create(#repl_r + 1 - current_conditional_c[3])); + for i3 = #repl_r, current_conditional_c[3], -1 do + repl_r[i3] = nil; + end; + table.insert(repl_r, { "condition", current_conditional_c[1], current_conditional_c[2] ~= true and (current_conditional_c[2] or second_c), current_conditional_c[2] and second_c }); + elseif repl[i2] then + i2 += 1; + local subst_c = repl[i2]; + if not subst_c then + if repl[i2 - 1] == 0x5C then + error("replacement string must not end with a trailing backslash", 3); + end; + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, repl[i2 - 1]); + else + table.insert(repl_r, { repl[i2 - 1] }); + end; + elseif subst_c == 0x5C and repl[i2 - 1] == 0x24 then + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, 0x24); + else + table.insert(repl_r, { 0x24 }); + end; + i2 -= 1; + min_repl_n += 1; + elseif subst_c == 0x30 then + table.insert(repl_r, 0); + elseif subst_c > 0x30 and subst_c <= 0x39 then + local start_i2 = i2; + local group_i = subst_c - 0x30; + while repl[i2 + 1] and repl[i2 + 1] >= 0x30 and repl[i2 + 1] <= 0x39 do + group_i ..= repl[i2 + 1] - 0x30; + i2 += 1; + end; + group_i = tonumber(group_i); + if not repl_flags.u and group_i > group_n then + error("reference to non-existent subpattern", 3); + end; + table.insert(repl_r, group_i); + elseif subst_c == 0x7B and repl[i2 - 1] == 0x24 then + i2 += 1; + local start_i2 = i2; + while repl[i2] and + (repl[i2] >= 0x30 and repl[i2] <= 0x39 + or repl[i2] >= 0x41 and repl[i2] <= 0x5A + or repl[i2] >= 0x61 and repl[i2] <= 0x7A + or repl[i2] == 0x5F) do + i2 += 1; + end; + if (repl[i2] == 0x7D or repl[i2] == 0x3A and (repl[i2 + 1] == 0x2B or repl[i2 + 1] == 0x2D)) and i2 ~= start_i2 then + local group_k = utf8_sub(repl.s, start_i2, i2); + if repl[start_i2] >= 0x30 and repl[start_i2] <= 0x39 then + group_k = tonumber(group_k); + if not repl_flags.u and group_k > group_n then + error("reference to non-existent subpattern", 3); + end; + else + group_k = self.group_id[group_k]; + if not repl_flags.u and (not group_k or group_k > group_n) then + error("reference to non-existent subpattern", 3); + end; + end; + if repl[i2] == 0x3A then + i2 += 1; + table.insert(conditional_c, { group_k, repl[i2] == 0x2D, #repl_r + 1 }); + else + table.insert(repl_r, group_k); + end; + else + error("malformed substitution pattern", 3); + end; + else + local c_escape_char; + if repl[i2 - 1] == 0x24 then + if subst_c ~= 0x24 then + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, 0x24); + else + table.insert(repl_r, { 0x24 }); + end; + end; + else + c_escape_char = escape_chars[repl[i2]]; + if type(c_escape_char) ~= "number" then + c_escape_char = nil; + end; + end; + local prev_repl_f = repl_r[#repl_r]; + if type(prev_repl_f) == "table" then + table.insert(prev_repl_f, c_escape_char or repl[i2]); + else + table.insert(repl_r, { c_escape_char or repl[i2] }); + end; + min_repl_n += 1; + end; + end; + i1 = i2; + end; + if conditional_c[1] then + error("malformed substitution pattern", 3); + end; + if not repl_r[2] and type(repl_r[1]) == "table" and repl_r[1][1] ~= "condition" then + repl, repl.n = repl_r[1], #repl_r[1]; + else + repl, repl_type = repl_r, "subst_string"; + end; + end; + end; + str = to_str_arr(str); + local incr, i0, count = 0, 1, 0; + while i0 <= str.n + incr + 1 do + local span = re_rawfind(self.token, str, i0, self.flags, self.verb_flags, false); + if not span then + break; + end; + local repl_r; + if repl_type == "string" then + repl_r = repl; + elseif repl_type == "subst_string" then + repl_r = insert_tokenized_sub(table.create(min_repl_n), str, span, repl); + else + local re_match; + local repl_c; + if repl_type == "table" then + re_match = utf8_sub(str.s, span[0][1], span[0][2]); + repl_c = repl[re_match]; + else + re_match = new_match(span, self.group_id, source, str.s); + repl_c = repl(re_match); + end; + if repl_c == re_match or repl_flags.o and not repl_c then + local repl_n = span[0][2] - span[0][1]; + repl_r = table.move(str, span[0][1], span[0][2] - 1, 1, table.create(repl_n)); + repl_r.n = repl_n; + elseif type(repl_c) == "string" then + repl_r = to_str_arr(repl_c); + elseif type(repl_c) == "number" then + repl_r = to_str_arr(repl_c .. ''); + elseif repl_flags.o then + error(string.format("invalid replacement value (a %s)", type(repl_c)), 3); + else + repl_r = { n = 0 }; + end; + end; + local match_len = span[0][2] - span[0][1]; + local repl_len = math.min(repl_r.n, match_len); + for i1 = 0, repl_len - 1 do + str[span[0][1] + i1] = repl_r[i1 + 1]; + end; + local i1 = span[0][1] + repl_len; + i0 = span[0][2]; + if match_len > repl_r.n then + for i2 = 1, match_len - repl_r.n do + table.remove(str, i1); + incr -= 1; + i0 -= 1; + end; + elseif repl_r.n > match_len then + for i2 = 1, repl_r.n - match_len do + table.insert(str, i1 + i2 - 1, repl_r[repl_len + i2]); + incr += 1; + i0 += 1; + end; + end; + if match_len <= 0 then + i0 += 1; + end; + count += 1; + if n < count + 1 then + break; + end; + end; + return from_str_arr(str), count; +end); + +re_m.split = check_re('RegEx', 'split', function(self, str, n) + if tonumber(n) then + n = tonumber(n); + if n <= -1 or n ~= n then + n = math.huge; + end; + elseif n ~= nil then + error(string.format("invalid argument #3 to 'split' (number expected, got %s)", typeof(n)), 3); + else + n = math.huge; + end; + str = to_str_arr(str); + local i, count = 1, 0; + local ret = { }; + local prev_empty = 0; + while i <= str.n + 1 do + count += 1; + local span = n >= count and re_rawfind(self.token, str, i, self.flags, self.verb_flags, false); + if not span then + break; + end; + table.insert(ret, utf8_sub(str.s, i - prev_empty, span[0][1])); + prev_empty = span[0][1] >= span[0][2] and 1 or 0; + i = span[0][2] + prev_empty; + end; + table.insert(ret, string.sub(str.s, utf8.offset(str.s, i - prev_empty))); + return ret; +end); + +-- +local function re_index(self, index) + return re_m[index] or proxy[self].flags[index]; +end; + +local function re_tostr(self) + return proxy[self].pattern_repr .. proxy[self].flag_repr; +end; +-- + +local other_valid_group_char = { + -- non-capturing group + [0x3A] = true, + -- lookarounds + [0x21] = true, [0x3D] = true, + -- atomic + [0x3E] = true, + -- branch reset + [0x7C] = true, +}; + +local function tokenize_ptn(codes, flags) + if flags.unicode and not options.unicodeData then + return "options.unicodeData cannot be turned off while having unicode flag"; + end; + local i, len = 1, codes.n; + local group_n = 0; + local outln, group_id, verb_flags = { }, { }, { + newline = 1, newline_seq = 1, not_empty = 0, + }; + while i <= len do + local c = codes[i]; + if c == 0x28 then + -- Match + local ret; + if codes[i + 1] == 0x2A then + i += 2; + local start_i = i; + while codes[i] + and (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F or codes[i] == 0x3A) do + i += 1; + end; + if codes[i] ~= 0x29 and codes[i - 1] ~= 0x3A then + -- fallback as normal and ( can't be repeated + return "quantifier doesn't follow a repeatable pattern"; + end; + local selected_verb = utf8_sub(codes.s, start_i, i); + if selected_verb == "positive_lookahead:" or selected_verb == "negative_lookhead:" + or selected_verb == "positive_lookbehind:" or selected_verb == "negative_lookbehind:" + or selected_verb:find("^[pn]l[ab]:$") then + ret = { 0x28, nil, nil, selected_verb:find('^n') and 0x21 or 0x3D, selected_verb:find('b', 3, true) and 1 }; + elseif selected_verb == "atomic:" then + ret = { 0x28, nil, nil, 0x3E, nil }; + elseif selected_verb == "ACCEPT" or selected_verb == "FAIL" or selected_verb == 'F' or selected_verb == "PRUNE" or selected_verb == "SKIP" then + ret = selected_verb == 'F' and "FAIL" or selected_verb; + else + if line_verbs[selected_verb] then + verb_flags.newline = selected_verb; + elseif selected_verb == "BSR_ANYCRLF" or selected_verb == "BSR_UNICODE" then + verb_flags.newline_seq = selected_verb == "BSR_UNICODE" and 1 or 0; + elseif selected_verb == "NOTEMPTY" or selected_verb == "NOTEMPTY_ATSTART" then + verb_flags.not_empty = selected_verb == "NOTEMPTY" and 1 or 2; + else + return "unknown or malformed verb"; + end; + if outln[1] then + return "this verb must be placed at the beginning of the regex"; + end; + end; + elseif codes[i + 1] == 0x3F then + -- ? syntax + i += 2; + if codes[i] == 0x23 then + -- comments + i = table.find(codes, 0x29, i); + if not i then + return "unterminated parenthetical"; + end; + i += 1; + continue; + elseif not codes[i] then + return "unterminated parenthetical"; + end; + ret = { 0x28, nil, nil, codes[i], nil }; + if codes[i] == 0x30 and codes[i + 1] == 0x29 then + -- recursive match entire pattern + ret[1], ret[2], ret[3], ret[5] = "recurmatch", 0, 0, nil; + elseif codes[i] > 0x30 and codes[i] <= 0x39 then + -- recursive match + local org_i = i; + i += 1; + while codes[i] >= 0x30 and codes[i] <= 0x30 do + i += 1; + end; + if codes[i] ~= 0x29 then + return "invalid group structure"; + end; + ret[1], ret[2], ret[4] = "recurmatch", tonumber(utf8_sub(codes.s, org_i, i)), nil; + elseif codes[i] == 0x3C and codes[i + 1] == 0x21 or codes[i + 1] == 0x3D then + -- lookbehinds + i += 1; + ret[4], ret[5] = codes[i], 1; + elseif codes[i] == 0x7C then + -- branch reset + ret[5] = group_n; + elseif codes[i] == 0x50 or codes[i] == 0x3C or codes[i] == 0x27 then + if codes[i] == 0x50 then + i += 1; + end; + if codes[i] == 0x3D then + -- backref + local start_i = i + 1; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if not codes[i] then + return "unterminated parenthetical"; + elseif codes[i] ~= 0x29 or i == start_i then + return "invalid group structure"; + end; + ret = { "backref", utf8_sub(codes.s, start_i, i) }; + elseif codes[i] == 0x3C or codes[i - 1] ~= 0x50 and codes[i] == 0x27 then + -- named capture + local delimiter = codes[i] == 0x27 and 0x27 or 0x3E; + local start_i = i + 1; + i += 1; + if codes[i] == 0x29 then + return "missing character in subpattern"; + elseif codes[i] >= 0x30 and codes[i] <= 0x39 then + return "subpattern name must not begin with a digit"; + elseif not (codes[i] >= 0x41 and codes[i] <= 0x5A or codes[i] >= 0x61 and codes[i] <= 0x7A or codes[i] == 0x5F) then + return "invalid character in subpattern"; + end; + i += 1; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if not codes[i] then + return "unterminated parenthetical"; + elseif codes[i] ~= delimiter then + return "invalid character in subpattern"; + end; + local name = utf8_sub(codes.s, start_i, i); + group_n += 1; + if (group_id[name] or group_n) ~= group_n then + return "subpattern name already exists"; + end; + for name1, group_n1 in pairs(group_id) do + if name ~= name1 and group_n == group_n1 then + return "different names for subpatterns of the same number aren't permitted"; + end; + end; + group_id[name] = group_n; + ret[2], ret[4] = group_n, nil; + else + return "invalid group structure"; + end; + elseif not other_valid_group_char[codes[i]] then + return "invalid group structure"; + end; + else + group_n += 1; + ret = { 0x28, group_n, nil, nil }; + end; + if ret then + table.insert(outln, ret); + end; + elseif c == 0x29 then + -- Close parenthesis + local i1 = #outln + 1; + local lookbehind_c = -1; + local current_lookbehind_c = 0; + local max_c, group_c = 0, 0; + repeat + i1 -= 1; + local v, is_table = outln[i1], type(outln[i1]) == "table"; + if is_table and v[1] == 0x28 then + group_c += 1; + if current_lookbehind_c and v.count then + current_lookbehind_c += v.count; + end; + if not v[3] then + if v[4] == 0x7C then + group_n = v[5] + math.max(max_c, group_c); + end; + if current_lookbehind_c ~= lookbehind_c and lookbehind_c ~= -1 then + lookbehind_c = nil; + else + lookbehind_c = current_lookbehind_c; + end; + break; + end; + elseif v == alternation then + if current_lookbehind_c ~= lookbehind_c and lookbehind_c ~= -1 then + lookbehind_c, current_lookbehind_c = nil, nil; + else + lookbehind_c, current_lookbehind_c = current_lookbehind_c, 0; + end; + max_c, group_c = math.max(max_c, group_c), 0; + elseif current_lookbehind_c then + if is_table and v[1] == "quantifier" then + if v[2] == v[3] then + current_lookbehind_c += v[2]; + else + current_lookbehind_c = nil; + end; + else + current_lookbehind_c += 1; + end; + end; + until i1 < 1; + if i1 < 1 then + return "unmatched ) in regular expression"; + end; + local v = outln[i1]; + local outln_len_p_1 = #outln + 1; + local ret = { 0x29, v[2], i1, v[4], v[5], count = lookbehind_c }; + if (v[4] == 0x21 or v[4] == 0x3D) and v[5] and not lookbehind_c then + return "lookbehind assertion is not fixed width"; + end; + v[3] = outln_len_p_1; + table.insert(outln, ret); + elseif c == 0x2E then + table.insert(outln, dot); + elseif c == 0x5B then + -- Character set + local negate, char_class = false, nil; + i += 1; + local start_i = i; + if codes[i] == 0x5E then + negate = true; + i += 1; + elseif codes[i] == 0x2E or codes[i] == 0x3A or codes[i] == 0x3D then + -- POSIX character classes + char_class = codes[i]; + end; + local ret; + if codes[i] == 0x5B or codes[i] == 0x5C then + ret = { }; + else + ret = { codes[i] }; + i += 1; + end; + while codes[i] ~= 0x5D do + if not codes[i] then + return "unterminated character class"; + elseif codes[i] == 0x2D and ret[1] and type(ret[1]) == "number" then + if codes[i + 1] == 0x5D then + table.insert(ret, 1, 0x2D); + else + i += 1; + local ret_c = codes[i]; + if ret_c == 0x5B then + if codes[i + 1] == 0x2E or codes[i + 1] == 0x3A or codes[i + 1] == 0x3D then + -- Check for POSIX character class, name does not matter + local i1 = i + 2; + repeat + i1 = table.find(codes, 0x5D, i1); + until not i1 or codes[i1 - 1] ~= 0x5C; + if not i1 then + return "unterminated character class"; + elseif codes[i1 - 1] == codes[i + 1] and i1 - 1 ~= i + 1 then + return "invalid range in character class"; + end; + end; + if ret[1] > 0x5B then + return "invalid range in character class"; + end; + elseif ret_c == 0x5C then + i += 1; + if codes[i] == 0x78 then + local radix0, radix1; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix0 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix1 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + else + i -= 1; + end; + else + i -= 1; + end; + ret_c = radix0 and (radix1 and 16 * radix0 + radix1 or radix0) or 0; + elseif codes[i] >= 0x30 and codes[i] <= 0x37 then + local radix0, radix1, radix2 = codes[i] - 0x30, nil, nil; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix1 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix2 = codes[i] - 0x30; + else + i -= 1; + end; + else + i -= 1; + end; + ret_c = radix1 and (radix2 and 64 * radix0 + 8 * radix1 + radix2 or 8 * radix0 + radix1) or radix0; + else + ret_c = escape_chars[codes[i]] or codes[i]; + if type(ret_c) ~= "number" then + return "invalid range in character class"; + end; + end; + elseif ret[1] > ret_c then + return "invalid range in character class"; + end; + ret[1] = { "range", ret[1], ret_c }; + end; + elseif codes[i] == 0x5B then + if codes[i + 1] == 0x2E or codes[i + 1] == 0x3A or codes[i + 1] == 0x3D then + local i1 = i + 2; + repeat + i1 = table.find(codes, 0x5D, i1); + until not i1 or codes[i1 - 1] ~= 0x5C; + if not i1 then + return "unterminated character class"; + elseif codes[i1 - 1] ~= codes[i + 1] or i1 - 1 == i + 1 then + table.insert(ret, 1, 0x5B); + elseif codes[i1 - 1] == 0x2E or codes[i1 - 1] == 0x3D then + return "POSIX collating elements aren't supported"; + elseif codes[i1 - 1] == 0x3A then + -- I have no plans to support escape codes (\) in character class names + local negate = codes[i + 3] == 0x5E; + local class_name = utf8_sub(codes.s, i + (negate and 3 or 2), i1 - 1); + -- If not valid then throw an error + if not posix_class_names[class_name] then + return "unknown POSIX class name"; + end; + table.insert(ret, 1, { "class", class_name, negate }); + i = i1; + end; + else + table.insert(ret, 1, 0x5B); + end; + elseif codes[i] == 0x5C then + i += 1; + if codes[i] == 0x78 then + local radix0, radix1; + i += 1; + if codes[i] == 0x7B then + i += 1; + local org_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed hexadecimal character"; + elseif i - org_i > 4 then + return "character offset too large"; + end; + table.insert(ret, 1, tonumber(utf8_sub(codes.s, org_i, i), 16)); + else + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix0 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66 then + radix1 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + else + i -= 1; + end; + else + i -= 1; + end; + table.insert(ret, 1, radix0 and (radix1 and 16 * radix0 + radix1 or radix0) or 0); + end; + elseif codes[i] >= 0x30 and codes[i] <= 0x37 then + local radix0, radix1, radix2 = codes[i] - 0x30, nil, nil; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix1 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix2 = codes[i] - 0x30; + else + i -= 1; + end; + else + i -= 1; + end; + table.insert(ret, 1, radix1 and (radix2 and 64 * radix0 + 8 * radix1 + radix2 or 8 * radix0 + radix1) or radix0); + elseif codes[i] == 0x45 then + -- intentionally left blank, \E that's not preceded \Q is ignored + elseif codes[i] == 0x51 then + local start_i = i + 1; + repeat + i = table.find(codes, 0x5C, i + 1); + until not i or codes[i + 1] == 0x45; + table.move(codes, start_i, i and i - 1 or #codes, #outln + 1, outln); + if not i then + break; + end; + i += 1; + elseif codes[i] == 0x4E then + if codes[i + 1] == 0x7B and codes[i + 2] == 0x55 and codes[i + 3] == 0x2B and flags.unicode then + i += 4; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == start_i then + return "malformed Unicode code point"; + end; + local code_point = tonumber(utf8_sub(codes.s, start_i, i)); + table.insert(ret, 1, code_point); + else + return "invalid escape sequence"; + end; + elseif codes[i] == 0x50 or codes[i] == 0x70 then + if not options.unicodeData then + return "options.unicodeData cannot be turned off when using \\p"; + end; + i += 1; + if codes[i] ~= 0x7B then + local c_name = utf8.char(codes[i] or 0); + if not valid_categories[c_name] then + return "unknown or malformed script name"; + end; + table.insert(ret, 1, { "category", false, c_name }); + else + local negate = codes[i] == 0x50; + i += 1; + if codes[i] == 0x5E then + i += 1; + negate = not negate; + end; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if codes[i] ~= 0x7D then + return "unknown or malformed script name"; + end; + local c_name = utf8_sub(codes.s, start_i, i); + local script_set = chr_scripts[c_name]; + if script_set then + table.insert(ret, 1, { "charset", negate, script_set }); + elseif not valid_categories[c_name] then + return "unknown or malformed script name"; + else + table.insert(ret, 1, { "category", negate, c_name }); + end; + end; + elseif codes[i] == 0x6F then + i += 1; + if codes[i] ~= 0x7B then + return "malformed octal code"; + end; + i += 1; + local org_i = i; + while codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed octal code"; + end; + local ret_chr = tonumber(utf8_sub(codes.s, org_i, i), 8); + if ret_chr > 0xFFFF then + return "character offset too large"; + end; + table.insert(ret, 1, ret_chr); + else + local esc_char = escape_chars[codes[i]]; + table.insert(ret, 1, type(esc_char) == "string" and { "class", esc_char, false } or esc_char or codes[i]); + end; + elseif flags.ignoreCase and codes[i] >= 0x61 and codes[i] <= 0x7A then + table.insert(ret, 1, codes[i] - 0x20); + else + table.insert(ret, 1, codes[i]); + end; + i += 1; + end; + if codes[i - 1] == char_class and i - 1 ~= start_i then + return char_class == 0x3A and "POSIX named classes are only support within a character set" or "POSIX collating elements aren't supported"; + end; + if not ret[2] and not negate then + table.insert(outln, ret[1]); + else + table.insert(outln, { "charset", negate, ret }); + end; + elseif c == 0x5C then + -- Escape char + i += 1; + local escape_c = codes[i]; + if not escape_c then + return "pattern may not end with a trailing backslash"; + elseif escape_c >= 0x30 and escape_c <= 0x39 then + local org_i = i; + while codes[i + 1] and codes[i + 1] >= 0x30 and codes[i + 1] <= 0x39 do + i += 1; + end; + local escape_d = tonumber(utf8_sub(codes.s, org_i, i + 1)); + if escape_d > group_n and i ~= org_i then + i = org_i; + local radix0, radix1, radix2; + if codes[i] <= 0x37 then + radix0 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix1 = codes[i] - 0x30; + i += 1; + if codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 then + radix2 = codes[i] - 0x30; + else + i -= 1; + end; + else + i -= 1; + end; + end; + table.insert(outln, radix0 and (radix1 and (radix2 and 64 * radix0 + 8 * radix1 + radix2 or 8 * radix0 + radix1) or radix0) or codes[org_i]); + else + table.insert(outln, { "backref", escape_d }); + end; + elseif escape_c == 0x45 then + -- intentionally left blank, \E that's not preceded \Q is ignored + elseif escape_c == 0x51 then + local start_i = i + 1; + repeat + i = table.find(codes, 0x5C, i + 1); + until not i or codes[i + 1] == 0x45; + table.move(codes, start_i, i and i - 1 or #codes, #outln + 1, outln); + if not i then + break; + end; + i += 1; + elseif escape_c == 0x4E then + if codes[i + 1] == 0x7B and codes[i + 2] == 0x55 and codes[i + 3] == 0x2B and flags.unicode then + i += 4; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == start_i then + return "malformed Unicode code point"; + end; + local code_point = tonumber(utf8_sub(codes.s, start_i, i)); + table.insert(outln, code_point); + else + table.insert(outln, escape_chars[0x4E]); + end; + elseif escape_c == 0x50 or escape_c == 0x70 then + if not options.unicodeData then + return "options.unicodeData cannot be turned off when using \\p"; + end; + i += 1; + if codes[i] ~= 0x7B then + local c_name = utf8.char(codes[i] or 0); + if not valid_categories[c_name] then + return "unknown or malformed script name"; + end; + table.insert(outln, { "category", false, c_name }); + else + local negate = escape_c == 0x50; + i += 1; + if codes[i] == 0x5E then + i += 1; + negate = not negate; + end; + local start_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x5A + or codes[i] >= 0x61 and codes[i] <= 0x7A + or codes[i] == 0x5F) do + i += 1; + end; + if codes[i] ~= 0x7D then + return "unknown or malformed script name"; + end; + local c_name = utf8_sub(codes.s, start_i, i); + local script_set = chr_scripts[c_name]; + if script_set then + table.insert(outln, { "charset", negate, script_set }); + elseif not valid_categories[c_name] then + return "unknown or malformed script name"; + else + table.insert(outln, { "category", negate, c_name }); + end; + end; + elseif escape_c == 0x67 and (codes[i + 1] == 0x7B or codes[i + 1] >= 0x30 and codes[i + 1] <= 0x39) then + local is_grouped = false; + i += 1; + if codes[i] == 0x7B then + i += 1; + is_grouped = true; + elseif codes[i] < 0x30 or codes[i] > 0x39 then + return "malformed reference code"; + end; + local org_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if is_grouped and codes[i] ~= 0x7D then + return "malformed reference code"; + end; + local ref_name = tonumber(utf8_sub(codes.s, org_i, i + (is_grouped and 0 or 1))); + table.insert(outln, { "backref", ref_name }); + if not is_grouped then + i -= 1; + end; + elseif escape_c == 0x6F then + i += 1; + if codes[i + 1] ~= 0x7B then + return "malformed octal code"; + end + i += 1; + local org_i = i; + while codes[i] and codes[i] >= 0x30 and codes[i] <= 0x37 do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed octal code"; + end; + local ret_chr = tonumber(utf8_sub(codes.s, org_i, i), 8); + if ret_chr > 0xFFFF then + return "character offset too large"; + end; + table.insert(outln, ret_chr); + elseif escape_c == 0x78 then + local radix0, radix1; + i += 1; + if codes[i] == 0x7B then + i += 1; + local org_i = i; + while codes[i] and + (codes[i] >= 0x30 and codes[i] <= 0x39 + or codes[i] >= 0x41 and codes[i] <= 0x46 + or codes[i] >= 0x61 and codes[i] <= 0x66) do + i += 1; + end; + if codes[i] ~= 0x7D or i == org_i then + return "malformed hexadecimal code"; + elseif i - org_i > 4 then + return "character offset too large"; + end; + table.insert(outln, tonumber(utf8_sub(codes.s, org_i, i), 16)); + else + if codes[i] and (codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66) then + radix0 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + i += 1; + if codes[i] and (codes[i] >= 0x30 and codes[i] <= 0x39 or codes[i] >= 0x41 and codes[i] <= 0x46 or codes[i] >= 0x61 and codes[i] <= 0x66) then + radix1 = codes[i] - ((codes[i] >= 0x41 and codes[i] <= 0x5A) and 0x37 or (codes[i] >= 0x61 and codes[i] <= 0x7A) and 0x57 or 0x30); + else + i -= 1; + end; + else + i -= 1; + end; + table.insert(outln, radix0 and (radix1 and 16 * radix0 + radix1 or radix0) or 0); + end; + else + local esc_char = b_escape_chars[escape_c] or escape_chars[escape_c]; + table.insert(outln, esc_char or escape_c); + end; + elseif c == 0x2A or c == 0x2B or c == 0x3F or c == 0x7B then + -- Quantifier + local start_q, end_q; + if c == 0x7B then + local org_i = i + 1; + local start_i; + while codes[i + 1] and (codes[i + 1] >= 0x30 and codes[i + 1] <= 0x39 or codes[i + 1] == 0x2C and not start_i and i + 1 ~= org_i) do + i += 1; + if codes[i] == 0x2C then + start_i = i; + end; + end; + if codes[i + 1] == 0x7D then + i += 1; + if not start_i then + start_q = tonumber(utf8_sub(codes.s, org_i, i)); + end_q = start_q; + else + start_q, end_q = tonumber(utf8_sub(codes.s, org_i, start_i)), start_i + 1 == i and math.huge or tonumber(utf8_sub(codes.s, start_i + 1, i)); + if end_q < start_q then + return "numbers out of order in {} quantifier"; + end; + end; + else + table.move(codes, org_i - 1, i, #outln + 1, outln); + end; + else + start_q, end_q = c == 0x2B and 1 or 0, c == 0x3F and 1 or math.huge; + end; + if start_q then + local quantifier_type = flags.ungreedy and "lazy" or "greedy"; + if codes[i + 1] == 0x2B or codes[i + 1] == 0x3F then + i += 1; + quantifier_type = codes[i] == 0x2B and "possessive" or flags.ungreedy and "greedy" or "lazy"; + end; + local outln_len = #outln; + local last_outln_value = outln[outln_len]; + if not last_outln_value or type(last_outln_value) == "table" and (last_outln_value[1] == "quantifier" or last_outln_value[1] == 0x28 or b_escape_chars[last_outln_value[1]]) + or last_outln_value == alternation or type(last_outln_value) == "string" then + return "quantifier doesn't follow a repeatable pattern"; + end; + if end_q == 0 then + table.remove(outln); + elseif start_q ~= 1 or end_q ~= 1 then + if type(last_outln_value) == "table" and last_outln_value[1] == 0x29 then + outln_len = last_outln_value[3]; + end; + outln[outln_len] = { "quantifier", start_q, end_q, quantifier_type, outln[outln_len] }; + end; + end; + elseif c == 0x7C then + -- Alternation + table.insert(outln, alternation); + local i1 = #outln; + repeat + i1 -= 1; + local v1, is_table = outln[i1], type(outln[i1]) == "table"; + if is_table and v1[1] == 0x29 then + i1 = outln[i1][3]; + elseif is_table and v1[1] == 0x28 then + if v1[4] == 0x7C then + group_n = v1[5]; + end; + break; + end; + until not v1; + elseif c == 0x24 or c == 0x5E then + table.insert(outln, c == 0x5E and beginning_str or end_str); + elseif flags.ignoreCase and c >= 0x61 and c <= 0x7A then + table.insert(outln, c - 0x20); + elseif flags.extended and (c >= 0x09 and c <= 0x0D or c == 0x20 or c == 0x23) then + if c == 0x23 then + repeat + i += 1; + until not codes[i] or codes[i] == 0x0A or codes[i] == 0x0D; + end; + else + table.insert(outln, c); + end; + i += 1; + end; + local max_group_n = 0; + for i, v in ipairs(outln) do + if type(v) == "table" and (v[1] == 0x28 or v[1] == "quantifier" and type(v[5]) == "table" and v[5][1] == 0x28) then + if v[1] == "quantifier" then + v = v[5]; + end; + if not v[3] then + return "unterminated parenthetical"; + elseif v[2] then + max_group_n = math.max(max_group_n, v[2]); + end; + elseif type(v) == "table" and (v[1] == "backref" or v[1] == "recurmatch") then + if not group_id[v[2]] and (type(v[2]) ~= "number" or v[2] > group_n) then + return "reference to a non-existent or invalid subpattern"; + elseif v[1] == "recurmatch" and v[2] ~= 0 then + for i1, v1 in ipairs(outln) do + if type(v1) == "table" and v1[1] == 0x28 and v1[2] == v[2] then + v[3] = i1; + break; + end; + end; + elseif type(v[2]) == "string" then + v[2] = group_id[v[2]]; + end; + end; + end; + outln.group_n = max_group_n; + return outln, group_id, verb_flags; +end; + +if not tonumber(options.cacheSize) then + error(string.format("expected number for options.cacheSize, got %s", typeof(options.cacheSize)), 2); +end; +local cacheSize = math.floor(options.cacheSize or 0) ~= 0 and tonumber(options.cacheSize); +local cache_pattern, cache_pattern_names; +if not cacheSize then +elseif cacheSize < 0 or cacheSize ~= cacheSize then + error("cache size cannot be a negative number or a NaN", 2); +elseif cacheSize == math.huge then + cache_pattern, cache_pattern_names = { nil }, { nil }; +elseif cacheSize >= 2 ^ 32 then + error("cache size too large", 2); +else + cache_pattern, cache_pattern_names = table.create(options.cacheSize), table.create(options.cacheSize); +end; +if cacheSize then + function re.pruge() + table.clear(cache_pattern_names); + table.clear(cache_pattern); + end; +end; + +local function new_re(str_arr, flags, flag_repr, pattern_repr) + local tokenized_ptn, group_id, verb_flags; + local cache_format = cacheSize and string.format("%s|%s", str_arr.s, flag_repr); + local cached_token = cacheSize and cache_pattern[table.find(cache_pattern_names, cache_format)]; + if cached_token then + tokenized_ptn, group_id, verb_flags = table.unpack(cached_token, 1, 3); + else + tokenized_ptn, group_id, verb_flags = tokenize_ptn(str_arr, flags); + if type(tokenized_ptn) == "string" then + error(tokenized_ptn, 2); + end; + if cacheSize and tokenized_ptn[1] then + table.insert(cache_pattern_names, 1, cache_format); + table.insert(cache_pattern, 1, { tokenized_ptn, group_id, verb_flags }); + if cacheSize ~= math.huge then + table.remove(cache_pattern_names, cacheSize + 1); + table.remove(cache_pattern, cacheSize + 1); + end; + end; + end; + + local object = newproxy(true); + proxy[object] = { name = "RegEx", flags = flags, flag_repr = flag_repr, pattern_repr = pattern_repr, token = tokenized_ptn, group_id = group_id, verb_flags = verb_flags }; + local object_mt = getmetatable(object); + object_mt.__index = setmetatable(flags, re_m); + object_mt.__tostring = re_tostr; + object_mt.__metatable = lockmsg; + + return object; +end; + +local function escape_fslash(pre) + return (#pre % 2 == 0 and '\\' or '') .. pre .. '.'; +end; + +local function sort_flag_chr(a, b) + return a:lower() < b:lower(); +end; + +function re.new(...) + if select('#', ...) == 0 then + error("missing argument #1 (string expected)", 2); + end; + local ptn, flags_str = ...; + if type(ptn) == "number" then + ptn ..= ''; + elseif type(ptn) ~= "string" then + error(string.format("invalid argument #1 (string expected, got %s)", typeof(ptn)), 2); + end; + if type(flags_str) ~= "string" and type(flags_str) ~= "number" and flags_str ~= nil then + error(string.format("invalid argument #2 (string expected, got %s)", typeof(flags_str)), 2); + end; + + local flags = { + anchored = false, caseless = false, multiline = false, dotall = false, unicode = false, ungreedy = false, extended = false, + }; + local flag_repr = { }; + for f in string.gmatch(flags_str or '', utf8.charpattern) do + if flags[flag_map[f]] ~= false then + error("invalid regular expression flag " .. f, 3); + end; + flags[flag_map[f]] = true; + table.insert(flag_repr, f); + end; + table.sort(flag_repr, sort_flag_chr); + flag_repr = table.concat(flag_repr); + return new_re(to_str_arr(ptn), flags, flag_repr, string.format("/%s/", ptn:gsub("(\\*)/", escape_fslash))); +end; + +function re.fromstring(...) + if select('#', ...) == 0 then + error("missing argument #1 (string expected)", 2); + end; + local ptn = ...; + if type(ptn) == "number" then + ptn ..= ''; + elseif type(ptn) ~= "string" then + error(string.format("invalid argument #1 (string expected, got %s)", typeof(ptn), 2)); + end; + local str_arr = to_str_arr(ptn); + local delimiter = str_arr[1]; + if not delimiter then + error("empty regex", 2); + elseif delimiter == 0x5C or (delimiter >= 0x30 and delimiter <= 0x39) or (delimiter >= 0x41 and delimiter <= 0x5A) or (delimiter >= 0x61 and delimiter <= 0x7A) then + error("delimiter must not be alphanumeric or a backslash", 2); + end; + + local i0 = 1; + repeat + i0 = table.find(str_arr, delimiter, i0 + 1); + if not i0 then + error(string.format("no ending delimiter ('%s') found", utf8.char(delimiter)), 2); + end; + local escape_count = 1; + while str_arr[i0 - escape_count] == 0x5C do + escape_count += 1; + end; + until escape_count % 2 == 1; + + local flags = { + anchored = false, caseless = false, multiline = false, dotall = false, unicode = false, ungreedy = false, extended = false, + }; + local flag_repr = { }; + while str_arr.n > i0 do + local f = utf8.char(table.remove(str_arr)); + str_arr.n -= 1; + if flags[flag_map[f]] ~= false then + error("invalid regular expression flag " .. f, 3); + end; + flags[flag_map[f]] = true; + table.insert(flag_repr, f); + end; + table.sort(flag_repr, sort_flag_chr); + flag_repr = table.concat(flag_repr); + table.remove(str_arr, 1); + table.remove(str_arr); + str_arr.n -= 2; + str_arr.s = string.sub(str_arr.s, 2, 1 + str_arr.n); + return new_re(str_arr, flags, flag_repr, string.sub(ptn, 1, 2 + str_arr.n)); +end; + +local re_escape_line_chrs = { + ['\0'] = '\\x00', ['\n'] = '\\n', ['\t'] = '\\t', ['\r'] = '\\r', ['\f'] = '\\f', +}; + +function re.escape(...) + if select('#', ...) == 0 then + error("missing argument #1 (string expected)", 2); + end; + local str, extended, delimiter = ...; + if type(str) == "number" then + str ..= ''; + elseif type(str) ~= "string" then + error(string.format("invalid argument #1 to 'escape' (string expected, got %s)", typeof(str)), 2); + end; + if delimiter == nil then + delimiter = ''; + elseif type(delimiter) == "number" then + delimiter ..= ''; + elseif type(delimiter) ~= "string" then + error(string.format("invalid argument #3 to 'escape' (string expected, got %s)", typeof(delimiter)), 2); + end; + if utf8.len(delimiter) > 1 or delimiter:match("^[%a\\]$") then + error("delimiter have not be alphanumeric", 2); + end; + return (string.gsub(str, "[\0\f\n\r\t]", re_escape_line_chrs):gsub(string.format("[\\%s#()%%%%*+.?[%%]^{|%s]", extended and '%s' or '', (delimiter:find'^[%%%]]$' and '%' or '') .. delimiter), "\\%1")); +end; + +function re.type(...) + if select('#', ...) == 0 then + error("missing argument #1", 2); + end; + return proxy[...] and proxy[...].name; +end; + +for k, f in pairs(re_m) do + re[k] = f; +end; + +re_m = { __index = re_m }; + +lockmsg = re.fromstring([[/The\s*metatable\s*is\s*(?:locked|inaccessible)(?#Nice try :])/i]]); +getmetatable(lockmsg).__metatable = lockmsg; + +local function readonly_table() + error("Attempt to modify a readonly table", 2); +end; + +match_m = { + __index = match_m, + __metatable = lockmsg, + __newindex = readonly_table, +}; + +re.Match = setmetatable({ }, match_m); + +return setmetatable({ }, { + __index = re, + __metatable = lockmsg, + __newindex = readonly_table, +}); From 6303ae30c4d9b939a781f2a2c6a33475f825b8c9 Mon Sep 17 00:00:00 2001 From: JohnnyMorganz Date: Tue, 5 Jul 2022 16:59:09 +0100 Subject: [PATCH 05/10] Fix op used when stringifying AstExprIndexName (#572) * Fix op used when stringifying AstExprIndexName * Remove visualizeWithSelf --- Analysis/src/Transpiler.cpp | 20 +++----------------- tests/Transpiler.test.cpp | 2 +- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 1577bd6..9feff1c 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -205,20 +205,6 @@ struct Printer } } - void visualizeWithSelf(AstExpr& expr, bool self) - { - if (!self) - return visualize(expr); - - AstExprIndexName* func = expr.as(); - LUAU_ASSERT(func); - - visualize(*func->expr); - writer.symbol(":"); - advance(func->indexLocation.begin); - writer.identifier(func->index.value); - } - void visualizeTypePackAnnotation(const AstTypePack& annotation, bool forVarArg) { advance(annotation.location.begin); @@ -366,7 +352,7 @@ struct Printer } else if (const auto& a = expr.as()) { - visualizeWithSelf(*a->func, a->self); + visualize(*a->func); writer.symbol("("); bool first = true; @@ -385,7 +371,7 @@ struct Printer else if (const auto& a = expr.as()) { visualize(*a->expr); - writer.symbol("."); + writer.symbol(std::string(1, a->op)); writer.write(a->index.value); } else if (const auto& a = expr.as()) @@ -766,7 +752,7 @@ struct Printer else if (const auto& a = program.as()) { writer.keyword("function"); - visualizeWithSelf(*a->name, a->func->self != nullptr); + visualize(*a->name); visualizeFunctionBody(*a->func); } else if (const auto& a = program.as()) diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index b02a52b..d2ed9ae 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -583,7 +583,7 @@ TEST_CASE_FIXTURE(Fixture, "transpile_error_expr") auto names = AstNameTable{allocator}; ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); - CHECK_EQ("local a = (error-expr: f.%error-id%)-(error-expr)", transpileWithTypes(*parseResult.root)); + CHECK_EQ("local a = (error-expr: f:%error-id%)-(error-expr)", transpileWithTypes(*parseResult.root)); } TEST_CASE_FIXTURE(Fixture, "transpile_error_stat") From 4d98a16ea8f4985ca757d2ceebc278a6db3f9761 Mon Sep 17 00:00:00 2001 From: JohnnyMorganz Date: Tue, 5 Jul 2022 17:08:05 +0100 Subject: [PATCH 06/10] Store resolved types in `astResolvedTypes` (#574) --- Analysis/include/Luau/TypeInfer.h | 1 + Analysis/src/Frontend.cpp | 2 ++ Analysis/src/TypeInfer.cpp | 27 ++++++++++++++++++++------- tests/Frontend.test.cpp | 2 ++ tests/TypeInfer.test.cpp | 23 +++++++++++++++++++++++ 5 files changed, 48 insertions(+), 7 deletions(-) diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 455654d..dac2902 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -345,6 +345,7 @@ private: TypePackId freshTypePack(TypeLevel level); TypeId resolveType(const ScopePtr& scope, const AstType& annotation); + TypeId resolveTypeWorker(const ScopePtr& scope, const AstType& annotation); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 4cfaa11..9195363 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -496,6 +496,8 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalastTypes.clear(); module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); + module->astResolvedTypes.clear(); + module->astResolvedTypePacks.clear(); module->scopes.resize(1); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index d9486a4..4fafb50 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -4884,6 +4884,13 @@ TypePackId TypeChecker::freshTypePack(TypeLevel level) } TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation) +{ + TypeId ty = resolveTypeWorker(scope, annotation); + currentModule->astResolvedTypes[&annotation] = ty; + return ty; +} + +TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& annotation) { if (const auto& lit = annotation.as()) { @@ -5200,9 +5207,10 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypeList TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation) { + TypePackId result; if (const AstTypePackVariadic* variadic = annotation.as()) { - return addTypePack(TypePackVar{VariadicTypePack{resolveType(scope, *variadic->variadicType)}}); + result = addTypePack(TypePackVar{VariadicTypePack{resolveType(scope, *variadic->variadicType)}}); } else if (const AstTypePackGeneric* generic = annotation.as()) { @@ -5216,10 +5224,12 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack else reportError(TypeError{generic->location, UnknownSymbol{genericName, UnknownSymbol::Type}}); - return errorRecoveryTypePack(scope); + result = errorRecoveryTypePack(scope); + } + else + { + result = *genericTy; } - - return *genericTy; } else if (const AstTypePackExplicit* explicitTp = annotation.as()) { @@ -5229,14 +5239,17 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack types.push_back(resolveType(scope, *type)); if (auto tailType = explicitTp->typeList.tailType) - return addTypePack(types, resolveTypePack(scope, *tailType)); - - return addTypePack(types); + result = addTypePack(types, resolveTypePack(scope, *tailType)); + else + result = addTypePack(types); } else { ice("Unknown AstTypePack kind"); } + + currentModule->astResolvedTypePacks[&annotation] = result; + return result; } bool ApplyTypeFunction::isDirty(TypeId ty) diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index b9c2470..4f33139 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -791,6 +791,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "discard_type_graphs") CHECK_EQ(0, module->internalTypes.typeVars.size()); CHECK_EQ(0, module->internalTypes.typePacks.size()); CHECK_EQ(0, module->astTypes.size()); + CHECK_EQ(0, module->astResolvedTypes.size()); + CHECK_EQ(0, module->astResolvedTypePacks.size()); } TEST_CASE_FIXTURE(FrontendFixture, "it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index efdfe0b..a175b82 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1003,4 +1003,27 @@ TEST_CASE_FIXTURE(Fixture, "do_not_bind_a_free_table_to_a_union_containing_that_ )"); } +TEST_CASE_FIXTURE(Fixture, "types stored in astResolvedTypes") +{ + CheckResult result = check(R"( +type alias = typeof("hello") +local function foo(param: alias) +end + )"); + + auto node = findNodeAtPosition(*getMainSourceModule(), {2, 16}); + auto ty = lookupType("alias"); + REQUIRE(node); + REQUIRE(node->is()); + REQUIRE(ty); + + auto func = node->as(); + REQUIRE(func->args.size == 1); + + auto arg = *func->args.begin(); + auto annotation = arg->annotation; + + CHECK_EQ(*getMainModule()->astResolvedTypes.find(annotation), *ty); +} + TEST_SUITE_END(); From baa35117bd164a6a034c82835ffbe72edbdd2f24 Mon Sep 17 00:00:00 2001 From: JohnnyMorganz Date: Tue, 5 Jul 2022 17:19:54 +0100 Subject: [PATCH 07/10] Use syntheticName for stringified MetatableTypeVar (#563) --- Analysis/src/ToString.cpp | 6 ++++++ tests/ToString.test.cpp | 31 +++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index bbbad9d..5d54d14 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -700,6 +700,12 @@ struct TypeVarStringifier void operator()(TypeId, const MetatableTypeVar& mtv) { state.result.invalid = true; + if (!state.exhaustive && mtv.syntheticName) + { + state.emit(*mtv.syntheticName); + return; + } + state.emit("{ @metatable "); stringify(mtv.metatable); state.emit(","); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 387e07c..1601a15 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -96,6 +96,37 @@ TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") //clang-format on } +TEST_CASE_FIXTURE(Fixture, "metatable") +{ + TypeVar table{TypeVariant(TableTypeVar())}; + TypeVar metatable{TypeVariant(TableTypeVar())}; + TypeVar mtv{TypeVariant(MetatableTypeVar{&table, &metatable})}; + CHECK_EQ("{ @metatable { }, { } }", toString(&mtv)); +} + +TEST_CASE_FIXTURE(Fixture, "named_metatable") +{ + TypeVar table{TypeVariant(TableTypeVar())}; + TypeVar metatable{TypeVariant(TableTypeVar())}; + TypeVar mtv{TypeVariant(MetatableTypeVar{&table, &metatable, "NamedMetatable"})}; + CHECK_EQ("NamedMetatable", toString(&mtv)); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "named_metatable_toStringNamedFunction") +{ + CheckResult result = check(R"( + local function createTbl(): NamedMetatable + return setmetatable({}, {}) + end + type NamedMetatable = typeof(createTbl()) + )"); + + TypeId ty = requireType("createTbl"); + const FunctionTypeVar* ftv = get(follow(ty)); + REQUIRE(ftv); + CHECK_EQ("createTbl(): NamedMetatable", toStringNamedFunction("createTbl", *ftv)); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") { CheckResult result = check(R"( From 42510bb5c885dff16159ded78c95ab3f7ea83d4f Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Tue, 5 Jul 2022 11:36:33 -0700 Subject: [PATCH 08/10] Fix test after merging a stale PR (#577) --- tests/TypeInfer.provisional.test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 059aed2..018059f 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -518,7 +518,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_f )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Not all codepaths in this function return '{ @metatable T, {| |} }, a...'.", toString(result.errors[0])); + CHECK_EQ("Not all codepaths in this function return 'self, a...'.", toString(result.errors[0])); } TEST_SUITE_END(); From a7ae439b0f219000cd4d78982c13e44c6a46b2c1 Mon Sep 17 00:00:00 2001 From: Alan Jeffrey <403333+asajeffrey@users.noreply.github.com> Date: Tue, 5 Jul 2022 16:25:09 -0500 Subject: [PATCH 09/10] Document new table type features (#567) --- docs/_pages/typecheck.md | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/docs/_pages/typecheck.md b/docs/_pages/typecheck.md index 363056c..8e032da 100644 --- a/docs/_pages/typecheck.md +++ b/docs/_pages/typecheck.md @@ -104,15 +104,15 @@ From the type checker perspective, each table can be in one of three states. The ### Unsealed tables -An unsealed table is a table whose properties could still be tacked on. This occurs when the table constructor literal had zero expressions. This is one way to accumulate knowledge of the shape of this table. +An unsealed table is a table which supports adding new properties, which updates the tables type. Unsealed tables are created using table literals. This is one way to accumulate knowledge of the shape of this table. ```lua -local t = {} -- {} -t.x = 1 -- {x: number} -t.y = 2 -- {x: number, y: number} +local t = {x = 1} -- {x: number} +t.y = 2 -- {x: number, y: number} +t.z = 3 -- {x: number, y: number, z: number} ``` -However, if this local were written as `local t: {} = {}`, it ends up sealing the table, so the two assignments henceforth will not be ok. +However, if this local were written as `local t: { x: number } = { x = 1 }`, it ends up sealing the table, so the two assignments henceforth will not be ok. Furthermore, once we exit the scope where this unsealed table was created in, we seal it. @@ -128,16 +128,25 @@ local v2 = vec2(1, 2) v2.z = 3 -- not ok ``` -### Sealed tables - -A sealed table is a table that is now locked down. This occurs when the table constructor literal had 1 or more expression, or when the table type is spelt out explicitly via a type annotation. +Unsealed tables are *exact* in that any property of the table must be named by the type. Since Luau treats missing properties as having value `nil`, this means that we can treat an unsealed table which does not mention a property as if it mentioned the property, as long as that property is optional. ```lua -local t = {x = 1} -- {x: number} -t.y = 2 -- not ok +local t = {x = 1} +local u : { x : number, y : number? } = t -- ok because y is optional +local v : { x : number, z : number } = t -- not ok because z is not optional ``` -Sealed tables support *width subtyping*, which allows a table with more properties to be used as a table with fewer +### Sealed tables + +A sealed table is a table that is now locked down. This occurs when the table type is spelled out explicitly via a type annotation, or if it is returned from a function. + +```lua +local t : { x: number } = {x = 1} +t.y = 2 -- not ok +``` + +Sealed tables are *inexact* in that the table may have properties which are not mentioned in the type. +As a result, sealed tables support *width subtyping*, which allows a table with more properties to be used as a table with fewer ```lua type Point1D = { x : number } From dbcd5fb28ec1dd63fdcfb666d9034d67f06e8e13 Mon Sep 17 00:00:00 2001 From: Allan N Jeremy Date: Thu, 7 Jul 2022 18:21:40 +0300 Subject: [PATCH 10/10] build: changed data output to be different for each os (#581) - macos: data-macos-latest.json - ubuntu: data-ubuntu-latest.json - windows: data-windows-latest.json --- .github/workflows/benchmark-dev.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/benchmark-dev.yml b/.github/workflows/benchmark-dev.yml index 03ff9d6..2c6eae4 100644 --- a/.github/workflows/benchmark-dev.yml +++ b/.github/workflows/benchmark-dev.yml @@ -75,7 +75,7 @@ jobs: name: ${{ matrix.bench.title }} (Windows ${{matrix.arch}}) tool: "benchmarkluau" output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json + external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json github-token: ${{ secrets.GITHUB_TOKEN }} - name: Push benchmark results @@ -85,7 +85,7 @@ jobs: cd gh-pages git config user.name github-actions git config user.email github@users.noreply.github.com - git add ./dev/bench/data.json + git add ./dev/bench/data-${{ matrix.os }}.json git commit -m "Add benchmarks results for ${{ github.sha }}" git push cd .. @@ -156,7 +156,7 @@ jobs: name: ${{ matrix.bench.title }} tool: "benchmarkluau" output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json + external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Store ${{ matrix.bench.title }} result (CacheGrind) @@ -166,7 +166,7 @@ jobs: name: ${{ matrix.bench.title }} (CacheGrind) tool: "roblox" output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json + external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Push benchmark results @@ -176,7 +176,7 @@ jobs: cd gh-pages git config user.name github-actions git config user.email github@users.noreply.github.com - git add ./dev/bench/data.json + git add ./dev/bench/data-${{ matrix.os }}.json git commit -m "Add benchmarks results for ${{ github.sha }}" git push cd .. @@ -244,7 +244,7 @@ jobs: gh-pages-branch: "main" output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json + external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Store ${{ matrix.bench.title }} result (CacheGrind) @@ -254,7 +254,7 @@ jobs: tool: "roblox" gh-pages-branch: "main" output-file-path: ./${{ matrix.bench.script }}-output.txt - external-data-json-path: ./gh-pages/dev/bench/data.json + external-data-json-path: ./gh-pages/dev/bench/data-${{ matrix.os }}.json github-token: ${{ secrets.BENCH_GITHUB_TOKEN }} - name: Push benchmark results @@ -264,7 +264,7 @@ jobs: cd gh-pages git config user.name github-actions git config user.email github@users.noreply.github.com - git add ./dev/bench/data.json + git add ./dev/bench/data-${{ matrix.os }}.json git commit -m "Add benchmarks results for ${{ github.sha }}" git push cd ..