Conrad Ludgate

Async Iterators

This blog post will not be very rigorous, but I wanted to highlight an example that I often go to regarding complicated async-iterators.

First, I should provide context.

Currently, we have an AsyncIterator trait defined as:

rust
pub trait AsyncIterator {
    type Item;

    fn poll_next(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>
    ) -> Poll<Option<Self::Item>>;
}

This is a mashup of the Future and Iterator traits:

rust
pub trait Future {
    type Output;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>;
}

pub trait Iterator {
    type Item;

    fn next(&mut self) -> Option<Self::Item>;
}

This AsyncIterator trait has been working fine for many Rust developers, as exposed by the equivalent futures::Stream trait.

However, some people would like AsyncIterator to be implemented using async fn next(&mut self). This sounds like it simplifies some things. Of course, we couldn't have a trait implemented with async functions before, but we can now. So it is worth considering.

There is already a blog post by withoutboats explaining why this is a bad idea, so I won't re-hash any of those here.

The reason why you might want to have an async fn next() is that it supposedly makes it easier to implement, but I disagree. To demonstrate, let me show you an example.

Framed

In the tokio-util crate, there is an AsyncIterator called FramedRead. This will read from bytes from an AsyncRead, and output some decoded frames. The implementation of this is a sight to behold:

rust
impl<T, U, R> Stream for FramedImpl<T, U, R>
where
    T: AsyncRead,
    U: Decoder,
    R: BorrowMut<ReadFrame>,
{
    type Item = Result<U::Item, U::Error>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        use crate::util::poll_read_buf;

        let mut pinned = self.project();
        let state: &mut ReadFrame = pinned.state.borrow_mut();
        // The following loops implements a state machine with each state corresponding
        // to a combination of the `is_readable` and `eof` flags. States persist across
        // loop entries and most state transitions occur with a return.
        //
        // The initial state is `reading`.
        //
        // | state   | eof   | is_readable | has_errored |
        // |---------|-------|-------------|-------------|
        // | reading | false | false       | false       |
        // | framing | false | true        | false       |
        // | pausing | true  | true        | false       |
        // | paused  | true  | false       | false       |
        // | errored | <any> | <any>       | true        |
        //                                                       `decode_eof` returns Err
        //                                          ┌────────────────────────────────────────────────────────┐
        //                   `decode_eof` returns   │                                                        │
        //                             `Ok(Some)`   │                                                        │
        //                                 ┌─────┐  │     `decode_eof` returns               After returning │
        //                Read 0 bytes     ├─────▼──┴┐    `Ok(None)`          ┌────────┐ ◄───┐ `None`    ┌───▼─────┐
        //               ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐   └───────────┤ Errored │
        //               │                 └─────────┘                        └─┬──▲───┘ │               └───▲───▲─┘
        // Pending read  │                                                      │  │     │                   │   │
        //     ┌──────┐  │            `decode` returns `Some`                   │  └─────┘                   │   │
        //     │      │  │                   ┌──────┐                           │  Pending                   │   │
        //     │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐     read n>0 bytes      │  read                      │   │
        //     └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘                            │   │
        //       └──┬─▲────┘                └─────┬──┬┘                                                      │   │
        //          │ │                           │  │                 `decode` returns Err                  │   │
        //          │ └───decode` returns `None`──┘  └───────────────────────────────────────────────────────┘   │
        //          │                             read returns Err                                               │
        //          └────────────────────────────────────────────────────────────────────────────────────────────┘
        loop {
            // Return `None` if we have encountered an error from the underlying decoder
            // See: https://github.com/tokio-rs/tokio/issues/3976
            if state.has_errored {
                // preparing has_errored -> paused
                trace!("Returning None and setting paused");
                state.is_readable = false;
                state.has_errored = false;
                return Poll::Ready(None);
            }

            // Repeatedly call `decode` or `decode_eof` while the buffer is "readable",
            // i.e. it _might_ contain data consumable as a frame or closing frame.
            // Both signal that there is no such data by returning `None`.
            //
            // If `decode` couldn't read a frame and the upstream source has returned eof,
            // `decode_eof` will attempt to decode the remaining bytes as closing frames.
            //
            // If the underlying AsyncRead is resumable, we may continue after an EOF,
            // but must finish emitting all of it's associated `decode_eof` frames.
            // Furthermore, we don't want to emit any `decode_eof` frames on retried
            // reads after an EOF unless we've actually read more data.
            if state.is_readable {
                // pausing or framing
                if state.eof {
                    // pausing
                    let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
                        trace!("Got an error, going to errored state");
                        state.has_errored = true;
                        err
                    })?;
                    if frame.is_none() {
                        state.is_readable = false; // prepare pausing -> paused
                    }
                    // implicit pausing -> pausing or pausing -> paused
                    return Poll::Ready(frame.map(Ok));
                }

                // framing
                trace!("attempting to decode a frame");

                if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| {
                    trace!("Got an error, going to errored state");
                    state.has_errored = true;
                    op
                })? {
                    trace!("frame decoded from buffer");
                    // implicit framing -> framing
                    return Poll::Ready(Some(Ok(frame)));
                }

                // framing -> reading
                state.is_readable = false;
            }
            // reading or paused
            // If we can't build a frame yet, try to read more data and try again.
            // Make sure we've got room for at least one byte to read to ensure
            // that we don't get a spurious 0 that looks like EOF.
            state.buffer.reserve(1);
            let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err(
                |err| {
                    trace!("Got an error, going to errored state");
                    state.has_errored = true;
                    err
                },
            )? {
                Poll::Ready(ct) => ct,
                // implicit reading -> reading or implicit paused -> paused
                Poll::Pending => return Poll::Pending,
            };
            if bytect == 0 {
                if state.eof {
                    // We're already at an EOF, and since we've reached this path
                    // we're also not readable. This implies that we've already finished
                    // our `decode_eof` handling, so we can simply return `None`.
                    // implicit paused -> paused
                    return Poll::Ready(None);
                }
                // prepare reading -> paused
                state.eof = true;
            } else {
                // prepare paused -> framing or noop reading -> framing
                state.eof = false;
            }

            // paused -> framing or reading -> framing or reading -> pausing
            state.is_readable = true;
        }
    }
}

