Fix table traversal used in recursion detection.

This fixes serializing same table multiple times within a parent table.
This commit is contained in:
Alex Orlenko 2021-09-28 16:10:09 +01:00
parent d586eef0f5
commit bdd3c923ba
No known key found for this signature in database
GPG Key ID: 4C150C250863B96D
3 changed files with 63 additions and 14 deletions

View File

@ -1721,6 +1721,7 @@ impl Lua {
}
#[cfg(feature = "serialize")]
#[inline]
pub(crate) unsafe fn get_ref_ptr(&self, lref: &LuaRef) -> *const c_void {
ffi::lua_topointer((*self.extra.get()).ref_thread, lref.index)
}

View File

@ -7,7 +7,7 @@ use std::string::String as StdString;
use serde::de::{self, IntoDeserializer};
use crate::error::{Error, Result};
use crate::table::{TablePairs, TableSequence};
use crate::table::{Table, TablePairs, TableSequence};
use crate::value::Value;
/// A struct for deserializing Lua values into Rust values.
@ -158,11 +158,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
where
V: de::Visitor<'de>,
{
let (variant, value) = match self.value {
let (variant, value, _guard) = match self.value {
Value::Table(table) => {
let lua = table.0.lua;
let ptr = unsafe { lua.get_ref_ptr(&table.0) };
self.visited.borrow_mut().insert(ptr);
let _guard = RecursionGuard::new(&table, &self.visited);
let mut iter = table.pairs::<StdString, Value>();
let (variant, value) = match iter.next() {
@ -185,9 +183,9 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
return Err(de::Error::custom("bad enum value"));
}
(variant, Some(value))
(variant, Some(value), Some(_guard))
}
Value::String(variant) => (variant.to_str()?.to_owned(), None),
Value::String(variant) => (variant.to_str()?.to_owned(), None, None),
_ => return Err(de::Error::custom("bad enum value")),
};
@ -206,9 +204,7 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
{
match self.value {
Value::Table(t) => {
let lua = t.0.lua;
let ptr = unsafe { lua.get_ref_ptr(&t.0) };
self.visited.borrow_mut().insert(ptr);
let _guard = RecursionGuard::new(&t, &self.visited);
let len = t.raw_len() as usize;
let mut deserializer = SeqDeserializer {
@ -261,9 +257,7 @@ impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
{
match self.value {
Value::Table(t) => {
let lua = t.0.lua;
let ptr = unsafe { lua.get_ref_ptr(&t.0) };
self.visited.borrow_mut().insert(ptr);
let _guard = RecursionGuard::new(&t, &self.visited);
let mut deserializer = MapDeserializer {
pairs: t.pairs(),
@ -495,10 +489,34 @@ impl<'lua, 'de> de::VariantAccess<'de> for VariantDeserializer<'lua> {
}
}
// Adds `ptr` to the `visited` map and removes on drop
// Used to track recursive tables but allow to traverse same tables multiple times
struct RecursionGuard {
ptr: *const c_void,
visited: Rc<RefCell<HashSet<*const c_void>>>,
}
impl RecursionGuard {
#[inline]
fn new(table: &Table, visited: &Rc<RefCell<HashSet<*const c_void>>>) -> Self {
let visited = Rc::clone(visited);
let ptr = unsafe { table.0.lua.get_ref_ptr(&table.0) };
visited.borrow_mut().insert(ptr);
RecursionGuard { ptr, visited }
}
}
impl Drop for RecursionGuard {
fn drop(&mut self) {
self.visited.borrow_mut().remove(&self.ptr);
}
}
// Checks `options` and decides should we emit an error or skip next element
fn check_value_if_skip(
value: &Value,
options: Options,
visited: &Rc<RefCell<HashSet<*const c_void>>>,
visited: &RefCell<HashSet<*const c_void>>,
) -> Result<bool> {
match value {
Value::Table(table) => {

View File

@ -305,6 +305,36 @@ fn test_to_value_with_options() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[test]
fn test_from_value_nested_tables() -> Result<(), Box<dyn std::error::Error>> {
let lua = Lua::new();
let value = lua
.load(
r#"
local table_a = {a = "a"}
local table_b = {"b"}
return {
a = table_a,
b = {table_b, table_b},
ab = {a = table_a, b = table_b}
}
"#,
)
.eval::<Value>()?;
let got = lua.from_value::<serde_json::Value>(value)?;
assert_eq!(
got,
serde_json::json!({
"a": {"a": "a"},
"b": [["b"], ["b"]],
"ab": {"a": {"a": "a"}, "b": ["b"]},
})
);
Ok(())
}
#[test]
fn test_from_value_struct() -> Result<(), Box<dyn std::error::Error>> {
let lua = Lua::new();