Blob Blame History Raw
From: Arseniy Krasnov <avkrasnov@salutedevices.com>
Date: Sat, 16 Sep 2023 16:09:18 +0300
Subject: vsock/virtio: MSG_ZEROCOPY flag support
Patch-mainline: v6.7-rc1
Git-commit: 581512a6dc939ef122e49336626ae159f3b8a345
References: jsc#PED-5505

This adds handling of MSG_ZEROCOPY flag on transmission path:

1) If this flag is set and zerocopy transmission is possible (enabled
   in socket options and transport allows zerocopy), then non-linear
   skb will be created and filled with the pages of user's buffer.
   Pages of user's buffer are locked in memory by 'get_user_pages()'.
2) Replaces way of skb owning: instead of 'skb_set_owner_sk_safe()' it
   calls 'skb_set_owner_w()'. Reason of this change is that
   '__zerocopy_sg_from_iter()' increments 'sk_wmem_alloc' of socket, so
   to decrease this field correctly, proper skb destructor is needed:
   'sock_wfree()'. This destructor is set by 'skb_set_owner_w()'.
3) Adds new callback to 'struct virtio_transport': 'can_msgzerocopy'.
   If this callback is set, then transport needs extra check to be able
   to send provided number of buffers in zerocopy mode. Currently, the
   only transport that needs this callback set is virtio, because this
   transport adds new buffers to the virtio queue and we need to check,
   that number of these buffers is less than size of the queue (it is
   required by virtio spec). vhost and loopback transports don't need
   this check.

Signed-off-by: Arseniy Krasnov <avkrasnov@salutedevices.com>
Reviewed-by: Stefano Garzarella <sgarzare@redhat.com>
Acked-by: Michael S. Tsirkin <mst@redhat.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
Acked-by: Thomas Bogendoerfer <tbogendoerfer@suse.de>
---
 include/linux/virtio_vsock.h                         |    9 
 include/trace/events/vsock_virtio_transport_common.h |   12 
 net/vmw_vsock/virtio_transport.c                     |   32 ++
 net/vmw_vsock/virtio_transport_common.c              |  248 ++++++++++++++-----
 4 files changed, 240 insertions(+), 61 deletions(-)

--- a/include/linux/virtio_vsock.h
+++ b/include/linux/virtio_vsock.h
@@ -160,6 +160,15 @@ struct virtio_transport {
 
 	/* Takes ownership of the packet */
 	int (*send_pkt)(struct sk_buff *skb);
+
+	/* Used in MSG_ZEROCOPY mode. Checks, that provided data
+	 * (number of buffers) could be transmitted with zerocopy
+	 * mode. If this callback is not implemented for the current
+	 * transport - this means that this transport doesn't need
+	 * extra checks and can perform zerocopy transmission by
+	 * default.
+	 */
+	bool (*can_msgzerocopy)(int bufs_num);
 };
 
 ssize_t
