Blob Blame History Raw
Index: streamz-0.6.3/examples/river_kmeans.ipynb
===================================================================
--- /dev/null
+++ streamz-0.6.3/examples/river_kmeans.ipynb
@@ -0,0 +1,134 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "accbccab",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import random\n",
+    "\n",
+    "import pandas as pd\n",
+    "\n",
+    "from streamz import Stream\n",
+    "import hvplot.streamz\n",
+    "from streamz.river import RiverTrain\n",
+    "from river import cluster\n",
+    "import holoviews as hv\n",
+    "from panel.pane.holoviews import HoloViews\n",
+    "hv.extension('bokeh')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "8a2ef27a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model = cluster.KMeans(n_clusters=3, sigma=0.1, mu=0.5)\n",
+    "centres = [[random.random(), random.random()] for _ in range(3)]\n",
+    "\n",
+    "def gen(move_chance=0.05):\n",
+    "    centre = int(random.random() * 3)  # 3x faster than random.randint(0, 2)\n",
+    "    if random.random() < move_chance:\n",
+    "        centres[centre][0] += random.random() / 5 - 0.1\n",
+    "        centres[centre][1] += random.random() / 5 - 0.1\n",
+    "    value = {'x': random.random() / 20 + centres[centre][0],\n",
+    "             'y': random.random() / 20 + centres[centre][1]}\n",
+    "    return value\n",
+    "\n",
+    "\n",
+    "def get_clusters(model):\n",
+    "    # return [{\"x\": xcen, \"y\": ycen}, ...] for each centre\n",
+    "    data = [{'x': v['x'], 'y': v['y']} for k, v in model.centers.items()]\n",
+    "    return pd.DataFrame(data, index=range(3))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e6451048",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "s = Stream.from_periodic(gen, 0.03)\n",
+    "km = RiverTrain(model, pass_model=True)\n",
+    "s.map(lambda x: (x,)).connect(km)  # learn takes a tuple of (x,[ y[, w]])\n",
+    "ex = pd.DataFrame({'x': [0.5], 'y': [0.5]})\n",
+    "ooo = s.map(lambda x: pd.DataFrame([x])).to_dataframe(example=ex)\n",
+    "out = km.map(get_clusters)\n",
+    "\n",
+    "# start things\n",
+    "s.emit(gen())  # set initial model\n",
+    "for i, (x, y) in enumerate(centres):\n",
+    "    model.centers[i]['x'] = x\n",
+    "    model.centers[i]['y'] = y\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1b4de451",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pout = out.to_dataframe(example=ex)\n",
+    "pl = (ooo.hvplot.scatter('x', 'y', color=\"blue\", backlog=50) *\n",
+    "      pout.hvplot.scatter('x', 'y', color=\"red\", backlog=3))\n",
+    "pl.opts(xlim=(-0.2, 1.2), ylim=(-0.2, 1.2), height=600, width=600)\n",
+    "pl"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c24d2363",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "s.start()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "18cfd94e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "s.stop()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "4537495c",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.8"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
Index: streamz-0.6.3/examples/river_kmeans.py
===================================================================
--- /dev/null
+++ streamz-0.6.3/examples/river_kmeans.py
@@ -0,0 +1,70 @@
+import random
+
+import pandas as pd
+
+from streamz import Stream
+import hvplot.streamz
+from streamz.river import RiverTrain
+from river import cluster
+import holoviews as hv
+from panel.pane.holoviews import HoloViews
+hv.extension('bokeh')
+
+model = cluster.KMeans(n_clusters=3, sigma=0.1, mu=0.5)
+centres = [[random.random(), random.random()] for _ in range(3)]
+count = [0]
+
+def gen(move_chance=0.05):
+    centre = int(random.random() * 3)  # 3x faster than random.randint(0, 2)
+    if random.random() < move_chance:
+        centres[centre][0] += random.random() / 5 - 0.1
+        centres[centre][1] += random.random() / 5 - 0.1
+    value = {'x': random.random() / 20 + centres[centre][0],
+             'y': random.random() / 20 + centres[centre][1]}
+    count[0] += 1
+    return value
+
+
+def get_clusters(model):
+    # return [{"x": xcen, "y": ycen}, ...] for each centre
+    data = [{'x': v['x'], 'y': v['y']} for k, v in model.centers.items()]
+    return pd.DataFrame(data, index=range(3))
+
+
+def main(viz=True):
+    # setup pipes
+    cadance = 0.16 if viz else 0.01
+    s = Stream.from_periodic(gen, cadance)
+    km = RiverTrain(model, pass_model=True)
+    s.map(lambda x: (x,)).connect(km)  # learn takes a tuple of (x,[ y[, w]])
+    ex = pd.DataFrame({'x': [0.5], 'y': [0.5]})
+    ooo = s.map(lambda x: pd.DataFrame([x])).to_dataframe(example=ex)
+    out = km.map(get_clusters)
+
+    # start things
+    s.emit(gen())  # set initial model
+    for i, (x, y) in enumerate(centres):
+        model.centers[i]['x'] = x
+        model.centers[i]['y'] = y
+
+    print("starting")
+    s.start()
+
+    if viz:
+        # plot
+        pout = out.to_dataframe(example=ex)
+        pl = (ooo.hvplot.scatter('x', 'y', color="blue", backlog=50) *
+              pout.hvplot.scatter('x', 'y', color="red", backlog=3))
+        pl.opts(xlim=(-0.2, 1.2), ylim=(-0.2, 1.2), height=600, width=600)
+        pan = HoloViews(pl)
+        pan.show()
+    else:
+        import time
+        time.sleep(5)
+        print(count, "events")
+        print("Current centres", centres)
+        print("Output centres", [list(c.values()) for c in model.centers.values()])
+    s.stop()
+
+if __name__ == "__main__":
+    main(viz=True)
Index: streamz-0.6.3/streamz/core.py
===================================================================
--- streamz-0.6.3.orig/streamz/core.py
+++ streamz-0.6.3/streamz/core.py
@@ -1902,89 +1902,6 @@ class latest(Stream):
             yield self._emit(x, self.next_metadata)
 
 
-@Stream.register_api()
-class to_kafka(Stream):
-    """ Writes data in the stream to Kafka
-
-    This stream accepts a string or bytes object. Call ``flush`` to ensure all
-    messages are pushed. Responses from Kafka are pushed downstream.
-
-    Parameters
-    ----------
-    topic : string
-        The topic which to write
-    producer_config : dict
-        Settings to set up the stream, see
-        https://docs.confluent.io/current/clients/confluent-kafka-python/#configuration
-        https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md
-        Examples:
-        bootstrap.servers: Connection string (host:port) to Kafka
-
-    Examples
-    --------
-    >>> from streamz import Stream
-    >>> ARGS = {'bootstrap.servers': 'localhost:9092'}
-    >>> source = Stream()
-    >>> kafka = source.map(lambda x: str(x)).to_kafka('test', ARGS)
-    <to_kafka>
-    >>> for i in range(10):
-    ...     source.emit(i)
-    >>> kafka.flush()
-    """
-    def __init__(self, upstream, topic, producer_config, **kwargs):
-        import confluent_kafka as ck
-
-        self.topic = topic
-        self.producer = ck.Producer(producer_config)
-
-        kwargs["ensure_io_loop"] = True
-        Stream.__init__(self, upstream, **kwargs)
-        self.stopped = False
-        self.polltime = 0.2
-        self.loop.add_callback(self.poll)
-        self.futures = []
-
-    @gen.coroutine
-    def poll(self):
-        while not self.stopped:
-            # executes callbacks for any delivered data, in this thread
-            # if no messages were sent, nothing happens
-            self.producer.poll(0)
-            yield gen.sleep(self.polltime)
-
-    def update(self, x, who=None, metadata=None):
-        future = gen.Future()
-        self.futures.append(future)
-
-        @gen.coroutine
-        def _():
-            while True:
-                try:
-                    # this runs asynchronously, in C-K's thread
-                    self.producer.produce(self.topic, x, callback=self.cb)
-                    return
-                except BufferError:
-                    yield gen.sleep(self.polltime)
-                except Exception as e:
-                    future.set_exception(e)
-                    return
-
-        self.loop.add_callback(_)
-        return future
-
-    @gen.coroutine
-    def cb(self, err, msg):
-        future = self.futures.pop(0)
-        if msg is not None and msg.value() is not None:
-            future.set_result(None)
-            yield self._emit(msg.value())
-        else:
-            future.set_exception(err or msg.error())
-
-    def flush(self, timeout=-1):
-        self.producer.flush(timeout)
-
-
 def sync(loop, func, *args, **kwargs):
     """
     Run coroutine in loop running in separate thread.
Index: streamz-0.6.3/streamz/river.py
===================================================================
--- /dev/null
+++ streamz-0.6.3/streamz/river.py
@@ -0,0 +1,62 @@
+from . import Stream
+
+
+# TODO: most river classes support batches, e.g., learn_many, more efficiently
+
+
+class RiverTransform(Stream):
+    """Pass data through one or more River transforms"""
+
+    def __init__(self, model, **kwargs):
+        super().__init__(**kwargs)
+        self.model = model
+
+    def update(self, x, who=None, metadata=None):
+        out = self.model.transform_one(*x)
+        self.emit(out)
+
+
+class RiverTrain(Stream):
+
+    def __init__(self, model, metric=None, pass_model=False, **kwargs):
+        """
+
+        If metric and pass_model are both defaults, this is effectively
+        a sink.
+
+        :param model: river model or pipeline
+        :param metric: river metric
+            If given, it is emitted on every sample
+        :param pass_model: bool
+            If True, the (updated) model if emitted for each sample
+        """
+        super().__init__(**kwargs)
+        self.model = model
+        if pass_model and metric is not None:
+            raise TypeError
+        self.pass_model = pass_model
+        self.metric = metric
+
+    def update(self, x, who=None, metadata=None):
+        """
+        :param x: tuple
+            (x, [y[, w]) floats for single sample. Include
+        """
+        self.model.learn_one(*x)
+        if self.metric:
+            yp = self.model.predict_one(x[0])
+            weights = x[2] if len(x) > 1 else 1.0
+            self.emit(self.metric.update(x[1], yp, weights).get(), metadata=metadata)
+        if self.pass_model:
+            self.emit(self.model, metadata=metadata)
+
+
+class RiverPredict(Stream):
+
+    def __init__(self, model, **kwargs):
+        super().__init__(**kwargs)
+        self.model = model
+
+    def update(self, x, who=None, metadata=None):
+        out = self.model.predict_one(x)
+        self.emit(out, metadata=metadata)
Index: streamz-0.6.3/streamz/sinks.py
===================================================================
--- streamz-0.6.3.orig/streamz/sinks.py
+++ streamz-0.6.3/streamz/sinks.py
@@ -73,6 +73,89 @@ class sink(Sink):
 
 
 @Stream.register_api()
+class to_kafka(Stream):
+    """ Writes data in the stream to Kafka
+
+    This stream accepts a string or bytes object. Call ``flush`` to ensure all
+    messages are pushed. Responses from Kafka are pushed downstream.
+
+    Parameters
+    ----------
+    topic : string
+        The topic which to write
+    producer_config : dict
+        Settings to set up the stream, see
+        https://docs.confluent.io/current/clients/confluent-kafka-python/#configuration
+        https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md
+        Examples:
+        bootstrap.servers: Connection string (host:port) to Kafka
+
+    Examples
+    --------
+    >>> from streamz import Stream
+    >>> ARGS = {'bootstrap.servers': 'localhost:9092'}
+    >>> source = Stream()
+    >>> kafka = source.map(lambda x: str(x)).to_kafka('test', ARGS)
+    <to_kafka>
+    >>> for i in range(10):
+    ...     source.emit(i)
+    >>> kafka.flush()
+    """
+    def __init__(self, upstream, topic, producer_config, **kwargs):
+        import confluent_kafka as ck
+
+        self.topic = topic
+        self.producer = ck.Producer(producer_config)
+
+        kwargs["ensure_io_loop"] = True
+        Stream.__init__(self, upstream, **kwargs)
+        self.stopped = False
+        self.polltime = 0.2
+        self.loop.add_callback(self.poll)
+        self.futures = []
+
+    @gen.coroutine
+    def poll(self):
+        while not self.stopped:
+            # executes callbacks for any delivered data, in this thread
+            # if no messages were sent, nothing happens
+            self.producer.poll(0)
+            yield gen.sleep(self.polltime)
+
+    def update(self, x, who=None, metadata=None):
+        future = gen.Future()
+        self.futures.append(future)
+
+        @gen.coroutine
+        def _():
+            while True:
+                try:
+                    # this runs asynchronously, in C-K's thread
+                    self.producer.produce(self.topic, x, callback=self.cb)
+                    return
+                except BufferError:
+                    yield gen.sleep(self.polltime)
+                except Exception as e:
+                    future.set_exception(e)
+                    return
+
+        self.loop.add_callback(_)
+        return future
+
+    @gen.coroutine
+    def cb(self, err, msg):
+        future = self.futures.pop(0)
+        if msg is not None and msg.value() is not None:
+            future.set_result(None)
+            yield self._emit(msg.value())
+        else:
+            future.set_exception(err or msg.error())
+
+    def flush(self, timeout=-1):
+        self.producer.flush(timeout)
+
+
+@Stream.register_api()
 class sink_to_textfile(Sink):
     """ Write elements to a plain text file, one element per line.
 
Index: streamz-0.6.3/streamz/tests/test_dask.py
===================================================================
--- streamz-0.6.3.orig/streamz/tests/test_dask.py
+++ streamz-0.6.3/streamz/tests/test_dask.py
@@ -1,3 +1,4 @@
+import asyncio
 from operator import add
 import random
 import time
@@ -16,21 +17,21 @@ from distributed.utils_test import gen_c
 
 
 @gen_cluster(client=True)
-def test_map(c, s, a, b):
+async def test_map(c, s, a, b):
     source = Stream(asynchronous=True)
     futures = scatter(source).map(inc)
     futures_L = futures.sink_to_list()
     L = futures.gather().sink_to_list()
 
     for i in range(5):
-        yield source.emit(i)
+        await source.emit(i)
 
     assert L == [1, 2, 3, 4, 5]
     assert all(isinstance(f, Future) for f in futures_L)
 
 
 @gen_cluster(client=True)
-def test_map_on_dict(c, s, a, b):
+async def test_map_on_dict(c, s, a, b):
     # dask treats dicts differently, so we have to make sure
     # the user sees no difference in the streamz api.
     # Regression test against #336
@@ -43,7 +44,7 @@ def test_map_on_dict(c, s, a, b):
     L = futures.gather().sink_to_list()
 
     for i in range(5):
-        yield source.emit({"i": i})
+        await source.emit({"i": i})
 
     assert len(L) == 5
     for i, item in enumerate(sorted(L, key=lambda x: x["x"])):
@@ -52,7 +53,7 @@ def test_map_on_dict(c, s, a, b):
 
 
 @gen_cluster(client=True)
-def test_partition_then_scatter_async(c, s, a, b):
+async def test_partition_then_scatter_async(c, s, a, b):
     # Ensure partition w/ timeout before scatter works correctly for
     # asynchronous
     start = time.monotonic()
@@ -63,10 +64,10 @@ def test_partition_then_scatter_async(c,
 
     rc = RefCounter(loop=source.loop)
     for i in range(3):
-        yield source.emit(i, metadata=[{'ref': rc}])
+        await source.emit(i, metadata=[{'ref': rc}])
 
     while rc.count != 0 and time.monotonic() - start < 1.:
-        yield gen.sleep(1e-2)
+        await gen.sleep(1e-2)
 
     assert L == [1, 2, 3]
 
@@ -92,7 +93,7 @@ def test_partition_then_scatter_sync(loo
 
 
 @gen_cluster(client=True)
-def test_non_unique_emit(c, s, a, b):
+async def test_non_unique_emit(c, s, a, b):
     """Regression for https://github.com/python-streamz/streams/issues/397
 
     Non-unique stream entries still need to each be processed.
@@ -103,28 +104,28 @@ def test_non_unique_emit(c, s, a, b):
 
     for _ in range(3):
         # Emit non-unique values
-        yield source.emit(0)
+        await source.emit(0)
 
     assert len(L) == 3
     assert L[0] != L[1] or L[0] != L[2]
 
 
 @gen_cluster(client=True)
-def test_scan(c, s, a, b):
+async def test_scan(c, s, a, b):
     source = Stream(asynchronous=True)
     futures = scatter(source).map(inc).scan(add)
     futures_L = futures.sink_to_list()
     L = futures.gather().sink_to_list()
 
     for i in range(5):
-        yield source.emit(i)
+        await source.emit(i)
 
     assert L == [1, 3, 6, 10, 15]
     assert all(isinstance(f, Future) for f in futures_L)
 
 
 @gen_cluster(client=True)
-def test_scan_state(c, s, a, b):
+async def test_scan_state(c, s, a, b):
     source = Stream(asynchronous=True)
 
     def f(acc, i):
@@ -133,33 +134,33 @@ def test_scan_state(c, s, a, b):
 
     L = scatter(source).scan(f, returns_state=True).gather().sink_to_list()
     for i in range(3):
-        yield source.emit(i)
+        await source.emit(i)
 
     assert L == [0, 1, 3]
 
 
 @gen_cluster(client=True)
-def test_zip(c, s, a, b):
+async def test_zip(c, s, a, b):
     a = Stream(asynchronous=True)
     b = Stream(asynchronous=True)
     c = scatter(a).zip(scatter(b))
 
     L = c.gather().sink_to_list()
 
-    yield a.emit(1)
-    yield b.emit('a')
-    yield a.emit(2)
-    yield b.emit('b')
+    await a.emit(1)
+    await b.emit('a')
+    await a.emit(2)
+    await b.emit('b')
 
     assert L == [(1, 'a'), (2, 'b')]
 
 
 @gen_cluster(client=True)
-def test_accumulate(c, s, a, b):
+async def test_accumulate(c, s, a, b):
     source = Stream(asynchronous=True)
     L = source.scatter().accumulate(lambda acc, x: acc + x, with_state=True).gather().sink_to_list()
     for i in range(3):
-        yield source.emit(i)
+        await source.emit(i)
     assert L[-1][1] == 3
 
 
@@ -169,10 +170,9 @@ def test_sync(loop):  # noqa: F811
             source = Stream()
             L = source.scatter().map(inc).gather().sink_to_list()
 
-            @gen.coroutine
-            def f():
+            async def f():
                 for i in range(10):
-                    yield source.emit(i, asynchronous=True)
+                    await source.emit(i, asynchronous=True)
 
             sync(loop, f)
 
@@ -193,24 +193,24 @@ def test_sync_2(loop):  # noqa: F811
 
 
 @gen_cluster(client=True, nthreads=[('127.0.0.1', 1)] * 2)
-def test_buffer(c, s, a, b):
+async def test_buffer(c, s, a, b):
     source = Stream(asynchronous=True)
     L = source.scatter().map(slowinc, delay=0.5).buffer(5).gather().sink_to_list()
 
     start = time.time()
     for i in range(5):
-        yield source.emit(i)
+        await source.emit(i)
     end = time.time()
     assert end - start < 0.5
 
     for i in range(5, 10):
-        yield source.emit(i)
+        await source.emit(i)
 
     end2 = time.time()
     assert end2 - start > (0.5 / 3)
 
     while len(L) < 10:
-        yield gen.sleep(0.01)
+        await gen.sleep(0.01)
         assert time.time() - start < 5
 
     assert L == list(map(inc, range(10)))
@@ -242,7 +242,7 @@ def test_buffer_sync(loop):  # noqa: F81
 
 
 @pytest.mark.xfail(reason='')
-def test_stream_shares_client_loop(loop):  # noqa: F811
+async def test_stream_shares_client_loop(loop):  # noqa: F811
     with cluster() as (s, [a, b]):
         with Client(s['address'], loop=loop) as client:  # noqa: F841
             source = Stream()
@@ -251,7 +251,7 @@ def test_stream_shares_client_loop(loop)
 
 
 @gen_cluster(client=True)
-def test_starmap(c, s, a, b):
+async def test_starmap(c, s, a, b):
     def add(x, y, z=0):
         return x + y + z
 
@@ -259,6 +259,6 @@ def test_starmap(c, s, a, b):
     L = source.scatter().starmap(add, z=10).gather().sink_to_list()
 
     for i in range(5):
-        yield source.emit((i, i))
+        await source.emit((i, i))
 
     assert L == [10, 12, 14, 16, 18]
Index: streamz-0.6.3/streamz/tests/test_kafka.py
===================================================================
--- streamz-0.6.3.orig/streamz/tests/test_kafka.py
+++ streamz-0.6.3/streamz/tests/test_kafka.py
@@ -1,3 +1,4 @@
+import asyncio
 import atexit
 from contextlib import contextmanager
 from flaky import flaky
@@ -217,7 +218,7 @@ def test_kafka_batch():
 
 
 @gen_cluster(client=True, timeout=60)
-def test_kafka_dask_batch(c, s, w1, w2):
+async def test_kafka_dask_batch(c, s, w1, w2):
     j = random.randint(0, 10000)
     ARGS = {'bootstrap.servers': 'localhost:9092',
             'group.id': 'streamz-test%i' % j}
@@ -227,15 +228,15 @@ def test_kafka_dask_batch(c, s, w1, w2):
                                            asynchronous=True, dask=True)
         out = stream.gather().sink_to_list()
         stream.start()
-        yield gen.sleep(5)  # this frees the loop while dask workers report in
+        await asyncio.sleep(5)  # this frees the loop while dask workers report in
         assert isinstance(stream, DaskStream)
         for i in range(10):
             kafka.produce(TOPIC, b'value-%d' % i)
         kafka.flush()
-        yield await_for(lambda: any(out), 10, period=0.2)
+        await await_for(lambda: any(out), 10, period=0.2)
         assert {'key': None, 'value': b'value-1'} in out[0]
         stream.stop()
-        yield gen.sleep(0)
+        await asyncio.sleep(0)
         stream.upstream.upstream.consumer.close()
 
 
@@ -382,7 +383,7 @@ def test_kafka_batch_checkpointing_sync_
 
 
 @gen_cluster(client=True, timeout=60)
-def test_kafka_dask_checkpointing_sync_nodes(c, s, w1, w2):
+async def test_kafka_dask_checkpointing_sync_nodes(c, s, w1, w2):
     '''
     Testing whether Dask's scatter and gather works in conformity with
     the reference counting checkpointing implementation.
@@ -403,23 +404,23 @@ def test_kafka_dask_checkpointing_sync_n
             kafka.produce(TOPIC, b'value-%d' % i)
         kafka.flush()
         stream1 = Stream.from_kafka_batched(TOPIC, ARGS1, asynchronous=True,
-                                           dask=True)
+                                            dask=True)
         out1 = stream1.map(split).gather().filter(lambda x: x[-1] % 2 == 1).sink_to_list()
         stream1.start()
-        yield await_for(lambda: any(out1) and out1[-1][-1] == 9, 10, period=0.2)
+        await await_for(lambda: any(out1) and out1[-1][-1] == 9, 10, period=0.2)
         stream1.upstream.stopped = True
         stream2 = Stream.from_kafka_batched(TOPIC, ARGS1, asynchronous=True,
-                                           dask=True)
+                                            dask=True)
         out2 = stream2.map(split).gather().filter(lambda x: x[-1] % 2 == 1).sink_to_list()
         stream2.start()
         time.sleep(5)
         assert len(out2) == 0
         stream2.upstream.stopped = True
         stream3 = Stream.from_kafka_batched(TOPIC, ARGS2, asynchronous=True,
-                                           dask=True)
+                                            dask=True)
         out3 = stream3.map(split).gather().filter(lambda x: x[-1] % 2 == 1).sink_to_list()
         stream3.start()
-        yield await_for(lambda: any(out3) and out3[-1][-1] == 9, 10, period=0.2)
+        await await_for(lambda: any(out3) and out3[-1][-1] == 9, 10, period=0.2)
         stream3.upstream.stopped = True