diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 7716805..c934094 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1659,6 +1659,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar TypeId propTy = resolveType(scope, *prop.ty); bool assignToMetatable = isMetamethod(propName); + Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; // Function types always take 'self', but this isn't reflected in the // parsed annotation. Add it here. @@ -1674,16 +1675,13 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar } } - if (ctv->props.count(propName) == 0) + if (assignTo.count(propName) == 0) { - if (assignToMetatable) - metatable->props[propName] = {propTy}; - else - ctv->props[propName] = {propTy}; + assignTo[propName] = {propTy}; } else { - TypeId currentTy = assignToMetatable ? metatable->props[propName].type : ctv->props[propName].type; + TypeId currentTy = assignTo[propName].type; // We special-case this logic to keep the intersection flat; otherwise we // would create a ton of nested intersection types. @@ -1693,19 +1691,13 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar options.push_back(propTy); TypeId newItv = addType(IntersectionTypeVar{std::move(options)}); - if (assignToMetatable) - metatable->props[propName] = {newItv}; - else - ctv->props[propName] = {newItv}; + assignTo[propName] = {newItv}; } else if (get(currentTy)) { TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}}); - if (assignToMetatable) - metatable->props[propName] = {intersection}; - else - ctv->props[propName] = {intersection}; + assignTo[propName] = {intersection}; } else { diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index fd6fb83..9fe0c6a 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -336,4 +336,30 @@ local s : Cls = GetCls() LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "class_definition_overload_metamethods") +{ + loadDefinition(R"( + declare class Vector3 + end + + declare class CFrame + function __mul(self, other: CFrame): CFrame + function __mul(self, other: Vector3): Vector3 + end + + declare function newVector3(): Vector3 + declare function newCFrame(): CFrame + )"); + + CheckResult result = check(R"( + local base = newCFrame() + local shouldBeCFrame = base * newCFrame() + local shouldBeVector = base * newVector3() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("shouldBeCFrame")), "CFrame"); + CHECK_EQ(toString(requireType("shouldBeVector")), "Vector3"); +} + TEST_SUITE_END();