From c6a3985dfd0995a816795ec4880de56c35a0c6a5 Mon Sep 17 00:00:00 2001 From: Kogia-sima Date: Sat, 19 Dec 2020 16:50:28 +0900 Subject: [PATCH] perf: optimize simd escaping for small strings --- sailfish/src/runtime/escape/avx2.rs | 41 +++--------------------- sailfish/src/runtime/escape/fallback.rs | 2 +- sailfish/src/runtime/escape/sse2.rs | 42 +++---------------------- 3 files changed, 10 insertions(+), 75 deletions(-) diff --git a/sailfish/src/runtime/escape/avx2.rs b/sailfish/src/runtime/escape/avx2.rs index b7ee822..9a40bb0 100644 --- a/sailfish/src/runtime/escape/avx2.rs +++ b/sailfish/src/runtime/escape/avx2.rs @@ -10,20 +10,19 @@ use super::super::Buffer; use super::{ESCAPED, ESCAPED_LEN, ESCAPE_LUT}; const VECTOR_BYTES: usize = std::mem::size_of::<__m256i>(); -const VECTOR_ALIGN: usize = VECTOR_BYTES - 1; #[target_feature(enable = "avx2")] pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { debug_assert!(feed.len() >= 16); let len = feed.len(); - if len < VECTOR_BYTES { escape_small(feed, buffer); return; } let mut start_ptr = feed.as_ptr(); + let mut ptr = start_ptr; let end_ptr = start_ptr.add(len); let v_independent1 = _mm256_set1_epi8(5); @@ -38,39 +37,8 @@ pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { )) as u32 }; - let mut ptr = start_ptr; - let aligned_ptr = ptr.add(VECTOR_BYTES - (start_ptr as usize & VECTOR_ALIGN)); - - { + while ptr <= end_ptr.sub(VECTOR_BYTES) { let mut mask = maskgen(_mm256_loadu_si256(ptr as *const __m256i)); - loop { - let trailing_zeros = mask.trailing_zeros() as usize; - let ptr2 = ptr.add(trailing_zeros); - if ptr2 >= aligned_ptr { - break; - } - - let c = ESCAPE_LUT[*ptr2 as usize] as usize; - if c < ESCAPED_LEN { - if start_ptr < ptr2 { - let slc = slice::from_raw_parts( - start_ptr, - ptr2 as usize - start_ptr as usize, - ); - buffer.push_str(std::str::from_utf8_unchecked(slc)); - } - buffer.push_str(*ESCAPED.get_unchecked(c)); - start_ptr = ptr2.add(1); - } - mask ^= 1 << trailing_zeros; - } - } - - ptr = aligned_ptr; - let mut next_ptr = ptr.add(VECTOR_BYTES); - - while next_ptr <= end_ptr { - let mut mask = maskgen(_mm256_load_si256(ptr as *const __m256i)); while mask != 0 { let trailing_zeros = mask.trailing_zeros() as usize; let ptr2 = ptr.add(trailing_zeros); @@ -89,11 +57,10 @@ pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { mask ^= 1 << trailing_zeros; } - ptr = next_ptr; - next_ptr = next_ptr.add(VECTOR_BYTES); + ptr = ptr.add(VECTOR_BYTES); } - debug_assert!(next_ptr > end_ptr); + debug_assert!(ptr.add(VECTOR_BYTES) > end_ptr); if ptr < end_ptr { debug_assert!((end_ptr as usize - ptr as usize) < VECTOR_BYTES); diff --git a/sailfish/src/runtime/escape/fallback.rs b/sailfish/src/runtime/escape/fallback.rs index 1b13720..8580ef8 100644 --- a/sailfish/src/runtime/escape/fallback.rs +++ b/sailfish/src/runtime/escape/fallback.rs @@ -58,7 +58,7 @@ pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { ptr = aligned_ptr; - while ptr.add(USIZE_BYTES) <= end_ptr { + while ptr <= end_ptr.sub(USIZE_BYTES) { debug_assert_eq!((ptr as usize) % USIZE_BYTES, 0); let chunk = *(ptr as *const usize); diff --git a/sailfish/src/runtime/escape/sse2.rs b/sailfish/src/runtime/escape/sse2.rs index 7287692..4f605e3 100644 --- a/sailfish/src/runtime/escape/sse2.rs +++ b/sailfish/src/runtime/escape/sse2.rs @@ -10,12 +10,12 @@ use super::super::Buffer; use super::{ESCAPED, ESCAPED_LEN, ESCAPE_LUT}; const VECTOR_BYTES: usize = std::mem::size_of::<__m128i>(); -const VECTOR_ALIGN: usize = VECTOR_BYTES - 1; #[target_feature(enable = "sse2")] pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { let len = feed.len(); let mut start_ptr = feed.as_ptr(); + let mut ptr = start_ptr; let end_ptr = start_ptr.add(len); let v_independent1 = _mm_set1_epi8(5); @@ -30,40 +30,9 @@ pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { )) as u32 }; - let mut ptr = start_ptr; - let aligned_ptr = ptr.add(VECTOR_BYTES - (start_ptr as usize & VECTOR_ALIGN)); - - { - let mut mask = maskgen(_mm_loadu_si128(ptr as *const __m128i)); - loop { - let trailing_zeros = mask.trailing_zeros() as usize; - let ptr2 = ptr.add(trailing_zeros); - if ptr2 >= aligned_ptr { - break; - } - - let c = ESCAPE_LUT[*ptr2 as usize] as usize; - if c < ESCAPED_LEN { - if start_ptr < ptr2 { - let slc = slice::from_raw_parts( - start_ptr, - ptr2 as usize - start_ptr as usize, - ); - buffer.push_str(std::str::from_utf8_unchecked(slc)); - } - buffer.push_str(*ESCAPED.get_unchecked(c)); - start_ptr = ptr2.add(1); - } - mask ^= 1 << trailing_zeros; - } - } - - ptr = aligned_ptr; - let mut next_ptr = ptr.add(VECTOR_BYTES); - - while next_ptr <= end_ptr { + while ptr <= end_ptr.sub(VECTOR_BYTES) { debug_assert_eq!((ptr as usize) % VECTOR_BYTES, 0); - let mut mask = maskgen(_mm_load_si128(ptr as *const __m128i)); + let mut mask = maskgen(_mm_loadu_si128(ptr as *const __m128i)); while mask != 0 { let trailing_zeros = mask.trailing_zeros() as usize; let ptr2 = ptr.add(trailing_zeros); @@ -82,11 +51,10 @@ pub unsafe fn escape(feed: &str, buffer: &mut Buffer) { mask ^= 1 << trailing_zeros; } - ptr = next_ptr; - next_ptr = next_ptr.add(VECTOR_BYTES); + ptr = ptr.add(VECTOR_BYTES); } - debug_assert!(next_ptr > end_ptr); + debug_assert!(ptr.add(VECTOR_BYTES) > end_ptr); if ptr < end_ptr { debug_assert!((end_ptr as usize - ptr as usize) < VECTOR_BYTES);