Blob Blame History Raw
From: Boris Pismenny <borisp@mellanox.com>
Date: Fri, 13 Jul 2018 14:33:43 +0300
Subject: tls: Add rx inline crypto offload
Patch-mainline: v4.19-rc1
Git-commit: 4799ac81e52a72a6404827bf2738337bb581a174
References: bsc#1109837

This patch completes the generic infrastructure to offload TLS crypto to a
network device. It enables the kernel to skip decryption and
authentication of some skbs marked as decrypted by the NIC. In the fast
path, all packets received are decrypted by the NIC and the performance
is comparable to plain TCP.

This infrastructure doesn't require a TCP offload engine. Instead, the
NIC only decrypts packets that contain the expected TCP sequence number.
Out-Of-Order TCP packets are provided unmodified. As a result, at the
worst case a received TLS record consists of both plaintext and ciphertext
packets. These partially decrypted records must be reencrypted,
only to be decrypted.

The notable differences between SW KTLS Rx and this offload are as
follows:
1. Partial decryption - Software must handle the case of a TLS record
that was only partially decrypted by HW. This can happen due to packet
reordering.
2. Resynchronization - tls_read_size calls the device driver to
resynchronize HW after HW lost track of TLS record framing in
the TCP stream.

Signed-off-by: Boris Pismenny <borisp@mellanox.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
Acked-by: Thomas Bogendoerfer <tbogendoerfer@suse.de>
---
 include/net/tls.h             |   63 ++++++++-
 net/tls/tls_device.c          |  278 ++++++++++++++++++++++++++++++++++++++----
 net/tls/tls_device_fallback.c |    1 
 net/tls/tls_main.c            |   32 +++-
 net/tls/tls_sw.c              |   24 ++-
 5 files changed, 355 insertions(+), 43 deletions(-)

--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -83,6 +83,16 @@ struct tls_device {
 	void (*unhash)(struct tls_device *device, struct sock *sk);
 };
 
+enum {
+	TLS_BASE,
+	TLS_SW,
+#ifdef CONFIG_TLS_DEVICE
+	TLS_HW,
+#endif
+	TLS_HW_RECORD,
+	TLS_NUM_CONFIG,
+};
+
 struct tls_sw_context_tx {
 	struct crypto_aead *aead_send;
 	struct crypto_wait async_wait;
@@ -197,6 +207,7 @@ struct tls_context {
 	int (*push_pending_record)(struct sock *sk, int flags);
 
 	void (*sk_write_space)(struct sock *sk);
+	void (*sk_destruct)(struct sock *sk);
 	void (*sk_proto_close)(struct sock *sk, long timeout);
 
 	int  (*setsockopt)(struct sock *sk, int level,
@@ -209,13 +220,27 @@ struct tls_context {
 	void (*unhash)(struct sock *sk);
 };
 
+struct tls_offload_context_rx {
+	/* sw must be the first member of tls_offload_context_rx */
+	struct tls_sw_context_rx sw;
+	atomic64_t resync_req;
+	u8 driver_state[];
+	/* The TLS layer reserves room for driver specific state
+	 * Currently the belief is that there is not enough
+	 * driver specific state to justify another layer of indirection
+	 */
+};
+
+#define TLS_OFFLOAD_CONTEXT_SIZE_RX					\
+	(ALIGN(sizeof(struct tls_offload_context_rx), sizeof(void *)) + \
+	 TLS_DRIVER_STATE_SIZE)
+
 int wait_on_pending_writer(struct sock *sk, long *timeo);
 int tls_sk_query(struct sock *sk, int optname, char __user *optval,
 		int __user *optlen);
 int tls_sk_attach(struct sock *sk, int optname, char __user *optval,
 		  unsigned int optlen);
 
-
 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx);
 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
 int tls_sw_sendpage(struct sock *sk, struct page *page,
@@ -290,11 +315,19 @@ static inline bool tls_is_pending_open_r
 	return tls_ctx->pending_open_record_frags;
 }
 
+struct sk_buff *
+tls_validate_xmit_skb(struct sock *sk, struct net_device *dev,
+		      struct sk_buff *skb);
+
 static inline bool tls_is_sk_tx_device_offloaded(struct sock *sk)
 {
-	return sk_fullsock(sk) &&
-	       /* matches smp_store_release in tls_set_device_offload */
-	       smp_load_acquire(&sk->sk_destruct) == &tls_device_sk_destruct;
+#ifdef CONFIG_SOCK_VALIDATE_XMIT
+	return sk_fullsock(sk) &
+	       (smp_load_acquire(&sk->sk_validate_xmit_skb) ==
+	       &tls_validate_xmit_skb);
+#else
+	return false;
+#endif
 }
 
 static inline void tls_err_abort(struct sock *sk, int err)
@@ -387,10 +420,27 @@ tls_offload_ctx_tx(const struct tls_cont
 	return (struct tls_offload_context_tx *)tls_ctx->priv_ctx_tx;
 }
 
+static inline struct tls_offload_context_rx *
+tls_offload_ctx_rx(const struct tls_context *tls_ctx)
+{
+	return (struct tls_offload_context_rx *)tls_ctx->priv_ctx_rx;
+}
+
+/* The TLS context is valid until sk_destruct is called */
+static inline void tls_offload_rx_resync_request(struct sock *sk, __be32 seq)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_offload_context_rx *rx_ctx = tls_offload_ctx_rx(tls_ctx);
+
+	atomic64_set(&rx_ctx->resync_req, ((((uint64_t)seq) << 32) | 1));
+}
+
+
 int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
 		      unsigned char *record_type);
 void tls_register_device(struct tls_device *device);
 void tls_unregister_device(struct tls_device *device);
