-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes print "testing closures and coroutines" local unpack = table.unpack local A,B = 0,{g=10} function f(x) local a = {} for i=1,1000 do local y = 0 do a[i] = function () B.g = B.g+1; y = y+x; return y+A end end end local dummy = function () return a[A] end collectgarbage() A = 1; assert(dummy() == a[1]); A = 0; assert(a[1]() == x) assert(a[3]() == x) collectgarbage() assert(B.g == 12) return a end a = f(10) -- force a GC in this level local x = {[1] = {}} -- to detect a GC setmetatable(x, {__mode = 'kv'}) while x[1] do -- repeat until GC local a = A..A..A..A -- create garbage A = A+1 end assert(a[1]() == 20+A) assert(a[1]() == 30+A) assert(a[2]() == 10+A) collectgarbage() assert(a[2]() == 20+A) assert(a[2]() == 30+A) assert(a[3]() == 20+A) assert(a[8]() == 10+A) assert(getmetatable(x).__mode == 'kv') assert(B.g == 19) -- testing closures with 'for' control variable a = {} for i=1,10 do a[i] = {set = function(x) i=x end, get = function () return i end} if i == 3 then break end end assert(a[4] == nil) a[1].set(10) assert(a[2].get() == 2) a[2].set('a') assert(a[3].get() == 3) assert(a[2].get() == 'a') a = {} for i, k in pairs{'a', 'b'} do a[i] = {set = function(x, y) i=x; k=y end, get = function () return i, k end} if i == 2 then break end end a[1].set(10, 20) local r,s = a[2].get() assert(r == 2 and s == 'b') r,s = a[1].get() assert(r == 10 and s == 20) a[2].set('a', 'b') r,s = a[2].get() assert(r == "a" and s == "b") -- testing closures with 'for' control variable x break for i=1,3 do f = function () return i end break end assert(f() == 1) for k, v in pairs{"a", "b"} do f = function () return k, v end break end assert(({f()})[1] == 1) assert(({f()})[2] == "a") -- testing closure x break x return x errors local b function f(x) local first = 1 while 1 do if x == 3 and not first then return end local a = 'xuxu' b = function (op, y) if op == 'set' then a = x+y else return a end end if x == 1 then do break end elseif x == 2 then return else if x ~= 3 then error() end end first = nil end end for i=1,3 do f(i) assert(b('get') == 'xuxu') b('set', 10); assert(b('get') == 10+i) b = nil end pcall(f, 4); assert(b('get') == 'xuxu') b('set', 10); assert(b('get') == 14) local w -- testing multi-level closure function f(x) return function (y) return function (z) return w+x+y+z end end end y = f(10) w = 1.345 assert(y(20)(30) == 60+w) -- testing closures x repeat-until local a = {} local i = 1 repeat local x = i a[i] = function () i = x+1; return x end until i > 10 or a[i]() ~= x assert(i == 11 and a[1]() == 1 and a[3]() == 3 and i == 4) print'+' -- test for correctly closing upvalues in tail calls of vararg functions local function t () local function c(a,b) assert(a=="test" and b=="OK") end local function v(f, ...) c("test", f() ~= 1 and "FAILED" or "OK") end local x = 1 return v(function() return x end) end t() -- coroutine tests local f -- assert(coroutine.running() == nil) -- tests for global environment local _G = getfenv() local function foo (a) setfenv(0, a) coroutine.yield(getfenv()) assert(getfenv(0) == a) assert(getfenv(1) == _G) assert(getfenv(loadstring"") == a) return getfenv() end f = coroutine.wrap(foo) local a = {} assert(f(a) == _G) local a,b = pcall(f) assert(a and b == _G) -- tests for multiple yield/resume arguments local function eqtab (t1, t2) assert(table.getn(t1) == table.getn(t2)) for i,v in ipairs(t1) do assert(t2[i] == v) end end _G.x = nil -- declare x function foo (a, ...) assert(coroutine.running() == f) assert(coroutine.status(f) == "running") local arg = {...} for i=1,table.getn(arg) do _G.x = {coroutine.yield(unpack(arg[i]))} end return unpack(a) end f = coroutine.create(foo) assert(type(f) == "thread" and coroutine.status(f) == "suspended") assert(string.find(tostring(f), "thread")) local s,a,b,c,d s,a,b,c,d = coroutine.resume(f, {1,2,3}, {}, {1}, {'a', 'b', 'c'}) assert(s and a == nil and coroutine.status(f) == "suspended") s,a,b,c,d = coroutine.resume(f) eqtab(_G.x, {}) assert(s and a == 1 and b == nil) s,a,b,c,d = coroutine.resume(f, 1, 2, 3) eqtab(_G.x, {1, 2, 3}) assert(s and a == 'a' and b == 'b' and c == 'c' and d == nil) s,a,b,c,d = coroutine.resume(f, "xuxu") eqtab(_G.x, {"xuxu"}) assert(s and a == 1 and b == 2 and c == 3 and d == nil) assert(coroutine.status(f) == "dead") s, a = coroutine.resume(f, "xuxu") assert(not s and string.find(a, "dead") and coroutine.status(f) == "dead") -- yields in tail calls local function foo (i) return coroutine.yield(i) end f = coroutine.wrap(function () for i=1,10 do assert(foo(i) == _G.x) end return 'a' end) for i=1,10 do _G.x = i; assert(f(i) == i) end _G.x = 'xuxu'; assert(f('xuxu') == 'a') -- recursive function pf (n, i) coroutine.yield(n) pf(n*i, i+1) end f = coroutine.wrap(pf) local s=1 for i=1,10 do assert(f(1, 1) == s) s = s*i end -- sieve function gen (n) return coroutine.wrap(function () for i=2,n do coroutine.yield(i) end end) end function filter (p, g) return coroutine.wrap(function () while 1 do local n = g() if n == nil then return end if n%p ~= 0 then coroutine.yield(n) end end end) end local x = gen(100) local a = {} while 1 do local n = x() if n == nil then break end table.insert(a, n) x = filter(n, x) end assert(table.getn(a) == 25 and a[table.getn(a)] == 97) -- errors in coroutines function foo () -- assert(debug.getinfo(1).currentline == debug.getinfo(foo).linedefined + 1) -- assert(debug.getinfo(2).currentline == debug.getinfo(goo).linedefined) coroutine.yield(3) error("foo") end local fooerr = "closure.lua:284: foo" function goo() foo() end x = coroutine.wrap(goo) assert(x() == 3) local a,b = pcall(x) assert(not a and b == fooerr) x = coroutine.create(goo) a,b = coroutine.resume(x) assert(a and b == 3) a,b = coroutine.resume(x) assert(not a and b == fooerr and coroutine.status(x) == "dead") a,b = coroutine.resume(x) assert(not a and string.find(b, "dead") and coroutine.status(x) == "dead") -- co-routines x for loop function all (a, n, k) if k == 0 then coroutine.yield(a) else for i=1,n do a[k] = i all(a, n, k-1) end end end local a = 0 for t in coroutine.wrap(function () all({}, 5, 4) end) do a = a+1 end assert(a == 5^4) -- access to locals of collected coroutines local C = {}; setmetatable(C, {__mode = "kv"}) local x = coroutine.wrap (function () local a = 10 local function f () a = a+10; return a end while true do a = a+1 coroutine.yield(f) end end) C[1] = x; local f = x() assert(f() == 21 and x()() == 32 and x() == f) x = nil collectgarbage() -- assert(C[1] == nil) assert(f() == 43 and f() == 53) -- old bug: attempt to resume itself function co_func (current_co) assert(coroutine.running() == current_co) assert(coroutine.resume(current_co) == false) assert(coroutine.resume(current_co) == false) return 10 end local co = coroutine.create(co_func) local a,b = coroutine.resume(co, co) assert(a == true and b == 10) assert(coroutine.resume(co, co) == false) assert(coroutine.resume(co, co) == false) -- access to locals of erroneous coroutines local x = coroutine.create (function () local a = 10 _G.f = function () a=a+1; return a end error('x') end) assert(not coroutine.resume(x)) -- overwrite previous position of local `a' assert(not coroutine.resume(x, 1, 1, 1, 1, 1, 1, 1)) assert(_G.f() == 11) assert(_G.f() == 12) if not T then (Message or print)('\a\n >>> testC not active: skipping yield/hook tests <<<\n\a') else local turn function fact (t, x) assert(turn == t) if x == 0 then return 1 else return x*fact(t, x-1) end end local A,B,a,b = 0,0,0,0 local x = coroutine.create(function () T.setyhook("", 2) A = fact("A", 10) end) local y = coroutine.create(function () T.setyhook("", 3) B = fact("B", 11) end) while A==0 or B==0 do if A==0 then turn = "A"; T.resume(x) end if B==0 then turn = "B"; T.resume(y) end end assert(B/A == 11) end -- leaving a pending coroutine open _X = coroutine.wrap(function () local a = 10 local x = function () a = a+1 end coroutine.yield() end) _X() -- coroutine environments co = coroutine.create(function () coroutine.yield(getfenv(0)) return loadstring("return a")() end) -- large closure size do local a1, a2, a3, a4, a5, a6, a7, a8, a9, a0 local b1, b2, b3, b4, b5, b6, b7, b8, b9, b0 local c1, c2, c3, c4, c5, c6, c7, c8, c9, c0 local d1, d2, d3, d4, d5, d6, d7, d8, d9, d0 local f = function() return a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a0 + b1 + b2 + b3 + b4 + b5 + b6 + b7 + b8 + b9 + b0 + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c0 + d1 + d2 + d3 + d4 + d5 + d6 + d7 + d8 + d9 + d0 end end return 'OK'