Two major API changes:

Allow load to return values, allows reimplementing require() like functions
properly.

Make globals table explicit in Lua, remove Lua::get / Lua::set in favor of
Lua::globals.  Allows obeying globals metatable, using other Table functions on
the globals table.

Also added "has" method as shorthand for checking whether a table entry is not
nil.
This commit is contained in:
kyren 2017-06-11 01:12:25 -04:00
parent 8203414b76
commit 5c8aa19b8d
4 changed files with 188 additions and 148 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "rlua"
version = "0.4.6"
version = "0.5.0-pre"
authors = ["kyren <catherine@chucklefish.org>"]
description = "High level bindings to Lua 5.3"
repository = "https://github.com/chucklefish/rlua"

View File

@ -10,23 +10,26 @@ fn examples() -> LuaResult<()> {
// functionality.
let lua = Lua::new();
let globals = lua.globals()?;
// You can get and set global variables
lua.set("string_var", "hello")?;
lua.set("int_var", 42)?;
globals.set("string_var", "hello")?;
globals.set("int_var", 42)?;
assert_eq!(lua.get::<_, String>("string_var")?, "hello");
assert_eq!(lua.get::<_, i64>("int_var")?, 42);
assert_eq!(globals.get::<_, String>("string_var")?, "hello");
assert_eq!(globals.get::<_, i64>("int_var")?, 42);
// You can load and evaluate lua code. The second parameter here gives the chunk a better name
// when lua error messages are printed.
lua.load(r#"
lua.load::<()>(
r#"
global = 'foo'..'bar'
"#,
Some("example code"))?;
assert_eq!(lua.get::<_, String>("global")?, "foobar");
Some("example code"),
)?;
assert_eq!(globals.get::<_, String>("global")?, "foobar");
assert_eq!(lua.eval::<i32>("1 + 1")?, 2);
assert_eq!(lua.eval::<bool>("false == false")?, true);
@ -49,10 +52,11 @@ fn examples() -> LuaResult<()> {
// You can pass values like LuaTable back into Lua
lua.set("array_table", array_table)?;
lua.set("map_table", map_table)?;
globals.set("array_table", array_table)?;
globals.set("map_table", map_table)?;
lua.eval::<()>(r#"
lua.eval::<()>(
r#"
for k, v in pairs(array_table) do
print(k, v)
end
@ -60,11 +64,12 @@ fn examples() -> LuaResult<()> {
for k, v in pairs(map_table) do
print(k, v)
end
"#)?;
"#,
)?;
// You can load lua functions
let print: LuaFunction = lua.get("print")?;
let print: LuaFunction = globals.get("print")?;
print.call::<_, ()>("hello from rust")?;
// There is a specific method for handling variadics that involves Heterogeneous Lists. This
@ -90,7 +95,7 @@ fn examples() -> LuaResult<()> {
// signature limitations.
lua.pack(list1 == list2)
})?;
lua.set("check_equal", check_equal)?;
globals.set("check_equal", check_equal)?;
// You can also accept variadic arguments to rust functions
let join = lua.create_function(|lua, args| {
@ -98,7 +103,7 @@ fn examples() -> LuaResult<()> {
// (This is quadratic!, it's just an example!)
lua.pack(strings.iter().fold("".to_owned(), |a, b| a + b))
})?;
lua.set("join", join)?;
globals.set("join", join)?;
assert_eq!(lua.eval::<bool>(r#"check_equal({"a", "b", "c"}, {"a", "b", "c"})"#)?,
true);
@ -127,10 +132,10 @@ fn examples() -> LuaResult<()> {
}
let vec2_constructor = lua.create_function(|lua, args| {
let hlist_pat![x, y] = lua.unpack::<HList![f32, f32]>(args)?;
lua.pack(Vec2(x, y))
})?;
lua.set("vec2", vec2_constructor)?;
let hlist_pat![x, y] = lua.unpack::<HList![f32, f32]>(args)?;
lua.pack(Vec2(x, y))
})?;
globals.set("vec2", vec2_constructor)?;
assert_eq!(lua.eval::<f32>("(vec2(1, 2) + vec2(2, 2)):magnitude()")?,
5.0);

View File

@ -183,6 +183,23 @@ impl<'lua> LuaTable<'lua> {
}
}
/// Shorthand for checking whether get(key) is nil
pub fn has<K: ToLua<'lua>>(&self, key: K) -> LuaResult<bool> {
let lua = self.0.lua;
let key = key.to_lua(lua)?;
unsafe {
error_guard(lua.state, 0, 0, |state| {
check_stack(state, 2)?;
lua.push_ref(state, &self.0);
lua.push_value(state, key)?;
ffi::lua_gettable(state, -2);
let has = ffi::lua_isnil(state, -1) == 0;
ffi::lua_pop(state, 2);
Ok(has)
})
}
}
/// Set a field in the table, without invoking metamethods
pub fn raw_set<K: ToLua<'lua>, V: ToLua<'lua>>(&self, key: K, value: V) -> LuaResult<()> {
let lua = self.0.lua;
@ -681,8 +698,7 @@ impl Lua {
ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX);
Ok(())
})
.unwrap();
}).unwrap();
stack_guard(state, 0, || {
ffi::lua_pushlightuserdata(state,
@ -701,8 +717,7 @@ impl Lua {
ffi::lua_rawset(state, ffi::LUA_REGISTRYINDEX);
Ok(())
})
.unwrap();
}).unwrap();
stack_guard(state, 0, || {
ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_GLOBALS);
@ -717,8 +732,7 @@ impl Lua {
ffi::lua_pop(state, 1);
Ok(())
})
.unwrap();
}).unwrap();
Lua {
state,
@ -728,9 +742,13 @@ impl Lua {
}
}
pub fn load(&self, source: &str, name: Option<&str>) -> LuaResult<()> {
pub fn load<'lua, R: FromLuaMulti<'lua>>(&'lua self,
source: &str,
name: Option<&str>)
-> LuaResult<R> {
unsafe {
stack_guard(self.state, 0, || {
let stack_start = ffi::lua_gettop(self.state);
handle_error(self.state,
if let Some(name) = name {
let name = CString::new(name.to_owned())?;
@ -746,7 +764,15 @@ impl Lua {
})?;
check_stack(self.state, 2)?;
handle_error(self.state, pcall_with_traceback(self.state, 0, 0))
handle_error(self.state,
pcall_with_traceback(self.state, 0, ffi::LUA_MULTRET))?;
let nresults = ffi::lua_gettop(self.state) - stack_start;
let mut results = LuaMultiValue::new();
for _ in 0..nresults {
results.push_front(self.pop_value(self.state)?);
}
R::from_lua_multi(results, self)
})
}
}
@ -878,37 +904,11 @@ impl Lua {
}
}
pub fn set<'lua, K, V>(&'lua self, key: K, value: V) -> LuaResult<()>
where K: ToLua<'lua>,
V: ToLua<'lua>
{
pub fn globals<'lua>(&'lua self) -> LuaResult<LuaTable<'lua>> {
unsafe {
stack_guard(self.state, 0, || {
check_stack(self.state, 3)?;
ffi::lua_rawgeti(self.state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_GLOBALS);
self.push_value(self.state, key.to_lua(self)?)?;
self.push_value(self.state, value.to_lua(self)?)?;
ffi::lua_rawset(self.state, -3);
ffi::lua_pop(self.state, 1);
Ok(())
})
}
}
pub fn get<'lua, K, V>(&'lua self, key: K) -> LuaResult<V>
where K: ToLua<'lua>,
V: FromLua<'lua>
{
unsafe {
stack_guard(self.state, 0, || {
check_stack(self.state, 2)?;
ffi::lua_rawgeti(self.state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_GLOBALS);
self.push_value(self.state, key.to_lua(self)?)?;
ffi::lua_gettable(self.state, -2);
let ret = self.pop_value(self.state)?;
ffi::lua_pop(self.state, 1);
V::from_lua(ret, self)
})
check_stack(self.state, 1)?;
ffi::lua_rawgeti(self.state, ffi::LUA_REGISTRYINDEX, ffi::LUA_RIDX_GLOBALS);
Ok(LuaTable(self.pop_ref(self.state)))
}
}