+int tls_device_decrypted(struct sock *sk, struct sk_buff *skb);
 int decrypt_skb(struct sock *sk, struct sk_buff *skb,
 		struct scatterlist *sgout);
 
@@ -402,4 +452,9 @@ int tls_sw_fallback_init(struct sock *sk
 			 struct tls_offload_context_tx *offload_ctx,
 			 struct tls_crypto_info *crypto_info);
 
+int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx);
+
+void tls_device_offload_cleanup_rx(struct sock *sk);
+void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn);
+
 #endif /* _TLS_OFFLOAD_H */
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -52,7 +52,11 @@ static DEFINE_SPINLOCK(tls_device_lock);
 
 static void tls_device_free_ctx(struct tls_context *ctx)
 {
-	kfree(tls_offload_ctx_tx(ctx));
+	if (ctx->tx_conf == TLS_HW)
+		kfree(tls_offload_ctx_tx(ctx));
+
+	if (ctx->rx_conf == TLS_HW)
+		kfree(tls_offload_ctx_rx(ctx));
 
 	kfree(ctx);
 }
@@ -70,10 +74,11 @@ static void tls_device_gc_task(struct wo
 	list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
 		struct net_device *netdev = ctx->netdev;
 
-		if (netdev) {
+		if (netdev && ctx->tx_conf == TLS_HW) {
 			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
 							TLS_OFFLOAD_CTX_DIR_TX);
 			dev_put(netdev);
+			ctx->netdev = NULL;
 		}
 
 		list_del(&ctx->list);
@@ -81,6 +86,22 @@ static void tls_device_gc_task(struct wo
 	}
 }
 