--- a/include/trace/events/vsock_virtio_transport_common.h
+++ b/include/trace/events/vsock_virtio_transport_common.h
@@ -43,7 +43,8 @@ TRACE_EVENT(virtio_transport_alloc_pkt,
 		 __u32 len,
 		 __u16 type,
 		 __u16 op,
-		 __u32 flags
+		 __u32 flags,
+		 bool zcopy
 	),
 	TP_ARGS(
 		src_cid, src_port,
@@ -51,7 +52,8 @@ TRACE_EVENT(virtio_transport_alloc_pkt,
 		len,
 		type,
 		op,
-		flags
+		flags,
+		zcopy
 	),
 	TP_STRUCT__entry(
 		__field(__u32, src_cid)
@@ -62,6 +64,7 @@ TRACE_EVENT(virtio_transport_alloc_pkt,
 		__field(__u16, type)
 		__field(__u16, op)
 		__field(__u32, flags)
+		__field(bool, zcopy)
 	),
 	TP_fast_assign(
 		__entry->src_cid = src_cid;
@@ -72,14 +75,15 @@ TRACE_EVENT(virtio_transport_alloc_pkt,
 		__entry->type = type;
 		__entry->op = op;
 		__entry->flags = flags;
+		__entry->zcopy = zcopy;
 	),
-	TP_printk("%u:%u -> %u:%u len=%u type=%s op=%s flags=%#x",
+	TP_printk("%u:%u -> %u:%u len=%u type=%s op=%s flags=%#x zcopy=%s",
 		  __entry->src_cid, __entry->src_port,
 		  __entry->dst_cid, __entry->dst_port,
 		  __entry->len,
 		  show_type(__entry->type),
 		  show_op(__entry->op),
-		  __entry->flags)
+		  __entry->flags, __entry->zcopy ? "true" : "false")
 );
 
 TRACE_EVENT(virtio_transport_recv_pkt,
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -455,6 +455,37 @@ static void virtio_vsock_rx_done(struct
 	queue_work(virtio_vsock_workqueue, &vsock->rx_work);
 }
 
+static bool virtio_transport_can_msgzerocopy(int bufs_num)
+{
+	struct virtio_vsock *vsock;
+	bool res = false;
+
+	rcu_read_lock();
+
+	vsock = rcu_dereference(the_virtio_vsock);
+	if (vsock) {
+		struct virtqueue *vq = vsock->vqs[VSOCK_VQ_TX];
+
+		/* Check that tx queue is large enough to keep whole
+		 * data to send. This is needed, because when there is
+		 * not enough free space in the queue, current skb to
+		 * send will be reinserted to the head of tx list of
+		 * the socket to retry transmission later, so if skb
+		 * is bigger than whole queue, it will be reinserted
+		 * again and again, thus blocking other skbs to be sent.
+		 * Each page of the user provided buffer will be added
+		 * as a single buffer to the tx virtqueue, so compare
+		 * number of pages against maximum capacity of the queue.
+		 */
+		if (bufs_num <= vq->num_max)
+			res = true;
+	}
+
+	rcu_read_unlock();
+
+	return res;
+}
+
 static bool virtio_transport_seqpacket_allow(u32 remote_cid);
 
 static struct virtio_transport virtio_transport = {
@@ -504,6 +535,7 @@ static struct virtio_transport virtio_tr
 	},
 
 	.send_pkt = virtio_transport_send_pkt,
+	.can_msgzerocopy = virtio_transport_can_msgzerocopy,
 };
 
 static bool virtio_transport_seqpacket_allow(u32 remote_cid)
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -37,73 +37,99 @@ virtio_transport_get_ops(struct vsock_so
 	return container_of(t, struct virtio_transport, transport);
 }
 
-/* Returns a new packet on success, otherwise returns NULL.
- *
- * If NULL is returned, errp is set to a negative errno.
- */
-static struct sk_buff *
-virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info,
-			   size_t len,
-			   u32 src_cid,
-			   u32 src_port,
-			   u32 dst_cid,
-			   u32 dst_port)
+static bool virtio_transport_can_zcopy(const struct virtio_transport *t_ops,
+				       struct virtio_vsock_pkt_info *info,
+				       size_t pkt_len)
 {
-	const size_t skb_len = VIRTIO_VSOCK_SKB_HEADROOM + len;
-	struct virtio_vsock_hdr *hdr;
-	struct sk_buff *skb;
-	void *payload;
-	int err;
+	struct iov_iter *iov_iter;
 
-	skb = virtio_vsock_alloc_skb(skb_len, GFP_KERNEL);
-	if (!skb)
-		return NULL;
+	if (!info->msg)
+		return false;
 
-	hdr = virtio_vsock_hdr(skb);
-	hdr->type	= cpu_to_le16(info->type);
-	hdr->op		= cpu_to_le16(info->op);
-	hdr->src_cid	= cpu_to_le64(src_cid);
-	hdr->dst_cid	= cpu_to_le64(dst_cid);
-	hdr->src_port	= cpu_to_le32(src_port);
-	hdr->dst_port	= cpu_to_le32(dst_port);
-	hdr->flags	= cpu_to_le32(info->flags);
-	hdr->len	= cpu_to_le32(len);
+	iov_iter = &info->msg->msg_iter;
 
-	if (info->msg && len > 0) {
-		payload = skb_put(skb, len);
-		err = memcpy_from_msg(payload, info->msg, len);
-		if (err)
-			goto out;
+	if (iov_iter->iov_offset)
+		return false;
 
-		if (msg_data_left(info->msg) == 0 &&
-		    info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) {
-			hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
+	/* We can't send whole iov. */
+	if (iov_iter->count > pkt_len)
+		return false;
 
-			if (info->msg->msg_flags & MSG_EOR)
-				hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
-		}
+	/* Check that transport can send data in zerocopy mode. */
+	t_ops = virtio_transport_get_ops(info->vsk);
+
+	if (t_ops->can_msgzerocopy) {
+		int pages_in_iov = iov_iter_npages(iov_iter, MAX_SKB_FRAGS);
+		int pages_to_send = min(pages_in_iov, MAX_SKB_FRAGS);
+
+		/* +1 is for packet header. */
+		return t_ops->can_msgzerocopy(pages_to_send + 1);
 	}
 
-	if (info->reply)
-		virtio_vsock_skb_set_reply(skb);
+	return true;
+}
 
-	trace_virtio_transport_alloc_pkt(src_cid, src_port,
-					 dst_cid, dst_port,
-					 len,
-					 info->type,
-					 info->op,
-					 info->flags);
+static int virtio_transport_init_zcopy_skb(struct vsock_sock *vsk,
+					   struct sk_buff *skb,
+					   struct msghdr *msg,
+					   bool zerocopy)
+{
+	struct ubuf_info *uarg;
+
+	if (msg->msg_ubuf) {
+		uarg = msg->msg_ubuf;
+		net_zcopy_get(uarg);
+	} else {
+		struct iov_iter *iter = &msg->msg_iter;
+		struct ubuf_info_msgzc *uarg_zc;
+
+		uarg = msg_zerocopy_realloc(sk_vsock(vsk),
+					    iter->count,
+					    NULL);
+		if (!uarg)
+			return -1;
 
-	if (info->vsk && !skb_set_owner_sk_safe(skb, sk_vsock(info->vsk))) {
-		WARN_ONCE(1, "failed to allocate skb on vsock socket with sk_refcnt == 0\n");
-		goto out;
+		uarg_zc = uarg_to_msgzc(uarg);
+		uarg_zc->zerocopy = zerocopy ? 1 : 0;
 	}
 
-	return skb;
+	skb_zcopy_init(skb, uarg);
 
-out:
-	kfree_skb(skb);
-	return NULL;
+	return 0;
+}
+
+static int virtio_transport_fill_skb(struct sk_buff *skb,
+				     struct virtio_vsock_pkt_info *info,
+				     size_t len,
+				     bool zcopy)
+{
+	if (zcopy)
+		return __zerocopy_sg_from_iter(info->msg, NULL, skb,
+					       &info->msg->msg_iter,
+					       len);
+
+	return memcpy_from_msg(skb_put(skb, len), info->msg, len);
+}
+
+static void virtio_transport_init_hdr(struct sk_buff *skb,
+				      struct virtio_vsock_pkt_info *info,
+				      size_t payload_len,
+				      u32 src_cid,
+				      u32 src_port,
+				      u32 dst_cid,
+				      u32 dst_port)
+{
+	struct virtio_vsock_hdr *hdr;
+
+	hdr = virtio_vsock_hdr(skb);
+	hdr->type	= cpu_to_le16(info->type);
+	hdr->op		= cpu_to_le16(info->op);
+	hdr->src_cid	= cpu_to_le64(src_cid);
+	hdr->dst_cid	= cpu_to_le64(dst_cid);
+	hdr->src_port	= cpu_to_le32(src_port);
+	hdr->dst_port	= cpu_to_le32(dst_port);
+	hdr->flags	= cpu_to_le32(info->flags);
+	hdr->len	= cpu_to_le32(payload_len);
 }
 
 static void virtio_transport_copy_nonlinear_skb(const struct sk_buff *skb,
@@ -214,6 +240,82 @@ static u16 virtio_transport_get_type(str
 		return VIRTIO_VSOCK_TYPE_SEQPACKET;
 }
 
+/* Returns new sk_buff on success, otherwise returns NULL. */
+static struct sk_buff *virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info,
+						  size_t payload_len,
+						  bool zcopy,
+						  u32 src_cid,
+						  u32 src_port,
+						  u32 dst_cid,
+						  u32 dst_port)
+{
+	struct vsock_sock *vsk;
+	struct sk_buff *skb;
+	size_t skb_len;
+
+	skb_len = VIRTIO_VSOCK_SKB_HEADROOM;
+
+	if (!zcopy)
+		skb_len += payload_len;
+
+	skb = virtio_vsock_alloc_skb(skb_len, GFP_KERNEL);
+	if (!skb)
+		return NULL;
+
+	virtio_transport_init_hdr(skb, info, payload_len, src_cid, src_port,
+				  dst_cid, dst_port);
+
+	vsk = info->vsk;
+
+	/* If 'vsk' != NULL then payload is always present, so we
+	 * will never call '__zerocopy_sg_from_iter()' below without
+	 * setting skb owner in 'skb_set_owner_w()'. The only case
+	 * when 'vsk' == NULL is VIRTIO_VSOCK_OP_RST control message
+	 * without payload.
+	 */
+	WARN_ON_ONCE(!(vsk && (info->msg && payload_len)) && zcopy);
+
+	/* Set owner here, because '__zerocopy_sg_from_iter()' uses
+	 * owner of skb without check to update 'sk_wmem_alloc'.
+	 */
+	if (vsk)
+		skb_set_owner_w(skb, sk_vsock(vsk));
+
+	if (info->msg && payload_len > 0) {
+		int err;
+
+		err = virtio_transport_fill_skb(skb, info, payload_len, zcopy);
+		if (err)
+			goto out;
+
+		if (msg_data_left(info->msg) == 0 &&
+		    info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) {
+			struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
+
+			hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
+
+			if (info->msg->msg_flags & MSG_EOR)
+				hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
+		}
+	}
+
+	if (info->reply)
+		virtio_vsock_skb_set_reply(skb);
+
+	trace_virtio_transport_alloc_pkt(src_cid, src_port,
+					 dst_cid, dst_port,
+					 payload_len,
+					 info->type,
+					 info->op,
+					 info->flags,
+					 zcopy);
+
+	return skb;
+out:
+	kfree_skb(skb);
+	return NULL;
+}
+
 /* This function can only be used on connecting/connected sockets,
  * since a socket assigned to a transport is required.
  *
@@ -222,10 +324,12 @@ static u16 virtio_transport_get_type(str
 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
 					  struct virtio_vsock_pkt_info *info)
 {
+	u32 max_skb_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE;
 	u32 src_cid, src_port, dst_cid, dst_port;
 	const struct virtio_transport *t_ops;
 	struct virtio_vsock_sock *vvs;
 	u32 pkt_len = info->pkt_len;
+	bool can_zcopy = false;
 	u32 rest_len;
 	int ret;
 
@@ -254,15 +358,30 @@ static int virtio_transport_send_pkt_inf
 	if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
 		return pkt_len;
 
+	if (info->msg) {
+		/* If zerocopy is not enabled by 'setsockopt()', we behave as
+		 * there is no MSG_ZEROCOPY flag set.
+		 */
+		if (!sock_flag(sk_vsock(vsk), SOCK_ZEROCOPY))
+			info->msg->msg_flags &= ~MSG_ZEROCOPY;
+
+		if (info->msg->msg_flags & MSG_ZEROCOPY)
+			can_zcopy = virtio_transport_can_zcopy(t_ops, info, pkt_len);
+
+		if (can_zcopy)
+			max_skb_len = min_t(u32, VIRTIO_VSOCK_MAX_PKT_BUF_SIZE,
+					    (MAX_SKB_FRAGS * PAGE_SIZE));
+	}
+
 	rest_len = pkt_len;
 
 	do {
 		struct sk_buff *skb;
 		size_t skb_len;
 
-		skb_len = min_t(u32, VIRTIO_VSOCK_MAX_PKT_BUF_SIZE, rest_len);
+		skb_len = min(max_skb_len, rest_len);
 
-		skb = virtio_transport_alloc_skb(info, skb_len,
+		skb = virtio_transport_alloc_skb(info, skb_len, can_zcopy,
 						 src_cid, src_port,
 						 dst_cid, dst_port);
 		if (!skb) {
@@ -270,6 +389,21 @@ static int virtio_transport_send_pkt_inf
 			break;
 		}
 
+		/* We process buffer part by part, allocating skb on
+		 * each iteration. If this is last skb for this buffer
+		 * and MSG_ZEROCOPY mode is in use - we must allocate
+		 * completion for the current syscall.
+		 */
+		if (info->msg && info->msg->msg_flags & MSG_ZEROCOPY &&
+		    skb_len == rest_len && info->op == VIRTIO_VSOCK_OP_RW) {
+			if (virtio_transport_init_zcopy_skb(vsk, skb,
+							    info->msg,
+							    can_zcopy)) {
+				ret = -ENOMEM;
+				break;
+			}
+		}
+
 		virtio_transport_inc_tx_pkt(vvs, skb);
 
 		ret = t_ops->send_pkt(skb);
@@ -985,7 +1119,7 @@ static int virtio_transport_reset_no_soc
 	if (!t)
 		return -ENOTCONN;
 
-	reply = virtio_transport_alloc_skb(&info, 0,
+	reply = virtio_transport_alloc_skb(&info, 0, false,
 					   le64_to_cpu(hdr->dst_cid),
 					   le32_to_cpu(hdr->dst_port),
 					   le64_to_cpu(hdr->src_cid),