diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 0ca54e6..dfd08e5 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -50,6 +50,7 @@ LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false) LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false) LUAU_FASTFLAGVARIABLE(LuauDeclareClassPrototype, false) +LUAU_FASTFLAGVARIABLE(LuauCallableClasses, false) namespace Luau { @@ -4257,26 +4258,33 @@ std::optional> TypeChecker::checkCallOverload(const Sc std::vector metaArgLocations; - // Might be a callable table + // Might be a callable table or class + std::optional callTy = std::nullopt; if (const MetatableTypeVar* mttv = get(fn)) { - if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, /* addErrors= */ false)) - { - // Construct arguments with 'self' added in front - TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); + callTy = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, /* addErrors= */ false); + } + else if (const ClassTypeVar* ctv = get(fn); FFlag::LuauCallableClasses && ctv && ctv->metatable) + { + callTy = getIndexTypeFromType(scope, *ctv->metatable, "__call", expr.func->location, /* addErrors= */ false); + } - TypePack* metaCallArgs = getMutable(metaCallArgPack); - metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); + if (callTy) + { + // Construct arguments with 'self' added in front + TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); - metaArgLocations = *argLocations; - metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); + TypePack* metaCallArgs = getMutable(metaCallArgPack); + metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); - fn = instantiate(scope, *ty, expr.func->location); + metaArgLocations = *argLocations; + metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); - argPack = metaCallArgPack; - args = metaCallArgs; - argLocations = &metaArgLocations; - } + fn = instantiate(scope, *callTy, expr.func->location); + + argPack = metaCallArgPack; + args = metaCallArgs; + argLocations = &metaArgLocations; } const FunctionTypeVar* ftv = get(fn); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index d00f1d8..d1a24e5 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -91,6 +91,13 @@ struct ClassFixture : BuiltinsFixture typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; addGlobalBinding(frontend, "Vector2", vector2Type, "@test"); + TypeId callableClassMetaType = arena.addType(TableTypeVar{}); + TypeId callableClassType = arena.addType(ClassTypeVar{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test"}); + getMutable(callableClassMetaType)->props = { + {"__call", {makeFunction(arena, nullopt, {callableClassType, typeChecker.stringType}, {typeChecker.numberType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType}; + for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings) persist(tf.type); @@ -514,4 +521,17 @@ TEST_CASE_FIXTURE(ClassFixture, "unions_of_intersections_of_classes") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(ClassFixture, "callable_classes") +{ + ScopedFastFlag luauCallableClasses{"LuauCallableClasses", true}; + + CheckResult result = check(R"( + local x : CallableClass + local y = x("testing") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("y"))); +} + TEST_SUITE_END();