+static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
+			      struct net_device *netdev)
+{
+	if (sk->sk_destruct != tls_device_sk_destruct) {
+		refcount_set(&ctx->refcount, 1);
+		dev_hold(netdev);
+		ctx->netdev = netdev;
+		spin_lock_irq(&tls_device_lock);
+		list_add_tail(&ctx->list, &tls_device_list);
+		spin_unlock_irq(&tls_device_lock);
+
+		ctx->sk_destruct = sk->sk_destruct;
+		sk->sk_destruct = tls_device_sk_destruct;
+	}
+}
+
 static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
 {
 	unsigned long flags;
@@ -180,13 +201,15 @@ void tls_device_sk_destruct(struct sock
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
 
-	if (ctx->open_record)
-		destroy_record(ctx->open_record);
+	tls_ctx->sk_destruct(sk);
 
-	delete_all_records(ctx);
-	crypto_free_aead(ctx->aead_send);
-	ctx->sk_destruct(sk);
-	clean_acked_data_disable(inet_csk(sk));
+	if (tls_ctx->tx_conf == TLS_HW) {
+		if (ctx->open_record)
+			destroy_record(ctx->open_record);
+		delete_all_records(ctx);
+		crypto_free_aead(ctx->aead_send);
+		clean_acked_data_disable(inet_csk(sk));
+	}
 
 	if (refcount_dec_and_test(&tls_ctx->refcount))
 		tls_device_queue_ctx_destruction(tls_ctx);
@@ -519,6 +542,118 @@ static int tls_device_push_pending_recor
 	return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
 }
 
+void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct net_device *netdev = tls_ctx->netdev;
+	struct tls_offload_context_rx *rx_ctx;
+	u32 is_req_pending;
+	s64 resync_req;
+	u32 req_seq;
+
+	if (tls_ctx->rx_conf != TLS_HW)
+		return;
+
+	rx_ctx = tls_offload_ctx_rx(tls_ctx);
+	resync_req = atomic64_read(&rx_ctx->resync_req);
+	req_seq = ntohl(resync_req >> 32) - ((u32)TLS_HEADER_SIZE - 1);
+	is_req_pending = resync_req;
+
+	if (unlikely(is_req_pending) && req_seq == seq &&
+	    atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
+		netdev->tlsdev_ops->tls_dev_resync_rx(netdev, sk,
+						      seq + TLS_HEADER_SIZE - 1,
+						      rcd_sn);
+}
+
+static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
+{
+	struct strp_msg *rxm = strp_msg(skb);
+	int err = 0, offset = rxm->offset, copy, nsg;
+	struct sk_buff *skb_iter, *unused;
+	struct scatterlist sg[1];
+	char *orig_buf, *buf;
+
+	orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
+			   TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
+	if (!orig_buf)
+		return -ENOMEM;
+	buf = orig_buf;
+
+	nsg = skb_cow_data(skb, 0, &unused);
+	if (unlikely(nsg < 0)) {
+		err = nsg;
+		goto free_buf;
+	}
+
+	sg_init_table(sg, 1);
+	sg_set_buf(&sg[0], buf,
+		   rxm->full_len + TLS_HEADER_SIZE +
+		   TLS_CIPHER_AES_GCM_128_IV_SIZE);
+	skb_copy_bits(skb, offset, buf,
+		      TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
+
+	/* We are interested only in the decrypted data not the auth */
+	err = decrypt_skb(sk, skb, sg);
+	if (err != -EBADMSG)
+		goto free_buf;
+	else
+		err = 0;
+
+	copy = min_t(int, skb_pagelen(skb) - offset,
+		     rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+
+	if (skb->decrypted)
+		skb_store_bits(skb, offset, buf, copy);
+
+	offset += copy;
+	buf += copy;
+
+	skb_walk_frags(skb, skb_iter) {
+		copy = min_t(int, skb_iter->len,
+			     rxm->full_len - offset + rxm->offset -
+			     TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+
+		if (skb_iter->decrypted)
+			skb_store_bits(skb, offset, buf, copy);
+
+		offset += copy;
+		buf += copy;
+	}
+
+free_buf:
+	kfree(orig_buf);
+	return err;
+}
+
+int tls_device_decrypted(struct sock *sk, struct sk_buff *skb)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
+	int is_decrypted = skb->decrypted;
+	int is_encrypted = !is_decrypted;
+	struct sk_buff *skb_iter;
+
+	/* Skip if it is already decrypted */
+	if (ctx->sw.decrypted)
+		return 0;
+
+	/* Check if all the data is decrypted already */
+	skb_walk_frags(skb, skb_iter) {
+		is_decrypted &= skb_iter->decrypted;
+		is_encrypted &= !skb_iter->decrypted;
+	}
+
+	ctx->sw.decrypted |= is_decrypted;
+
+	/* Return immedeatly if the record is either entirely plaintext or
+	 * entirely ciphertext. Otherwise handle reencrypt partially decrypted
+	 * record.
+	 */
+	return (is_encrypted || is_decrypted) ? 0 :
+		tls_device_reencrypt(sk, skb);
+}
+
 int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
 {
 	u16 nonce_size, tag_size, iv_size, rec_seq_size;
@@ -608,7 +743,6 @@ int tls_set_device_offload(struct sock *
 
 	clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
 	ctx->push_pending_record = tls_device_push_pending_record;
-	offload_ctx->sk_destruct = sk->sk_destruct;
 
 	/* TLS offload is greatly simplified if we don't send
 	 * SKBs where only part of the payload needs to be encrypted.
@@ -618,8 +752,6 @@ int tls_set_device_offload(struct sock *
 	if (skb)
 		TCP_SKB_CB(skb)->eor = 1;
 
-	refcount_set(&ctx->refcount, 1);
-
 	/* We support starting offload on multiple sockets
 	 * concurrently, so we only need a read lock here.
 	 * This lock must precede get_netdev_for_sock to prevent races between
@@ -654,19 +786,14 @@ int tls_set_device_offload(struct sock *
 	if (rc)
 		goto release_netdev;
 
-	ctx->netdev = netdev;
+	tls_device_attach(ctx, sk, netdev);
 
-	spin_lock_irq(&tls_device_lock);
-	list_add_tail(&ctx->list, &tls_device_list);
-	spin_unlock_irq(&tls_device_lock);
-
-	sk->sk_validate_xmit_skb = tls_validate_xmit_skb;
 	/* following this assignment tls_is_sk_tx_device_offloaded
 	 * will return true and the context might be accessed
 	 * by the netdev's xmit function.
 	 */
-	smp_store_release(&sk->sk_destruct,
-			  &tls_device_sk_destruct);
+	smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
+	dev_put(netdev);
 	up_read(&device_offload_lock);
 	goto out;
 
@@ -689,6 +816,105 @@ out:
 	return rc;
 }
 
+int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
+{
+	struct tls_offload_context_rx *context;
+	struct net_device *netdev;
+	int rc = 0;
+
+	/* We support starting offload on multiple sockets
+	 * concurrently, so we only need a read lock here.
+	 * This lock must precede get_netdev_for_sock to prevent races between
+	 * NETDEV_DOWN and setsockopt.
+	 */
+	down_read(&device_offload_lock);
+	netdev = get_netdev_for_sock(sk);
+	if (!netdev) {
+		pr_err_ratelimited("%s: netdev not found\n", __func__);
+		rc = -EINVAL;
+		goto release_lock;
+	}
+
+	if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
+		pr_err_ratelimited("%s: netdev %s with no TLS offload\n",
+				   __func__, netdev->name);
+		rc = -ENOTSUPP;
+		goto release_netdev;
+	}
+
+	/* Avoid offloading if the device is down
+	 * We don't want to offload new flows after
+	 * the NETDEV_DOWN event
+	 */
+	if (!(netdev->flags & IFF_UP)) {
+		rc = -EINVAL;
+		goto release_netdev;
+	}
+
+	context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
+	if (!context) {
+		rc = -ENOMEM;
+		goto release_netdev;
+	}
+
+	ctx->priv_ctx_rx = context;
+	rc = tls_set_sw_offload(sk, ctx, 0);
+	if (rc)
+		goto release_ctx;
+
+	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
+					     &ctx->crypto_recv,
+					     tcp_sk(sk)->copied_seq);
+	if (rc) {
+		pr_err_ratelimited("%s: The netdev has refused to offload this socket\n",
+				   __func__);
+		goto free_sw_resources;
+	}
+
+	tls_device_attach(ctx, sk, netdev);
+	goto release_netdev;
+
+free_sw_resources:
+	tls_sw_free_resources_rx(sk);
+release_ctx:
+	ctx->priv_ctx_rx = NULL;
+release_netdev:
+	dev_put(netdev);
+release_lock:
+	up_read(&device_offload_lock);
+	return rc;
+}
+
+void tls_device_offload_cleanup_rx(struct sock *sk)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct net_device *netdev;
+
+	down_read(&device_offload_lock);
+	netdev = tls_ctx->netdev;
+	if (!netdev)
+		goto out;
+
+	if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
+		pr_err_ratelimited("%s: device is missing NETIF_F_HW_TLS_RX cap\n",
+				   __func__);
+		goto out;
+	}
+
+	netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
+					TLS_OFFLOAD_CTX_DIR_RX);
+
+	if (tls_ctx->tx_conf != TLS_HW) {
+		dev_put(netdev);
+		tls_ctx->netdev = NULL;
+	}
+out:
+	up_read(&device_offload_lock);
+	kfree(tls_ctx->rx.rec_seq);
+	kfree(tls_ctx->rx.iv);
+	tls_sw_release_resources_rx(sk);
+}
+
 static int tls_device_down(struct net_device *netdev)
 {
 	struct tls_context *ctx, *tmp;
@@ -709,8 +935,12 @@ static int tls_device_down(struct net_de
 	spin_unlock_irqrestore(&tls_device_lock, flags);
 
 	list_for_each_entry_safe(ctx, tmp, &list, list)	{
-		netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
-						TLS_OFFLOAD_CTX_DIR_TX);
+		if (ctx->tx_conf == TLS_HW)
+			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
+							TLS_OFFLOAD_CTX_DIR_TX);
+		if (ctx->rx_conf == TLS_HW)
+			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
+							TLS_OFFLOAD_CTX_DIR_RX);
 		ctx->netdev = NULL;
 		dev_put(netdev);
 		list_del_init(&ctx->list);
