Written on: 2019-09-26

Loops in Rust

Let's implement a Base64 encoder. We will write a function that takes an array of arbitrary bytes as input and returns an array of symbols from the Base64 alphabet.

As the name implies, there are 64 symbols in the Base64 alphabet. That means that each symbol encodes 6 bits. So we would like to iterate through our input 6 bits at a time. But our input is provided as an array of bytes and a byte is 8 bits. The least common multiple of 6 and 8 is 24, so we will have to step through the input 3 bytes at a time, and write out 4 symbols at each step.

What actually is a Base64 symbol? The Base64 encoding is designed to allow arbitrary binary data to be sent through systems that expect to deal with text. Therefore, each symbol is a character: usually an ASCII character. If we use the UTF-8 encoding for our strings, each character will be one byte.

Thus our function will read 3 bytes at a time and write 4 bytes. This 3:4 ratio is the encoding overhead for Base64. It means that any data will expand to 133% of its unencoded size when encoded using Base64. This is the price we pay to send binary data as text.

We'll start with the simplest case, in which our input is an exact multiple of 3 bytes long.

An implementation in C might look like this:

#include <assert.h>
#include <stdint.h>
#include <stdlib.h>

char encoding[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

char *encode(uint8_t *src, size_t src_len) {
    assert(src_len % 3 == 0);
    size_t dst_len = (src_len / 3) * 4;
    char *dst = malloc(dst_len);
    size_t si = 0;
    size_t di = 0;

    while (si < src_len) {
        uint32_t val = (src[si] << 16) | (src[si+1] << 8) | src[si+2];
        dst[di+0] = encoding[val >> 18 & 0x3F];
        dst[di+1] = encoding[val >> 12 & 0x3F];
        dst[di+2] = encoding[val >> 6 & 0x3F];
        dst[di+3] = encoding[val & 0x3F];
        si += 3;
        di += 4;
    }

    return dst;
}

The algorithm is as follows:

  1. Allocate a suitably sized output array
  2. Initialize two indices to zero
  3. Read three bytes from the source array using the source index
  4. Split the three bytes into four 6-bit values
  5. Encode the four values and write them out using the destination index
  6. Increment the source index by three
  7. Increment the destination index by four
  8. Repeat from step 3 until we've encoded all the source bytes

Let's take a closer look at the loop. I will use the Compiler Explorer to view the resulting assembly: I'll include a link so that you can look up the complete assembly if you're interested but I'll only include relevant excerpts here.

I will target x86-64; please follow the links to see the exact compiler versions and options used.

The above C implemention gives us this [1]:

        ; Setup and then...
        test    r14, r14                           ; if (src_len == 0)
        je      .LBB0_4                            ;     goto .LBB0_4
        mov     rcx, rax                           ; dst is result of malloc
        add     rcx, 3                             ; dst += 3
        xor     edx, edx                           ; si = 0
.LBB0_3:
        movzx   ebx, byte ptr [r15 + rdx]          ; load src[si]
        movzx   edi, byte ptr [r15 + rdx + 1]      ; load src[si+1]
        shl     edi, 8
        mov     esi, ebx
        shl     esi, 16
        or      esi, edi
        movzx   r8d, byte ptr [r15 + rdx + 2]      ; load src[si+2]
        or      edi, r8d
        shr     rbx, 2
        movzx   ebx, byte ptr [rbx + encoding]     ; load 1st symbol
        mov     byte ptr [rcx - 3], bl             ; store at dst-3
        shr     esi, 12
        and     esi, 63
        movzx   ebx, byte ptr [rsi + encoding]     ; load 2nd symbol
        mov     byte ptr [rcx - 2], bl             ; store at dst-2
        shr     edi, 6
        and     edi, 63
        movzx   ebx, byte ptr [rdi + encoding]     ; load 3rd symbol
        mov     byte ptr [rcx - 1], bl             ; store at dst-1
        and     r8d, 63
        movzx   ebx, byte ptr [r8 + encoding]      ; load 4th symbol
        mov     byte ptr [rcx], bl                 ; store at dst
        add     rdx, 3                             ; si += 3
        add     rcx, 4                             ; dst += 4
        cmp     rdx, r14                           ; if (si < src_len)
        jb      .LBB0_3                            ;     goto .LBB0_3
.LBB0_4:
        ; The end

It's pretty much what you'd expect. There's a one-off check to make sure that src_len isn't zero and a jump past the whole loop if it is. Inside the loop we have an index in rdx that corresponds to the si index in the C source, and we have the dst pointer in rcx that we increment directly. We increment one register by 3 and the other by 4 on each iteration through the loop. It's a fairly direct translation of the C code.

We can write the same algorithm in Rust:

const ENCODING: [u8; 64] = [
    b'A', b'B', b'C', b'D', b'E', b'F', b'G', b'H',
    b'I', b'J', b'K', b'L', b'M', b'N', b'O', b'P',
    b'Q', b'R', b'S', b'T', b'U', b'V', b'W', b'X',
    b'Y', b'Z', b'a', b'b', b'c', b'd', b'e', b'f',
    b'g', b'h', b'i', b'j', b'k', b'l', b'm', b'n',
    b'o', b'p', b'q', b'r', b's', b't', b'u', b'v',
    b'w', b'x', b'y', b'z', b'0', b'1', b'2', b'3',
    b'4', b'5', b'6', b'7', b'8', b'9', b'+', b'/'
];

pub fn encode(src: &[u8]) -> Vec<u8> {
    assert!(src.len() % 3 == 0);
    let dst_len = (src.len() / 3) * 4;
    let mut dst = vec![0 as u8; dst_len];

    let mut si = 0;
    let mut di = 0;
    let n = src.len();

    loop {
        if si >= n {
            break;
        }
        let val = (src[si+0] as u32) << 16 |
                  (src[si+1] as u32) << 8 |
                  (src[si+2] as u32);
        dst[di+0] = ENCODING[(val >> 18 & 0x3F) as usize];
        dst[di+1] = ENCODING[(val >> 12 & 0x3F) as usize];
        dst[di+2] = ENCODING[(val >> 6 & 0x3F) as usize];
        dst[di+3] = ENCODING[(val & 0x3F) as usize];
        si += 3;
        di += 4;
    }

    dst
}

The first striking thing about this Rust version's assembly [2] is the increase in complexity. Where the C version had only three branches — the assertion, the loop condition and a check that src_len isn't zero — the Rust version has eleven!

This is because the Rust compiler has automatically inserted some error checking. This is generally good: for example, the compiler has inserted a check that the memory allocation for the dst vector succeeds. I did not bother to check what malloc returns in the C version, which would cause the program to crash if it had failed.

It's great that Rust enforces such checks automatically... although you can't currently make it do anything other than abort the program if the check fails, so there's no chance to gracefully handle running out of memory. (There is a proposed API to let the programmer set their own Out-Of-Memory handler but it is currently experimental and not included in stable Rust releases [3].)

The Rust compiler has also inserted bounds checking for all of the array accesses. This is a nice thing in the general case — buffer overflows are the cause of many bugs in C code — but in this case it is a little annoying. We know the sizes of both arrays and we iterate through the arrays based on those sizes: there is no need for run-time bounds checking as we can prove at compile-time that we will never read or write outside the bounds of the arrays.

I had hoped that the bounds checks would be elided, since static analysis can prove that they are unnecessary, but it appears that in this case they are not. (The compiler options used are, at the time of writing, the same as those used by Cargo's default "release" profile.)

This is what we get:

        ; Setup and then...
        test    r14, r14                      ; if (src_len == 0)
        je      .LBB8_18                      ;     goto .LBB8_18
        mov     ebx, 2                        ; si = 2
        mov     esi, 3                        ; di = 3
        lea     r8, [rip + .L__unnamed_7]     ; ENCODING array
.LBB8_6:
        lea     rax, [rbx - 1]
        cmp     rax, r14                      ; if (si - 1 >= src_len)
        jae     .LBB8_20                      ;     goto OUT_OF_BOUNDS
        cmp     rbx, r14                      ; if (si >= src_len)
        jae     .LBB8_21                      ;     goto OUT_OF_BOUNDS
        lea     rax, [rsi - 3]
        cmp     r13, rax                      ; if (dst_len <= di - 3)
        jbe     .LBB8_9                       ;     goto OUT_OF_BOUNDS
        movzx   ecx, byte ptr [r12 + rbx - 2] ; load src[si-2]
        movzx   edx, byte ptr [r12 + rbx - 1] ; load src[si-1]
        movzx   edi, byte ptr [r12 + rbx]     ; load src[si]
        mov     rax, rcx
        shr     rax, 2
        movzx   eax, byte ptr [rax + r8]      ; load 1st symbol
        mov     byte ptr [r9 + rsi - 3], al   ; store at dst[di-3]
        lea     rax, [rsi - 2]
        cmp     r13, rax                      ; if (dst_len <= di - 2)
        jbe     .LBB8_12                      ;     goto OUT_OF_BOUNDS
        shl     rdx, 8
        shl     ecx, 16
        or      ecx, edx
        shr     ecx, 12
        and     ecx, 63
        movzx   eax, byte ptr [rcx + r8]      ; load 2nd symbol
        mov     byte ptr [r9 + rsi - 2], al   ; store at dst[di-2]
        lea     rax, [rsi - 1]
        cmp     r13, rax                      ; if (dst_len <= di - 1)
        jbe     .LBB8_14                      ;     goto OUT_OF_BOUNDS
        or      rdx, rdi
        shr     edx, 6
        and     edx, 63
        movzx   eax, byte ptr [rdx + r8]      ; load 3rd symbol
        mov     byte ptr [r9 + rsi - 1], al   ; store at dst[di-1]
        cmp     r13, rsi                      ; if (dst_len <= di)
        jbe     .LBB8_16                      ;     goto OUT_OF_BOUNDS
        and     edi, 63
        movzx   eax, byte ptr [rdi + r8]      ; load 4th symbol
        mov     byte ptr [r9 + rsi], al       ; store at dst[di]
        lea     rax, [rbx + 3]
        add     rbx, 1
        add     rsi, 4                        ; di += 4
        cmp     rbx, r14                      ; if (si + 1 < src_len)
        mov     rbx, rax                      ;     si = si + 3
        jb      .LBB8_6                       ;     goto .LBB8_6
.LBB8_18:
        ; The end

Note that the bit-twiddling parts are identical to the C version. I compiled the C version with Clang, which uses the same LLVM backend as the Rust compiler, so this shouldn't be too much of a surprise.

The main difference is that it doesn't increment a destination pointer directly, as the C version did, but instead keeps a di index. This is because it needs to use di when bounds checking accesses to the dst array. So, apart from the bounds checking, the Rust version and the C version are almost identical.

The bounds checking probably won't make much difference in this case: the calculations are all register-only operations and hardware branch prediction will minimize any effects of the branches themselves. However, it does have an effect on register allocation and code size... and it just kind of irks me that they are there when we can see clearly at compile time that they aren't necessary: it would be nice if we could get rid of them.

Our Rust version also doesn't seem very idiomatic. One of the great promises of Rust is that it allows us to write fast, low-level code in a high-level, functional programming style. We shouldn't be manually incrementing array indices, we should be using iterators!

Here's another version:

pub fn encode(src: &[u8]) -> Vec<u8> {
    assert!(src.len() % 3 == 0);
    let dst_len = (src.len() / 3) * 4;
    let mut dst = vec![0 as u8; dst_len];

    for (s, d) in src.chunks_exact(3).zip(dst.chunks_exact_mut(4)) {
        let val = (s[0] as u32) << 16 | (s[1] as u32) << 8 | (s[2] as u32);
        d[0] = ENCODING[(val >> 18 & 0x3F) as usize];
        d[1] = ENCODING[(val >> 12 & 0x3F) as usize];
        d[2] = ENCODING[(val >> 6 & 0x3F) as usize];
        d[3] = ENCODING[(val & 0x3F) as usize];
    }

    dst
}

That one gives us this [4]:

        ; Setup and then...
        test    r8, r8                         ; if (src_len/3 == 0)
        je      .LBB7_7                        ;     goto .LBB7_7
        add     r13, 2                         ; src += 2
        xor     edx, edx                       ; i = 0
        lea     r9, [rip + .L__unnamed_7]      ; ENCODING array
.LBB7_6:
        movzx   esi, byte ptr [r13 - 2]        ; load src-2
        movzx   ecx, byte ptr [r13 - 1]        ; load src-1
        shl     ecx, 8
        mov     edi, esi
        shl     edi, 16
        or      edi, ecx
        movzx   r10d, byte ptr [r13]           ; load src
        or      ecx, r10d
        shr     rsi, 2
        movzx   ebx, byte ptr [rsi + r9]       ; load 1st symbol
        mov     byte ptr [rax + 4*rdx], bl     ; store at dst[4*i]
        shr     edi, 12
        and     edi, 63
        movzx   ebx, byte ptr [rdi + r9]       ; load 2nd symbol
        mov     byte ptr [rax + 4*rdx + 1], bl ; store at dst[4*i+1]
        shr     ecx, 6
        and     ecx, 63
        movzx   ecx, byte ptr [rcx + r9]       ; load 3rd symbol
        mov     byte ptr [rax + 4*rdx + 2], cl ; store at dst[4*i+2]
        and     r10d, 63
        movzx   ecx, byte ptr [r10 + r9]       ; load 4th symbol
        mov     byte ptr [rax + 4*rdx + 3], cl ; store at dst[4*i+3]
        lea     rdx, [rdx + 1]                 ; i += 1
        add     r13, 3                         ; src += 3
        cmp     rdx, r8                        ; if (i < src_len/3)
        jb      .LBB7_6                        ;     goto .LBB7_6
.LBB7_7:
        ; The end

Hooray! No bounds checking! This one is now almost identical to the C version. We have one index and one array pointer that we increment directly.

Whether this version is as readable as the first version is perhaps a matter of taste — but we can definitely do better. We're still allocating an array and then iterating through it: if we're really going to do justice to Rust's OCaml heritage, we should be more functional still:

pub fn encode(src: &[u8]) -> Vec<u8> {
    src.chunks_exact(3)
       .map(|s| (s[0] as u32) << 16 | (s[1] as u32) << 8 | (s[2] as u32))
       .flat_map(|v| vec![ENCODING[(v >> 18 & 0x3F) as usize],
                          ENCODING[(v >> 12 & 0x3F) as usize],
                          ENCODING[(v >> 6 & 0x3F) as usize],
                          ENCODING[(v & 0x3F) as usize]])
       .collect()
}

Unfortunately, that compiles to something far more convoluted [5]. The inner loop has gone from 26 instructions to 115! Worse, the closures that we passed to map and flat_map have been compiled to a separate function which we call on every iteration.

And there's an even bigger problem caused by that vec! macro in the flat_map closure. It's there because we need to return an iterator and Rust does not provide iterators over arrays or tuples. We could return a slice, but then we'd have an iterator over references. That's no good because they would be references to the temporary variables created in that closure, so they couldn't live longer than the closure itself, but we need them later to collect them into the output. The only convenient option provided by core Rust is a Vec... but that results in memory allocation and deallocation on every iteration!

We can fix this by making our own iterator that works with arrays. We need something that will take an array of four elements and will return a copy of each element on request:

struct ArrayIter<T> {
    a: [T; 4],
    i: usize,
}

impl<T> ArrayIter<T> {
    fn new(a: [T; 4]) -> Self {
        Self {a, i: 0}
    }
}

impl<T> Iterator for ArrayIter<T>
where T: Copy,
{
    type Item = T;

    fn next(&mut self) -> Option<Self::Item> {
        if self.i < 4 {
            let value = self.a[self.i];
            self.i += 1;
            Some(value)
        } else {
            None
        }
    }
}

pub fn encode(src: &[u8]) -> Vec<u8> {
    src.chunks_exact(3)
       .map(|s| (s[0] as u32) << 16 | (s[1] as u32) << 8 | (s[2] as u32))
       .flat_map(|v| ArrayIter::new(
               [ENCODING[(v >> 18 & 0x3F) as usize],
                ENCODING[(v >> 12 & 0x3F) as usize],
                ENCODING[(v >> 6 & 0x3F) as usize],
                ENCODING[(v & 0x3F) as usize]]))
       .collect()
}

This is better [6]: we've eliminated the heap allocations and deallocations in the flat_map closure. But we've reintroduced bounds checking!

Also, we still have another memory-related problem: as we loop through the src array, calling our map/flat_map closure on every iteration, we incrementally populate a result Vec. But unlike in our earlier version, where we allocated dst once at the beginning and then wrote our results into it, this version writes into the result Vec until it is full then it reallocates as needed. We know exactly how many output symbols we are going to create but collect does not, so it must dynamically reallocate memory inside our loop.

Maybe going fully functional wasn't such a good idea. It's definitely not as efficient as just writing the loops explicitly.

Finally, let's finish by handling the case when the input is not a multiple of 3 bytes. In that case, the Base64 output should still be a multiple of 4 symbols: we have to add padding symbols, usually =, to make up the difference.

Here's a version in C:

char *encode(uint8_t *src, size_t src_len) {
    size_t dst_len = (src_len + 2) / 3 * 4;
    char *dst = malloc(dst_len);
    size_t si = 0;
    size_t di = 0;
    size_t n = (src_len / 3) * 3;

    while (si < n) {
        uint32_t val = (src[si] << 16) | (src[si+1] << 8) | src[si+2];
        dst[di+0] = encoding[val >> 18 & 0x3F];
        dst[di+1] = encoding[val >> 12 & 0x3F];
        dst[di+2] = encoding[val >> 6 & 0x3F];
        dst[di+3] = encoding[val & 0x3F];
        si += 3;
        di += 4;
    }

    size_t remainder = src_len - si;
    if (remainder == 0) {
        return dst;
    }

    uint32_t val = src[si] << 16;
    if (remainder == 2) {
        val |= src[si+1] << 8;
    }

    dst[di+0] = encoding[val >> 18 & 0x3F];
    dst[di+1] = encoding[val >> 12 & 0x3F];

    switch (remainder) {
    case 2:
        dst[di+2] = encoding[val >> 6 & 0x3F];
        break;
    case 1:
        dst[di+2] = '=';
        break;
    }

    dst[di+3] = '=';

    return dst;
}

It's the same as before but now we handle the remaining bytes when the input isn't a multiple of 3 bytes long.

We can do something similar in Rust, but we have to be careful to clone and borrow the iterators as appropriate:

pub fn encode(src: &[u8]) -> Vec<u8> {
    let dst_len = (src.len() + 2) / 3 * 4;
    let mut dst = vec![0 as u8; dst_len];

    let src_iter = src.chunks_exact(3);
    let src_remainder = src_iter.clone().remainder();
    let mut dst_iter = dst.chunks_exact_mut(4);

    for (s, d) in src_iter.zip(&mut dst_iter) {
        let val = (s[0] as u32) << 16 | (s[1] as u32) << 8 | (s[2] as u32);
        d[0] = ENCODING[(val >> 18 & 0x3F) as usize];
        d[1] = ENCODING[(val >> 12 & 0x3F) as usize];
        d[2] = ENCODING[(val >> 6 & 0x3F) as usize];
        d[3] = ENCODING[(val & 0x3F) as usize];
    }

    if src_remainder.len() > 0 {
        let mut val = (src_remainder[0] as u32) << 16;
        if src_remainder.len() > 1 {
            val |= (src_remainder[1] as u32) << 8;
        }
        let dst_remainder = dst_iter.last().unwrap();
        dst_remainder[0] = ENCODING[(val >> 18 & 0x3F) as usize];
        dst_remainder[1] = ENCODING[(val >> 12 & 0x3F) as usize];
        match src_remainder.len() {
            2 => {
                dst_remainder[2] = ENCODING[(val >> 6 & 0x3F) as usize];
            },
            1 => {
                dst_remainder[2] = b'=';
            },
            _ => (),
        }
        dst_remainder[3] = b'=';
    }

    dst
}

As before, the C and the Rust produce pretty much equivalent assembly [7][8]. The Rust version has a couple of extra checks (such as the unwrap) but it's very similar.

The Rust Book says that Rust strives to provide "higher-level features that compile to lower-level code as fast as code written manually" [9]. This appears to be the case... as long as you pick the right features.