Wow! All it does is read some bytes, decode the frames, and then repeat until there is an error or no more bytes. Why is it so complicated? Your first instinct might be to blame the fact that we must use polling and not .await.

However, if you look closely at this implementation, we only poll the reader in one place. If we were to replace that with await, the implementation would look almost identical. Here is a diff:

diff
-impl<T, U, R> Stream for FramedImpl<T, U, R>
+impl<T, U, R> AsyncIterator for FramedImpl<T, U, R>
 where
-    T: AsyncRead,
+    T: AsyncRead + Unpin,
     U: Decoder,
     R: BorrowMut<ReadFrame>,
 {
     type Item = Result<U::Item, U::Error>;

-   fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+   async fn next(&mut self) -> Option<Self::Item> {
         use crate::util::poll_read_buf;

-        let mut pinned = self.project();
-        let state: &mut ReadFrame = self.pinned.borrow_mut();
+        let state: &mut ReadFrame = self.state.borrow_mut();
         // The following loops implements a state machine with each state corresponding
         // to a combination of the `is_readable` and `eof` flags. States persist across
         // loop entries and most state transitions occur with a return.
         //
         // The initial state is `reading`.
         //
         // | state   | eof   | is_readable | has_errored |
         // |---------|-------|-------------|-------------|
         // | reading | false | false       | false       |
         // | framing | false | true        | false       |
         // | pausing | true  | true        | false       |
         // | paused  | true  | false       | false       |
         // | errored | <any> | <any>       | true        |
         //                                                       `decode_eof` returns Err
         //                                          ┌────────────────────────────────────────────────────────┐
         //                   `decode_eof` returns   │                                                        │
         //                             `Ok(Some)`   │                                                        │
         //                                 ┌─────┐  │     `decode_eof` returns               After returning │
         //                Read 0 bytes     ├─────▼──┴┐    `Ok(None)`          ┌────────┐ ◄───┐ `None`    ┌───▼─────┐
         //               ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐   └───────────┤ Errored │
         //               │                 └─────────┘                        └─┬──▲───┘ │               └───▲───▲─┘
         // Pending read  │                                                      │  │     │                   │   │
         //     ┌──────┐  │            `decode` returns `Some`                   │  └─────┘                   │   │
         //     │      │  │                   ┌──────┐                           │  Pending                   │   │
         //     │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐     read n>0 bytes      │  read                      │   │
         //     └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘                            │   │
         //       └──┬─▲────┘                └─────┬──┬┘                                                      │   │
         //          │ │                           │  │                 `decode` returns Err                  │   │
         //          │ └───decode` returns `None`──┘  └───────────────────────────────────────────────────────┘   │
         //          │                             read returns Err                                               │
         //          └────────────────────────────────────────────────────────────────────────────────────────────┘
         loop {
             // Return `None` if we have encountered an error from the underlying decoder
             // See: https://github.com/tokio-rs/tokio/issues/3976
             if state.has_errored {
                 // preparing has_errored -> paused
                 trace!("Returning None and setting paused");
                 state.is_readable = false;
                 state.has_errored = false;
-                return Poll::Ready(None);
+                return None;
             }

             // Repeatedly call `decode` or `decode_eof` while the buffer is "readable",
             // i.e. it _might_ contain data consumable as a frame or closing frame.
             // Both signal that there is no such data by returning `None`.
             //
             // If `decode` couldn't read a frame and the upstream source has returned eof,
             // `decode_eof` will attempt to decode the remaining bytes as closing frames.
             //
             // If the underlying AsyncRead is resumable, we may continue after an EOF,
             // but must finish emitting all of it's associated `decode_eof` frames.
             // Furthermore, we don't want to emit any `decode_eof` frames on retried
             // reads after an EOF unless we've actually read more data.
             if state.is_readable {
                 // pausing or framing
                 if state.eof {
                     // pausing
                     let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
                         trace!("Got an error, going to errored state");
                         state.has_errored = true;
                         err
                     })?;
                     if frame.is_none() {
                         state.is_readable = false; // prepare pausing -> paused
                     }
                     // implicit pausing -> pausing or pausing -> paused
-                    return Poll::Ready(frame.map(Ok));
+                    return frame.map(Ok);
                 }

                 // framing
                 trace!("attempting to decode a frame");

                 if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| {
                     trace!("Got an error, going to errored state");
                     state.has_errored = true;
                     op
                 })? {
                     trace!("frame decoded from buffer");
                     // implicit framing -> framing
-                    return Poll::Ready(Some(Ok(frame)));
+                    return Some(Ok(frame));
                 }

                 // framing -> reading
                 state.is_readable = false;
             }
             // reading or paused
             // If we can't build a frame yet, try to read more data and try again.
             // Make sure we've got room for at least one byte to read to ensure
             // that we don't get a spurious 0 that looks like EOF.
             state.buffer.reserve(1);