@@ -731,12 +961,16 @@ static int tls_dev_event(struct notifier
 {
 	struct net_device *dev = netdev_notifier_info_to_dev(ptr);
 
-	if (!(dev->features & NETIF_F_HW_TLS_TX))
+	if (!(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
 		return NOTIFY_DONE;
 
 	switch (event) {
 	case NETDEV_REGISTER:
 	case NETDEV_FEAT_CHANGE:
+		if ((dev->features & NETIF_F_HW_TLS_RX) &&
+		    !dev->tlsdev_ops->tls_dev_resync_rx)
+			return NOTIFY_BAD;
+
 		if  (dev->tlsdev_ops &&
 		     dev->tlsdev_ops->tls_dev_add &&
 		     dev->tlsdev_ops->tls_dev_del)
--- a/net/tls/tls_device_fallback.c
+++ b/net/tls/tls_device_fallback.c
@@ -413,6 +413,7 @@ struct sk_buff *tls_validate_xmit_skb(st
 
 	return tls_sw_fallback(sk, skb);
 }
+EXPORT_SYMBOL_GPL(tls_validate_xmit_skb);
 
 int tls_sw_fallback_init(struct sock *sk,
 			 struct tls_offload_context_tx *offload_ctx,
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -51,15 +51,6 @@ enum {
 	TLSV6,
 	TLS_NUM_PROTS,
 };
-enum {
-	TLS_BASE,
-	TLS_SW,
-#ifdef CONFIG_TLS_DEVICE
-	TLS_HW,
-#endif
-	TLS_HW_RECORD,
-	TLS_NUM_CONFIG,
-};
 
 static struct proto *saved_tcpv6_prot;
 static DEFINE_MUTEX(tcpv6_prot_mutex);
@@ -290,7 +281,10 @@ static void tls_sk_proto_close(struct so
 	}
 
 #ifdef CONFIG_TLS_DEVICE
-	if (ctx->tx_conf != TLS_HW) {
+	if (ctx->rx_conf == TLS_HW)
+		tls_device_offload_cleanup_rx(sk);
+
+	if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) {
 #else
 	{
 #endif
@@ -470,8 +464,16 @@ static int do_tls_setsockopt_conf(struct
 			conf = TLS_SW;
 		}
 	} else {
-		rc = tls_set_sw_offload(sk, ctx, 0);
-		conf = TLS_SW;
+#ifdef CONFIG_TLS_DEVICE
+		rc = tls_set_device_offload_rx(sk, ctx);
+		conf = TLS_HW;
+		if (rc) {
+#else
+		{
+#endif
+			rc = tls_set_sw_offload(sk, ctx, 0);
+			conf = TLS_SW;
+		}
 	}
 
 	if (rc)
@@ -629,6 +631,12 @@ static void build_protos(struct proto pr
 	prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
 	prot[TLS_HW][TLS_SW].sendmsg		= tls_device_sendmsg;
 	prot[TLS_HW][TLS_SW].sendpage		= tls_device_sendpage;
+
+	prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
+
+	prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
+
+	prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
 #endif
 
 	prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -657,16 +657,25 @@ static struct sk_buff *tls_wait_data(str
 }
 
 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
-			      struct scatterlist *sgout)
+			      struct scatterlist *sgout, bool *zc)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 	struct strp_msg *rxm = strp_msg(skb);
 	int err = 0;
 
-	err = decrypt_skb(sk, skb, sgout);
+#ifdef CONFIG_TLS_DEVICE
+	err = tls_device_decrypted(sk, skb);
 	if (err < 0)
 		return err;
+#endif
+	if (!ctx->decrypted) {
+		err = decrypt_skb(sk, skb, sgout);
+		if (err < 0)
+			return err;
+	} else {
+		*zc = false;
+	}
 
 	rxm->offset += tls_ctx->rx.prepend_size;
 	rxm->full_len -= tls_ctx->rx.overhead_size;
@@ -828,7 +837,7 @@ int tls_sw_recvmsg(struct sock *sk,
 				if (err < 0)
 					goto fallback_to_reg_recv;
 
-				err = decrypt_skb_update(sk, skb, sgin);
+				err = decrypt_skb_update(sk, skb, sgin, &zc);
 				for (; pages > 0; pages--)
 					put_page(sg_page(&sgin[pages]));
 				if (err < 0) {
@@ -837,7 +846,7 @@ int tls_sw_recvmsg(struct sock *sk,
 				}
 			} else {
 fallback_to_reg_recv:
-				err = decrypt_skb_update(sk, skb, NULL);
+				err = decrypt_skb_update(sk, skb, NULL, &zc);
 				if (err < 0) {
 					tls_err_abort(sk, EBADMSG);
 					goto recv_end;
@@ -892,6 +901,7 @@ ssize_t tls_sw_splice_read(struct socket
 	int err = 0;
 	long timeo;
 	int chunk;
+	bool zc;
 
 	lock_sock(sk);
 
@@ -908,7 +918,7 @@ ssize_t tls_sw_splice_read(struct socket
 	}
 
 	if (!ctx->decrypted) {
-		err = decrypt_skb_update(sk, skb, NULL);
+		err = decrypt_skb_update(sk, skb, NULL, &zc);
 
 		if (err < 0) {
 			tls_err_abort(sk, EBADMSG);
@@ -997,6 +1007,10 @@ static int tls_read_size(struct strparse
 		goto read_failure;
 	}
 
+#ifdef CONFIG_TLS_DEVICE
+	handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
+			     *(u64*)tls_ctx->rx.rec_seq);
+#endif
 	return data_len + TLS_HEADER_SIZE;
 
 read_failure: