Fix table traversal used in recursion detection.
This fixes serializing same table multiple times within a parent table.
This commit is contained in:
parent
d586eef0f5
commit
bdd3c923ba
|
@ -1721,6 +1721,7 @@ impl Lua {
|
|||
}
|
||||
|
||||
#[cfg(feature = "serialize")]
|
||||
#[inline]
|
||||
pub(crate) unsafe fn get_ref_ptr(&self, lref: &LuaRef) -> *const c_void {
|
||||
ffi::lua_topointer((*self.extra.get()).ref_thread, lref.index)
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use std::string::String as StdString;
|
|||
use serde::de::{self, IntoDeserializer};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::table::{TablePairs, TableSequence};
|
||||
use crate::table::{Table, TablePairs, TableSequence};
|
||||
use crate::value::Value;
|
||||
|
||||
/// A struct for deserializing Lua values into Rust values.
|
||||
|
@ -158,11 +158,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
|
|||
where
|
||||
V: de::Visitor<'de>,
|
||||
{
|
||||
let (variant, value) = match self.value {
|
||||
let (variant, value, _guard) = match self.value {
|
||||
Value::Table(table) => {
|
||||
let lua = table.0.lua;
|
||||
let ptr = unsafe { lua.get_ref_ptr(&table.0) };
|
||||
self.visited.borrow_mut().insert(ptr);
|
||||
let _guard = RecursionGuard::new(&table, &self.visited);
|
||||
|
||||
let mut iter = table.pairs::<StdString, Value>();
|
||||
let (variant, value) = match iter.next() {
|
||||
|
@ -185,9 +183,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
|
|||
return Err(de::Error::custom("bad enum value"));
|
||||
}
|
||||
|
||||
(variant, Some(value))
|
||||
(variant, Some(value), Some(_guard))
|
||||
}
|
||||
Value::String(variant) => (variant.to_str()?.to_owned(), None),
|
||||
Value::String(variant) => (variant.to_str()?.to_owned(), None, None),
|
||||
_ => return Err(de::Error::custom("bad enum value")),
|
||||
};
|
||||
|
||||
|
@ -206,9 +204,7 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
|
|||
{
|
||||
match self.value {
|
||||
Value::Table(t) => {
|
||||
let lua = t.0.lua;
|
||||
let ptr = unsafe { lua.get_ref_ptr(&t.0) };
|
||||
self.visited.borrow_mut().insert(ptr);
|
||||
let _guard = RecursionGuard::new(&t, &self.visited);
|
||||
|
||||
let len = t.raw_len() as usize;
|
||||
let mut deserializer = SeqDeserializer {
|
||||
|
@ -261,9 +257,7 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
|
|||
{
|
||||
match self.value {
|
||||
Value::Table(t) => {
|
||||
let lua = t.0.lua;
|
||||
let ptr = unsafe { lua.get_ref_ptr(&t.0) };
|
||||
self.visited.borrow_mut().insert(ptr);
|
||||
let _guard = RecursionGuard::new(&t, &self.visited);
|
||||
|
||||
let mut deserializer = MapDeserializer {
|
||||
pairs: t.pairs(),
|
||||
|
@ -495,10 +489,34 @@ impl<'lua, 'de> de::VariantAccess<'de> for VariantDeserializer<'lua> {
|
|||
}
|
||||
}
|
||||
|
||||
// Adds `ptr` to the `visited` map and removes on drop
|
||||
// Used to track recursive tables but allow to traverse same tables multiple times
|
||||
struct RecursionGuard {
|
||||
ptr: *const c_void,
|
||||
visited: Rc<RefCell<HashSet<*const c_void>>>,
|
||||
}
|
||||
|
||||
impl RecursionGuard {
|
||||
#[inline]
|
||||
fn new(table: &Table, visited: &Rc<RefCell<HashSet<*const c_void>>>) -> Self {
|
||||
let visited = Rc::clone(visited);
|
||||
let ptr = unsafe { table.0.lua.get_ref_ptr(&table.0) };
|
||||
visited.borrow_mut().insert(ptr);
|
||||
RecursionGuard { ptr, visited }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RecursionGuard {
|
||||
fn drop(&mut self) {
|
||||
self.visited.borrow_mut().remove(&self.ptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Checks `options` and decides should we emit an error or skip next element
|
||||
fn check_value_if_skip(
|
||||
value: &Value,
|
||||
options: Options,
|
||||
visited: &Rc<RefCell<HashSet<*const c_void>>>,
|
||||
visited: &RefCell<HashSet<*const c_void>>,
|
||||
) -> Result<bool> {
|
||||
match value {
|
||||
Value::Table(table) => {
|
||||
|
|
|
@ -305,6 +305,36 @@ fn test_to_value_with_options() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_value_nested_tables() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let lua = Lua::new();
|
||||
|
||||
let value = lua
|
||||
.load(
|
||||
r#"
|
||||
local table_a = {a = "a"}
|
||||
local table_b = {"b"}
|
||||
return {
|
||||
a = table_a,
|
||||
b = {table_b, table_b},
|
||||
ab = {a = table_a, b = table_b}
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.eval::<Value>()?;
|
||||
let got = lua.from_value::<serde_json::Value>(value)?;
|
||||
assert_eq!(
|
||||
got,
|
||||
serde_json::json!({
|
||||
"a": {"a": "a"},
|
||||
"b": [["b"], ["b"]],
|
||||
"ab": {"a": {"a": "a"}, "b": ["b"]},
|
||||
})
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_value_struct() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let lua = Lua::new();
|
||||
|
|
Loading…
Reference in New Issue