summaryrefslogtreecommitdiff
path: root/net/vmw_vsock/virtio_transport_common.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/vmw_vsock/virtio_transport_common.c')
-rw-r--r--net/vmw_vsock/virtio_transport_common.c104
1 files changed, 89 insertions, 15 deletions
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index 957cdc01c8e8..b769fc258931 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -94,6 +94,11 @@ virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info,
info->op,
info->flags);
+ if (info->vsk && !skb_set_owner_sk_safe(skb, sk_vsock(info->vsk))) {
+ WARN_ONCE(1, "failed to allocate skb on vsock socket with sk_refcnt == 0\n");
+ goto out;
+ }
+
return skb;
out:
@@ -196,7 +201,8 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
const struct virtio_transport *t_ops;
struct virtio_vsock_sock *vvs;
u32 pkt_len = info->pkt_len;
- struct sk_buff *skb;
+ u32 rest_len;
+ int ret;
info->type = virtio_transport_get_type(sk_vsock(vsk));
@@ -216,10 +222,6 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
vvs = vsk->trans;
- /* we can send less than pkt_len bytes */
- if (pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
- pkt_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE;
-
/* virtio_transport_get_credit might return less than pkt_len credit */
pkt_len = virtio_transport_get_credit(vvs, pkt_len);
@@ -227,17 +229,49 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
return pkt_len;
- skb = virtio_transport_alloc_skb(info, pkt_len,
- src_cid, src_port,
- dst_cid, dst_port);
- if (!skb) {
- virtio_transport_put_credit(vvs, pkt_len);
- return -ENOMEM;
- }
+ rest_len = pkt_len;
+
+ do {
+ struct sk_buff *skb;
+ size_t skb_len;
+
+ skb_len = min_t(u32, VIRTIO_VSOCK_MAX_PKT_BUF_SIZE, rest_len);
+
+ skb = virtio_transport_alloc_skb(info, skb_len,
+ src_cid, src_port,
+ dst_cid, dst_port);
+ if (!skb) {
+ ret = -ENOMEM;
+ break;
+ }
- virtio_transport_inc_tx_pkt(vvs, skb);
+ virtio_transport_inc_tx_pkt(vvs, skb);
- return t_ops->send_pkt(skb);
+ ret = t_ops->send_pkt(skb);
+ if (ret < 0)
+ break;
+
+ /* Both virtio and vhost 'send_pkt()' returns 'skb_len',
+ * but for reliability use 'ret' instead of 'skb_len'.
+ * Also if partial send happens (e.g. 'ret' != 'skb_len')
+ * somehow, we break this loop, but account such returned
+ * value in 'virtio_transport_put_credit()'.
+ */
+ rest_len -= ret;
+
+ if (WARN_ONCE(ret != skb_len,
+ "'send_pkt()' returns %i, but %zu expected\n",
+ ret, skb_len))
+ break;
+ } while (rest_len);
+
+ virtio_transport_put_credit(vvs, rest_len);
+
+ /* Return number of bytes, if any data has been sent. */
+ if (rest_len != pkt_len)
+ ret = pkt_len - rest_len;
+
+ return ret;
}
static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
@@ -273,6 +307,9 @@ u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
{
u32 ret;
+ if (!credit)
+ return 0;
+
spin_lock_bh(&vvs->tx_lock);
ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
if (ret > credit)
@@ -286,6 +323,9 @@ EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
{
+ if (!credit)
+ return;
+
spin_lock_bh(&vvs->tx_lock);
vvs->tx_cnt -= credit;
spin_unlock_bh(&vvs->tx_lock);
@@ -363,6 +403,13 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
u32 free_space;
spin_lock_bh(&vvs->rx_lock);
+
+ if (WARN_ONCE(skb_queue_empty(&vvs->rx_queue) && vvs->rx_bytes,
+ "rx_queue is empty, but rx_bytes is non-zero\n")) {
+ spin_unlock_bh(&vvs->rx_lock);
+ return err;
+ }
+
while (total < len && !skb_queue_empty(&vvs->rx_queue)) {
skb = skb_peek(&vvs->rx_queue);
@@ -1066,7 +1113,7 @@ virtio_transport_recv_enqueue(struct vsock_sock *vsk,
memcpy(skb_put(last_skb, skb->len), skb->data, skb->len);
free_pkt = true;
last_hdr->flags |= hdr->flags;
- last_hdr->len = cpu_to_le32(last_skb->len);
+ le32_add_cpu(&last_hdr->len, len);
goto out;
}
}
@@ -1294,6 +1341,11 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
goto free_pkt;
}
+ if (!skb_set_owner_sk_safe(skb, sk)) {
+ WARN_ONCE(1, "receiving vsock socket has sk_refcnt == 0\n");
+ goto free_pkt;
+ }
+
vsk = vsock_sk(sk);
lock_sock(sk);
@@ -1383,6 +1435,28 @@ int virtio_transport_purge_skbs(void *vsk, struct sk_buff_head *queue)
}
EXPORT_SYMBOL_GPL(virtio_transport_purge_skbs);
+int virtio_transport_read_skb(struct vsock_sock *vsk, skb_read_actor_t recv_actor)
+{
+ struct virtio_vsock_sock *vvs = vsk->trans;
+ struct sock *sk = sk_vsock(vsk);
+ struct sk_buff *skb;
+ int off = 0;
+ int err;
+
+ spin_lock_bh(&vvs->rx_lock);
+ /* Use __skb_recv_datagram() for race-free handling of the receive. It
+ * works for types other than dgrams.
+ */
+ skb = __skb_recv_datagram(sk, &vvs->rx_queue, MSG_DONTWAIT, &off, &err);
+ spin_unlock_bh(&vvs->rx_lock);
+
+ if (!skb)
+ return err;
+
+ return recv_actor(sk, skb);
+}
+EXPORT_SYMBOL_GPL(virtio_transport_read_skb);
+
MODULE_LICENSE("GPL v2");
MODULE_AUTHOR("Asias He");
MODULE_DESCRIPTION("common code for virtio vsock");