From 9095fc4b83ea40c72980a539b6c9c06136d67895 Mon Sep 17 00:00:00 2001 From: JohnnyMorganz Date: Mon, 28 Nov 2022 18:02:41 +0000 Subject: [PATCH] Support `__call` on class type vars (#762) Currently, the metatable of a class type var is not correctly checked to see if it is callable (i.e. has a `__call` metatable). We resolve this issue by checking the metatable in `checkCallOverload` Fixes #756 Co-authored-by: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com> --- Analysis/src/TypeInfer.cpp | 36 +++++++++++++++++++------------- tests/TypeInfer.classes.test.cpp | 20 ++++++++++++++++++ 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 0ca54e6..dfd08e5 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -50,6 +50,7 @@ LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false) LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false) LUAU_FASTFLAGVARIABLE(LuauDeclareClassPrototype, false) +LUAU_FASTFLAGVARIABLE(LuauCallableClasses, false) namespace Luau { @@ -4257,26 +4258,33 @@ std::optional> TypeChecker::checkCallOverload(const Sc std::vector metaArgLocations; - // Might be a callable table + // Might be a callable table or class + std::optional callTy = std::nullopt; if (const MetatableTypeVar* mttv = get(fn)) { - if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, /* addErrors= */ false)) - { - // Construct arguments with 'self' added in front - TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); + callTy = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, /* addErrors= */ false); + } + else if (const ClassTypeVar* ctv = get(fn); FFlag::LuauCallableClasses && ctv && ctv->metatable) + { + callTy = getIndexTypeFromType(scope, *ctv->metatable, "__call", expr.func->location, /* addErrors= */ false); + } - TypePack* metaCallArgs = getMutable(metaCallArgPack); - metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); + if (callTy) + { + // Construct arguments with 'self' added in front + TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); - metaArgLocations = *argLocations; - metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); + TypePack* metaCallArgs = getMutable(metaCallArgPack); + metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); - fn = instantiate(scope, *ty, expr.func->location); + metaArgLocations = *argLocations; + metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); - argPack = metaCallArgPack; - args = metaCallArgs; - argLocations = &metaArgLocations; - } + fn = instantiate(scope, *callTy, expr.func->location); + + argPack = metaCallArgPack; + args = metaCallArgs; + argLocations = &metaArgLocations; } const FunctionTypeVar* ftv = get(fn); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index d00f1d8..d1a24e5 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -91,6 +91,13 @@ struct ClassFixture : BuiltinsFixture typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; addGlobalBinding(frontend, "Vector2", vector2Type, "@test"); + TypeId callableClassMetaType = arena.addType(TableTypeVar{}); + TypeId callableClassType = arena.addType(ClassTypeVar{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test"}); + getMutable(callableClassMetaType)->props = { + {"__call", {makeFunction(arena, nullopt, {callableClassType, typeChecker.stringType}, {typeChecker.numberType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType}; + for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings) persist(tf.type); @@ -514,4 +521,17 @@ TEST_CASE_FIXTURE(ClassFixture, "unions_of_intersections_of_classes") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(ClassFixture, "callable_classes") +{ + ScopedFastFlag luauCallableClasses{"LuauCallableClasses", true}; + + CheckResult result = check(R"( + local x : CallableClass + local y = x("testing") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("y"))); +} + TEST_SUITE_END();