View File

@ -9,23 +9,44 @@ use super::*;
#[test]
fn test_set_get() {
let lua = Lua::new();
lua.set("foo", "bar").unwrap();
lua.set("baz", "baf").unwrap();
assert_eq!(lua.get::<_, String>("foo").unwrap(), "bar");
assert_eq!(lua.get::<_, String>("baz").unwrap(), "baf");
let globals = lua.globals().unwrap();
globals.set("foo", "bar").unwrap();
globals.set("baz", "baf").unwrap();
assert_eq!(globals.get::<_, String>("foo").unwrap(), "bar");
assert_eq!(globals.get::<_, String>("baz").unwrap(), "baf");
}
#[test]
fn test_load() {
let lua = Lua::new();
lua.load(
let globals = lua.globals().unwrap();
lua.load::<()>(
r#"
res = 'foo'..'bar'
"#,
None,
)
.unwrap();
assert_eq!(lua.get::<_, String>("res").unwrap(), "foobar");
).unwrap();
assert_eq!(globals.get::<_, String>("res").unwrap(), "foobar");
let module: LuaTable = lua.load(
r#"
local module = {}
function module.func()
return "hello"
end
return module
"#,
None,
).unwrap();
assert!(module.has("func").unwrap());
assert_eq!(module
.get::<_, LuaFunction>("func")
.unwrap()
.call::<_, String>(())
.unwrap(),
"hello");
}
#[test]
@ -43,10 +64,13 @@ fn test_eval() {
#[test]
fn test_table() {
let lua = Lua::new();
let globals = lua.globals().unwrap();
lua.set("table", lua.create_empty_table().unwrap()).unwrap();
let table1: LuaTable = lua.get("table").unwrap();
let table2: LuaTable = lua.get("table").unwrap();
globals
.set("table", lua.create_empty_table().unwrap())
.unwrap();
let table1: LuaTable = globals.get("table").unwrap();
let table2: LuaTable = globals.get("table").unwrap();
table1.set("foo", "bar").unwrap();
table2.set("baz", "baf").unwrap();
@ -54,19 +78,18 @@ fn test_table() {
assert_eq!(table2.get::<_, String>("foo").unwrap(), "bar");
assert_eq!(table1.get::<_, String>("baz").unwrap(), "baf");
lua.load(
lua.load::<()>(
r#"
table1 = {1, 2, 3, 4, 5}
table2 = {}
table3 = {1, 2, nil, 4, 5}
"#,
None,
)
.unwrap();
).unwrap();
let table1 = lua.get::<_, LuaTable>("table1").unwrap();
let table2 = lua.get::<_, LuaTable>("table2").unwrap();
let table3 = lua.get::<_, LuaTable>("table3").unwrap();
let table1 = globals.get::<_, LuaTable>("table1").unwrap();
let table2 = globals.get::<_, LuaTable>("table2").unwrap();
let table3 = globals.get::<_, LuaTable>("table3").unwrap();
assert_eq!(table1.length().unwrap(), 5);
assert_eq!(table1.pairs::<i64, i64>().unwrap(),
@ -78,10 +101,11 @@ fn test_table() {
assert_eq!(table3.array_values::<Option<i64>>().unwrap(),
vec![Some(1), Some(2), None, Some(4), Some(5)]);
lua.set("table4",
lua.create_array_table(vec![1, 2, 3, 4, 5]).unwrap())
globals
.set("table4",
lua.create_array_table(vec![1, 2, 3, 4, 5]).unwrap())
.unwrap();
let table4 = lua.get::<_, LuaTable>("table4").unwrap();
let table4 = globals.get::<_, LuaTable>("table4").unwrap();
assert_eq!(table4.pairs::<i64, i64>().unwrap(),
vec![(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]);
}
@ -89,17 +113,17 @@ fn test_table() {
#[test]
fn test_function() {
let lua = Lua::new();
lua.load(
let globals = lua.globals().unwrap();
lua.load::<()>(
r#"
function concat(arg1, arg2)
return arg1 .. arg2
end
"#,
None,
)
.unwrap();
).unwrap();
let concat = lua.get::<_, LuaFunction>("concat").unwrap();
let concat = globals.get::<_, LuaFunction>("concat").unwrap();
assert_eq!(concat.call::<_, String>(hlist!["foo", "bar"]).unwrap(),
"foobar");
}
@ -107,7 +131,8 @@ fn test_function() {
#[test]
fn test_bind() {
let lua = Lua::new();
lua.load(
let globals = lua.globals().unwrap();
lua.load::<()>(
r#"
function concat(...)
local res = ""
@ -118,10 +143,9 @@ fn test_bind() {
end
"#,
None,
)
.unwrap();
).unwrap();
let mut concat = lua.get::<_, LuaFunction>("concat").unwrap();
let mut concat = globals.get::<_, LuaFunction>("concat").unwrap();
concat = concat.bind("foo").unwrap();
concat = concat.bind("bar").unwrap();
concat = concat.bind(hlist!["baz", "baf"]).unwrap();
@ -132,7 +156,8 @@ fn test_bind() {
#[test]
fn test_rust_function() {
let lua = Lua::new();
lua.load(
let globals = lua.globals().unwrap();
lua.load::<()>(
r#"
function lua_function()
return rust_function()
@ -142,13 +167,12 @@ fn test_rust_function() {
return 1
"#,
None,
)
.unwrap();
).unwrap();
let lua_function = lua.get::<_, LuaFunction>("lua_function").unwrap();
let lua_function = globals.get::<_, LuaFunction>("lua_function").unwrap();
let rust_function = lua.create_function(|lua, _| lua.pack("hello")).unwrap();
lua.set("rust_function", rust_function).unwrap();
globals.set("rust_function", rust_function).unwrap();
assert_eq!(lua_function.call::<_, String>(()).unwrap(), "hello");
}
@ -189,9 +213,10 @@ fn test_methods() {
}
let lua = Lua::new();
let globals = lua.globals().unwrap();
let userdata = lua.create_userdata(UserData(42)).unwrap();
lua.set("userdata", userdata.clone()).unwrap();
lua.load(
globals.set("userdata", userdata.clone()).unwrap();
lua.load::<()>(
r#"
function get_it()
return userdata:get_value()
@ -202,10 +227,9 @@ fn test_methods() {
end
"#,
None,
)
.unwrap();
let get = lua.get::<_, LuaFunction>("get_it").unwrap();
let set = lua.get::<_, LuaFunction>("set_it").unwrap();
).unwrap();
let get = globals.get::<_, LuaFunction>("get_it").unwrap();
let set = globals.get::<_, LuaFunction>("set_it").unwrap();
assert_eq!(get.call::<_, i64>(()).unwrap(), 42);
userdata.borrow_mut::<UserData>().unwrap().0 = 64;
assert_eq!(get.call::<_, i64>(()).unwrap(), 64);
@ -241,8 +265,9 @@ fn test_metamethods() {
}
let lua = Lua::new();
lua.set("userdata1", UserData(7)).unwrap();
lua.set("userdata2", UserData(3)).unwrap();
let globals = lua.globals().unwrap();
globals.set("userdata1", UserData(7)).unwrap();
globals.set("userdata2", UserData(3)).unwrap();
assert_eq!(lua.eval::<UserData>("userdata1 + userdata2").unwrap().0, 10);
assert_eq!(lua.eval::<UserData>("userdata1 - userdata2").unwrap().0, 4);
assert_eq!(lua.eval::<i64>("userdata1:get()").unwrap(), 7);
@ -253,20 +278,20 @@ fn test_metamethods() {
#[test]
fn test_scope() {
let lua = Lua::new();
lua.load(
let globals = lua.globals().unwrap();
lua.load::<()>(
r#"
touter = {
tin = {1, 2, 3}
}
"#,
None,
)
.unwrap();
).unwrap();
// Make sure that table gets do not borrow the table, but instead just borrow lua.
let tin;
{
let touter = lua.get::<_, LuaTable>("touter").unwrap();
let touter = globals.get::<_, LuaTable>("touter").unwrap();
tin = touter.get::<_, LuaTable>("tin").unwrap();
}
@ -279,7 +304,7 @@ fn test_scope() {
// impl LuaUserDataType for UserData {};
// let userdata_ref;
// {
// let touter = lua.get::<_, LuaTable>("touter").unwrap();
// let touter = globals.get::<_, LuaTable>("touter").unwrap();
// touter.set("userdata", lua.create_userdata(UserData).unwrap()).unwrap();
// let userdata = touter.get::<_, LuaUserData>("userdata").unwrap();
// userdata_ref = userdata.borrow::<UserData>();
@ -289,7 +314,8 @@ fn test_scope() {
#[test]
fn test_lua_multi() {
let lua = Lua::new();
lua.load(
let globals = lua.globals().unwrap();
lua.load::<()>(
r#"
function concat(arg1, arg2)
return arg1 .. arg2
@ -300,11 +326,10 @@ fn test_lua_multi() {
end
"#,
None,
)
.unwrap();
).unwrap();
let concat = lua.get::<_, LuaFunction>("concat").unwrap();
let mreturn = lua.get::<_, LuaFunction>("mreturn").unwrap();
let concat = globals.get::<_, LuaFunction>("concat").unwrap();
let mreturn = globals.get::<_, LuaFunction>("mreturn").unwrap();
assert_eq!(concat.call::<_, String>(hlist!["foo", "bar"]).unwrap(),
"foobar");
@ -318,19 +343,19 @@ fn test_lua_multi() {
#[test]
fn test_coercion() {
let lua = Lua::new();
lua.load(
let globals = lua.globals().unwrap();
lua.load::<()>(
r#"
int = 123
str = "123"
num = 123.0
"#,
None,
)
.unwrap();
).unwrap();
assert_eq!(lua.get::<_, String>("int").unwrap(), "123");
assert_eq!(lua.get::<_, i32>("str").unwrap(), 123);
assert_eq!(lua.get::<_, i32>("num").unwrap(), 123);
assert_eq!(globals.get::<_, String>("int").unwrap(), "123");
assert_eq!(globals.get::<_, i32>("str").unwrap(), 123);
assert_eq!(globals.get::<_, i32>("num").unwrap(), 123);
}
#[test]
@ -355,7 +380,8 @@ fn test_error() {
}
let lua = Lua::new();
lua.load(
let globals = lua.globals().unwrap();
lua.load::<()>(
r#"
function no_error()
end
@ -395,19 +421,22 @@ fn test_error() {
end
"#,
None,
)
.unwrap();
).unwrap();
let rust_error_function =
lua.create_function(|_, _| Err(LuaExternalError(Box::new(TestError)).into()))
.unwrap();
lua.set("rust_error_function", rust_error_function).unwrap();
globals
.set("rust_error_function", rust_error_function)
.unwrap();
let no_error = lua.get::<_, LuaFunction>("no_error").unwrap();
let lua_error = lua.get::<_, LuaFunction>("lua_error").unwrap();
let rust_error = lua.get::<_, LuaFunction>("rust_error").unwrap();
let test_pcall = lua.get::<_, LuaFunction>("test_pcall").unwrap();
let understand_recursion = lua.get::<_, LuaFunction>("understand_recursion").unwrap();
let no_error = globals.get::<_, LuaFunction>("no_error").unwrap();
let lua_error = globals.get::<_, LuaFunction>("lua_error").unwrap();
let rust_error = globals.get::<_, LuaFunction>("rust_error").unwrap();
let test_pcall = globals.get::<_, LuaFunction>("test_pcall").unwrap();
let understand_recursion = globals
.get::<_, LuaFunction>("understand_recursion")
.unwrap();
assert!(no_error.call::<_, ()>(()).is_ok());
match lua_error.call::<_, ()>(()) {
@ -427,7 +456,7 @@ fn test_error() {
match catch_unwind(|| -> LuaResult<()> {
let lua = Lua::new();
lua.load(
lua.load::<()>(
r#"
function rust_panic()
pcall(function () rust_panic_function() end)
@ -438,9 +467,9 @@ fn test_error() {
let rust_panic_function = lua.create_function(|_, _| {
panic!("expected panic, this panic should be caught in rust")
})?;
lua.set("rust_panic_function", rust_panic_function)?;
globals.set("rust_panic_function", rust_panic_function)?;
let rust_panic = lua.get::<_, LuaFunction>("rust_panic")?;
let rust_panic = globals.get::<_, LuaFunction>("rust_panic")?;
rust_panic.call::<_, ()>(())
}) {
@ -451,7 +480,7 @@ fn test_error() {
match catch_unwind(|| -> LuaResult<()> {
let lua = Lua::new();
lua.load(
lua.load::<()>(
r#"
function rust_panic()
xpcall(function() rust_panic_function() end, function() end)
@ -462,9 +491,9 @@ fn test_error() {
let rust_panic_function = lua.create_function(|_, _| {
panic!("expected panic, this panic should be caught in rust")
})?;
lua.set("rust_panic_function", rust_panic_function)?;
globals.set("rust_panic_function", rust_panic_function)?;
let rust_panic = lua.get::<_, LuaFunction>("rust_panic")?;
let rust_panic = globals.get::<_, LuaFunction>("rust_panic")?;
rust_panic.call::<_, ()>(())
}) {
@ -477,15 +506,17 @@ fn test_error() {
#[test]
fn test_thread() {
let lua = Lua::new();
let thread = lua.create_thread(lua.eval::<LuaFunction>(r#"function (s)
let thread = lua.create_thread(
lua.eval::<LuaFunction>(
r#"function (s)
local sum = s
for i = 1,4 do
sum = sum + coroutine.yield(sum)
end
return sum
end"#)
.unwrap())
.unwrap();
end"#,
).unwrap(),
).unwrap();
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active);
assert_eq!(thread.resume::<_, i64>(0).unwrap(), Some(0));
@ -499,13 +530,15 @@ fn test_thread() {
assert_eq!(thread.resume::<_, i64>(4).unwrap(), Some(10));
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Dead);
let accumulate = lua.create_thread(lua.eval::<LuaFunction>(r#"function (sum)
let accumulate = lua.create_thread(
lua.eval::<LuaFunction>(
r#"function (sum)
while true do
sum = sum + coroutine.yield(sum)
end
end"#)
.unwrap())
.unwrap();
end"#,
).unwrap(),
).unwrap();
for i in 0..4 {
accumulate.resume::<_, ()>(i).unwrap();
@ -515,12 +548,13 @@ fn test_thread() {
assert!(accumulate.resume::<_, ()>("error").is_err());
assert_eq!(accumulate.status().unwrap(), LuaThreadStatus::Error);
let thread = lua.eval::<LuaThread>(r#"coroutine.create(function ()
let thread = lua.eval::<LuaThread>(
r#"coroutine.create(function ()
while true do
coroutine.yield(42)
end
end)"#)
.unwrap();
end)"#,
).unwrap();
assert_eq!(thread.status().unwrap(), LuaThreadStatus::Active);
assert_eq!(thread.resume::<_, i64>(()).unwrap(), Some(42));
}
@ -528,14 +562,15 @@ fn test_thread() {
#[test]
fn test_lightuserdata() {
let lua = Lua::new();
lua.load(
let globals = lua.globals().unwrap();
lua.load::<()>(
r#"function id(a)
return a
end"#,
None,
)
.unwrap();
let res = lua.get::<_, LuaFunction>("id")
).unwrap();
let res = globals
.get::<_, LuaFunction>("id")
.unwrap()
.call::<_, LightUserData>(LightUserData(42 as *mut c_void))
.unwrap();
@ -545,7 +580,8 @@ fn test_lightuserdata() {
#[test]
fn test_table_error() {
let lua = Lua::new();
lua.load(
let globals = lua.globals().unwrap();
lua.load::<()>(
r#"
table = {}
setmetatable(table, {
@ -561,10 +597,9 @@ fn test_table_error() {
})
"#,
None,
)
.unwrap();
).unwrap();
let bad_table: LuaTable = lua.get("table").unwrap();
let bad_table: LuaTable = globals.get("table").unwrap();
assert!(bad_table.set(1, 1).is_err());
assert!(bad_table.get::<_, i32>(1).is_err());
assert!(bad_table.length().is_err());