-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes print "testing closures and coroutines" local unpack = table.unpack local A,B = 0,{g=10} function f(x) local a = {} for i=1,1000 do local y = 0 do a[i] = function () B.g = B.g+1; y = y+x; return y+A end end end local dummy = function () return a[A] end collectgarbage() A = 1; assert(dummy() == a[1]); A = 0; assert(a[1]() == x) assert(a[3]() == x) collectgarbage() assert(B.g == 12) return a end a = f(10) -- force a GC in this level local x = {[1] = {}} -- to detect a GC setmetatable(x, {__mode = 'kv'}) while x[1] do -- repeat until GC local a = A..A..A..A -- create garbage A = A+1 end assert(a[1]() == 20+A) assert(a[1]() == 30+A) assert(a[2]() == 10+A) collectgarbage() assert(a[2]() == 20+A) assert(a[2]() == 30+A) assert(a[3]() == 20+A) assert(a[8]() == 10+A) assert(getmetatable(x).__mode == 'kv') assert(B.g == 19) -- testing closures with 'for' control variable a = {} for i=1,10 do a[i] = {set = function(x) i=x end, get = function () return i end} if i == 3 then break end end assert(a[4] == nil) a[1].set(10) assert(a[2].get() == 2) a[2].set('a') assert(a[3].get() == 3) assert(a[2].get() == 'a') a = {} for i, k in pairs{'a', 'b'} do a[i] = {set = function(x, y) i=x; k=y end, get = function () return i, k end} if i == 2 then break end end a[1].set(10, 20) local r,s = a[2].get() assert(r == 2 and s == 'b') r,s = a[1].get() assert(r == 10 and s == 20) a[2].set('a', 'b') r,s = a[2].get() assert(r == "a" and s == "b") -- testing closures with 'for' control variable x break for i=1,3 do f = function () return i end break end assert(f() == 1) for k, v in pairs{"a", "b"} do f = function () return k, v end break end assert(({f()})[1] == 1) assert(({f()})[2] == "a") -- testing closure x break x return x errors local b function f(x) local first = 1 while 1 do if x == 3 and not first then return end local a = 'xuxu' b = function (op, y) if op == 'set' then a = x+y else return a end end if x == 1 then do break end elseif x == 2 then return else if x ~= 3 then error() end end first = nil end end for i=1,3 do f(i) assert(b('get') == 'xuxu') b('set', 10); assert(b('get') == 10+i) b = nil end pcall(f, 4); assert(b('get') == 'xuxu') b('set', 10); assert(b('get') == 14) local w -- testing multi-level closure function f(x) return function (y) return function (z) return w+x+y+z end end end y = f(10) w = 1.345 assert(y(20)(30) == 60+w) -- testing closures x repeat-until local a = {} local i = 1 repeat local x = i a[i] = function () i = x+1; return x end until i > 10 or a[i]() ~= x assert(i == 11 and a[1]() == 1 and a[3]() == 3 and i == 4) print'+' -- test for correctly closing upvalues in tail calls of vararg functions local function t () local function c(a,b) assert(a=="test" and b=="OK") end local function v(f, ...) c("test", f() ~= 1 and "FAILED" or "OK") end local x = 1 return v(function() return x end) end t() -- coroutine tests local f -- assert(coroutine.running() == nil) -- tests for global environment local _G = getfenv() local function foo (a) setfenv(0, a) coroutine.yield(getfenv()) assert(getfenv(0) == a) assert(getfenv(1) == _G) assert(getfenv(loadstring"") == a) return getfenv() end f = coroutine.wrap(foo) local a = {} assert(f(a) == _G) local a,b = pcall(f) assert(a and b == _G) -- tests for multiple yield/resume arguments local function eqtab (t1, t2) assert(table.getn(t1) == table.getn(t2)) for i,v in ipairs(t1) do assert(t2[i] == v) end end _G.x = nil -- declare x function foo (a, ...) assert(coroutine.running() == f) assert(coroutine.status(f) == "running") local arg = {...} for i=1,table.getn(arg) do _G.x = {coroutine.yield(unpack(arg[i]))} end return unpack(a) end f = coroutine.create(foo) assert(type(f) == "thread" and coroutine.status(f) == "suspended") assert(string.find(tostring(f), "thread")) local s,a,b,c,d s,a,b,c,d = coroutine.resume(f, {1,2,3}, {}, {1}, {'a', 'b', 'c'}) assert(s and a == nil and coroutine.status(f) == "suspended") s,a,b,c,d = coroutine.resume(f) eqtab(_G.x, {}) assert(s and a == 1 and b == nil) s,a,b,c,d = coroutine.resume(f, 1, 2, 3) eqtab(_G.x, {1, 2, 3}) assert(s and a == 'a' and b == 'b' and c == 'c' and d == nil) s,a,b,c,d = coroutine.resume(f, "xuxu") eqtab(_G.x, {"xuxu"}) assert(s and a == 1 and b == 2 and c == 3 and d == nil) assert(coroutine.status(f) == "dead") s, a = coroutine.resume(f, "xuxu") assert(not s and string.find(a, "dead") and coroutine.status(f) == "dead") -- yields in tail calls local function foo (i) return coroutine.yield(i) end f = coroutine.wrap(function () for i=1,10 do assert(foo(i) == _G.x) end return 'a' end) for i=1,10 do _G.x = i; assert(f(i) == i) end _G.x = 'xuxu'; assert(f('xuxu') == 'a') -- recursive function pf (n, i) coroutine.yield(n) pf(n*i, i+1) end f = coroutine.wrap(pf) local s=1 for i=1,10 do assert(f(1, 1) == s) s = s*i end -- sieve function gen (n) return coroutine.wrap(function () for i=2,n do coroutine.yield(i) end end) end function filter (p, g) return coroutine.wrap(function () while 1 do local n = g() if n == nil then return end if n%p ~= 0 then coroutine.yield(n) end end end) end local x = gen(100) local a = {} while 1 do local n = x() if n == nil then break end table.insert(a, n) x = filter(n, x) end assert(table.getn(a) == 25 and a[table.getn(a)] == 97) -- errors in coroutines function foo () -- assert(debug.getinfo(1).currentline == debug.getinfo(foo).linedefined + 1) -- assert(debug.getinfo(2).currentline == debug.getinfo(goo).linedefined) coroutine.yield(3) error("foo") end local fooerr = "closure.lua:284: foo" function goo() foo() end x = coroutine.wrap(goo) assert(x() == 3) local a,b = pcall(x) assert(not a and b == fooerr) x = coroutine.create(goo) a,b = coroutine.resume(x) assert(a and b == 3) a,b = coroutine.resume(x) assert(not a and b == fooerr and coroutine.status(x) == "dead") a,b = coroutine.resume(x) assert(not a and string.find(b, "dead") and coroutine.status(x) == "dead") -- co-routines x for loop function all (a, n, k) if k == 0 then coroutine.yield(a) else for i=1,n do a[k] = i all(a, n, k-1) end end end local a = 0 for t in coroutine.wrap(function () all({}, 5, 4) end) do a = a+1 end assert(a == 5^4) -- access to locals of collected coroutines local C = {}; setmetatable(C, {__mode = "kv"}) local x = coroutine.wrap (function () local a = 10 local function f () a = a+10; return a end while true do a = a+1 coroutine.yield(f) end end) C[1] = x; local f = x() assert(f() == 21 and x()() == 32 and x() == f) x = nil collectgarbage() -- assert(C[1] == nil) assert(f() == 43 and f() == 53) -- old bug: attempt to resume itself function co_func (current_co) assert(coroutine.running() == current_co) assert(coroutine.resume(current_co) == false) assert(coroutine.resume(current_co) == false) return 10 end local co = coroutine.create(co_func) local a,b = coroutine.resume(co, co) assert(a == true and b == 10) assert(coroutine.resume(co, co) == false) assert(coroutine.resume(co, co) == false) -- access to locals of erroneous coroutines local x = coroutine.create (function () local a = 10 _G.f = function () a=a+1; return a end error('x') end) assert(not coroutine.resume(x)) -- overwrite previous position of local `a' assert(not coroutine.resume(x, 1, 1, 1, 1, 1, 1, 1)) assert(_G.f() == 11) assert(_G.f() == 12) if not T then (Message or print)('\a\n >>> testC not active: skipping yield/hook tests <<<\n\a') else local turn function fact (t, x) assert(turn == t) if x == 0 then return 1 else return x*fact(t, x-1) end end local A,B,a,b = 0,0,0,0 local x = coroutine.create(function () T.setyhook("", 2) A = fact("A", 10) end) local y = coroutine.create(function () T.setyhook("", 3) B = fact("B", 11) end) while A==0 or B==0 do if A==0 then turn = "A"; T.resume(x) end if B==0 then turn = "B"; T.resume(y) end end assert(B/A == 11) end -- leaving a pending coroutine open _X = coroutine.wrap(function () local a = 10 local x = function () a = a+1 end coroutine.yield() end) _X() -- coroutine environments co = coroutine.create(function () coroutine.yield(getfenv(0)) return loadstring("return a")() end) return 'OK'