-            let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err(
+            let bytect = match self.inner.next(&mut state.buffer).await.map_err(
                 |err| {
                     trace!("Got an error, going to errored state");
                     state.has_errored = true;
                     err
                 },
-            )? {
-                Poll::Ready(ct) => ct,
-                // implicit reading -> reading or implicit paused -> paused
-                Poll::Pending => return Poll::Pending,
-            };
+            )?;
             if bytect == 0 {
                 if state.eof {
                     // We're already at an EOF, and since we've reached this path
                     // we're also not readable. This implies that we've already finished
                     // our `decode_eof` handling, so we can simply return `None`.
                     // implicit paused -> paused
-                    return Poll::Ready(None);
+                    return None;
                 }
                 // prepare reading -> paused
                 state.eof = true;
             } else {
                 // prepare paused -> framing or noop reading -> framing
                 state.eof = false;
             }

             // paused -> framing or reading -> framing or reading -> pausing
             state.is_readable = true;
         }
     }
 }

Is that really all we can do? Unfortunately, yes. Reading bytes and decoding them is complex. Sometimes we must repeat calls to read to get a single frame. Sometimes a single call to read contains multiple frames. This makes this state machine construction still necessary.

We've also introduced a regression by needing to have T: Unpin. The proposed solution here is to use a PinnedAsyncIterator trait but I don't think that really makes anything simpler.

Generators

Instead of worrying about what the AsyncIterator trait looks like, we should provide a convenient syntax to implement it instead.

Here's how I imagine the implementation above could look written as a generator.

rust
#[yield(U::Item)]
async fn framed_impl<T, U>(t: T, codec: U) -> Result<(), U::Error>
where
    T: AsyncRead,
    U: Decoder,
{
    let mut inner = std::pin::pin!(t);
    let mut buffer = BytesMut::with_capacity(INITIAL_CAPACITY);
    loop {
        buffer.reserve(1);
        let eof = inner.read(&mut buffer).await? == 0;

        'framing: loop {
            trace!("attempting to decode a frame");

            let frame = if eof {
                codec.decode_eof(&mut buffer)?
            } else {
                codec.decode(&mut buffer)?
            };
            match frame {
                Some(frame) => {
                    trace!("frame decoded from buffer");
                    yield Ok(frame);
                    continue 'framing;
                }
                // nothing left to read. goodbye
                None if eof => return Ok(()),
                // go back to reading
                None => break 'framing,
            };
        }
    }
}

Do I need to say more?