diff options
Diffstat (limited to 'net/vmw_vsock/virtio_transport_common.c')
-rw-r--r-- | net/vmw_vsock/virtio_transport_common.c | 96 |
1 files changed, 77 insertions, 19 deletions
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index ee78b4082ef9..e4878551f140 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -201,7 +201,8 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, const struct virtio_transport *t_ops; struct virtio_vsock_sock *vvs; u32 pkt_len = info->pkt_len; - struct sk_buff *skb; + u32 rest_len; + int ret; info->type = virtio_transport_get_type(sk_vsock(vsk)); @@ -221,10 +222,6 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, vvs = vsk->trans; - /* we can send less than pkt_len bytes */ - if (pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) - pkt_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE; - /* virtio_transport_get_credit might return less than pkt_len credit */ pkt_len = virtio_transport_get_credit(vvs, pkt_len); @@ -232,17 +229,49 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) return pkt_len; - skb = virtio_transport_alloc_skb(info, pkt_len, - src_cid, src_port, - dst_cid, dst_port); - if (!skb) { - virtio_transport_put_credit(vvs, pkt_len); - return -ENOMEM; - } + rest_len = pkt_len; + + do { + struct sk_buff *skb; + size_t skb_len; + + skb_len = min_t(u32, VIRTIO_VSOCK_MAX_PKT_BUF_SIZE, rest_len); + + skb = virtio_transport_alloc_skb(info, skb_len, + src_cid, src_port, + dst_cid, dst_port); + if (!skb) { + ret = -ENOMEM; + break; + } + + virtio_transport_inc_tx_pkt(vvs, skb); - virtio_transport_inc_tx_pkt(vvs, skb); + ret = t_ops->send_pkt(skb); + if (ret < 0) + break; + + /* Both virtio and vhost 'send_pkt()' returns 'skb_len', + * but for reliability use 'ret' instead of 'skb_len'. + * Also if partial send happens (e.g. 'ret' != 'skb_len') + * somehow, we break this loop, but account such returned + * value in 'virtio_transport_put_credit()'. + */ + rest_len -= ret; + + if (WARN_ONCE(ret != skb_len, + "'send_pkt()' returns %i, but %zu expected\n", + ret, skb_len)) + break; + } while (rest_len); + + virtio_transport_put_credit(vvs, rest_len); - return t_ops->send_pkt(skb); + /* Return number of bytes, if any data has been sent. */ + if (rest_len != pkt_len) + ret = pkt_len - rest_len; + + return ret; } static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, @@ -278,6 +307,9 @@ u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) { u32 ret; + if (!credit) + return 0; + spin_lock_bh(&vvs->tx_lock); ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); if (ret > credit) @@ -291,6 +323,9 @@ EXPORT_SYMBOL_GPL(virtio_transport_get_credit); void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) { + if (!credit) + return; + spin_lock_bh(&vvs->tx_lock); vvs->tx_cnt -= credit; spin_unlock_bh(&vvs->tx_lock); @@ -862,6 +897,9 @@ static int virtio_transport_reset_no_sock(const struct virtio_transport *t, if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST) return 0; + if (!t) + return -ENOTCONN; + reply = virtio_transport_alloc_skb(&info, 0, le64_to_cpu(hdr->dst_cid), le32_to_cpu(hdr->dst_port), @@ -870,11 +908,6 @@ static int virtio_transport_reset_no_sock(const struct virtio_transport *t, if (!reply) return -ENOMEM; - if (!t) { - kfree_skb(reply); - return -ENOTCONN; - } - return t->send_pkt(reply); } @@ -1402,6 +1435,31 @@ int virtio_transport_purge_skbs(void *vsk, struct sk_buff_head *queue) } EXPORT_SYMBOL_GPL(virtio_transport_purge_skbs); +int virtio_transport_read_skb(struct vsock_sock *vsk, skb_read_actor_t recv_actor) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + struct sock *sk = sk_vsock(vsk); + struct sk_buff *skb; + int off = 0; + int copied; + int err; + + spin_lock_bh(&vvs->rx_lock); + /* Use __skb_recv_datagram() for race-free handling of the receive. It + * works for types other than dgrams. + */ + skb = __skb_recv_datagram(sk, &vvs->rx_queue, MSG_DONTWAIT, &off, &err); + spin_unlock_bh(&vvs->rx_lock); + + if (!skb) + return err; + + copied = recv_actor(sk, skb); + kfree_skb(skb); + return copied; +} +EXPORT_SYMBOL_GPL(virtio_transport_read_skb); + MODULE_LICENSE("GPL v2"); MODULE_AUTHOR("Asias He"); MODULE_DESCRIPTION("common code for virtio vsock"); |