diff options
| -rw-r--r-- | net/vmw_vsock/af_vsock.c | 18 | ||||
| -rw-r--r-- | net/vmw_vsock/virtio_transport_common.c | 36 | ||||
| -rw-r--r-- | net/vmw_vsock/vsock_bpf.c | 9 | 
3 files changed, 53 insertions, 10 deletions
| diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 5cf8109f672a..fa9d1b49599b 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -491,6 +491,15 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)  		 */  		vsk->transport->release(vsk);  		vsock_deassign_transport(vsk); + +		/* transport's release() and destruct() can touch some socket +		 * state, since we are reassigning the socket to a new transport +		 * during vsock_connect(), let's reset these fields to have a +		 * clean state. +		 */ +		sock_reset_flag(sk, SOCK_DONE); +		sk->sk_state = TCP_CLOSE; +		vsk->peer_shutdown = 0;  	}  	/* We increase the module refcnt to prevent the transport unloading @@ -870,6 +879,9 @@ EXPORT_SYMBOL_GPL(vsock_create_connected);  s64 vsock_stream_has_data(struct vsock_sock *vsk)  { +	if (WARN_ON(!vsk->transport)) +		return 0; +  	return vsk->transport->stream_has_data(vsk);  }  EXPORT_SYMBOL_GPL(vsock_stream_has_data); @@ -878,6 +890,9 @@ s64 vsock_connectible_has_data(struct vsock_sock *vsk)  {  	struct sock *sk = sk_vsock(vsk); +	if (WARN_ON(!vsk->transport)) +		return 0; +  	if (sk->sk_type == SOCK_SEQPACKET)  		return vsk->transport->seqpacket_has_data(vsk);  	else @@ -887,6 +902,9 @@ EXPORT_SYMBOL_GPL(vsock_connectible_has_data);  s64 vsock_stream_has_space(struct vsock_sock *vsk)  { +	if (WARN_ON(!vsk->transport)) +		return 0; +  	return vsk->transport->stream_has_space(vsk);  }  EXPORT_SYMBOL_GPL(vsock_stream_has_space); diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index 9acc13ab3f82..7f7de6d88096 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -26,6 +26,9 @@  /* Threshold for detecting small packets to copy */  #define GOOD_COPY_LEN  128 +static void virtio_transport_cancel_close_work(struct vsock_sock *vsk, +					       bool cancel_timeout); +  static const struct virtio_transport *  virtio_transport_get_ops(struct vsock_sock *vsk)  { @@ -1109,6 +1112,8 @@ void virtio_transport_destruct(struct vsock_sock *vsk)  {  	struct virtio_vsock_sock *vvs = vsk->trans; +	virtio_transport_cancel_close_work(vsk, true); +  	kfree(vvs);  	vsk->trans = NULL;  } @@ -1204,17 +1209,11 @@ static void virtio_transport_wait_close(struct sock *sk, long timeout)  	}  } -static void virtio_transport_do_close(struct vsock_sock *vsk, -				      bool cancel_timeout) +static void virtio_transport_cancel_close_work(struct vsock_sock *vsk, +					       bool cancel_timeout)  {  	struct sock *sk = sk_vsock(vsk); -	sock_set_flag(sk, SOCK_DONE); -	vsk->peer_shutdown = SHUTDOWN_MASK; -	if (vsock_stream_has_data(vsk) <= 0) -		sk->sk_state = TCP_CLOSING; -	sk->sk_state_change(sk); -  	if (vsk->close_work_scheduled &&  	    (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {  		vsk->close_work_scheduled = false; @@ -1226,6 +1225,20 @@ static void virtio_transport_do_close(struct vsock_sock *vsk,  	}  } +static void virtio_transport_do_close(struct vsock_sock *vsk, +				      bool cancel_timeout) +{ +	struct sock *sk = sk_vsock(vsk); + +	sock_set_flag(sk, SOCK_DONE); +	vsk->peer_shutdown = SHUTDOWN_MASK; +	if (vsock_stream_has_data(vsk) <= 0) +		sk->sk_state = TCP_CLOSING; +	sk->sk_state_change(sk); + +	virtio_transport_cancel_close_work(vsk, cancel_timeout); +} +  static void virtio_transport_close_timeout(struct work_struct *work)  {  	struct vsock_sock *vsk = @@ -1628,8 +1641,11 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,  	lock_sock(sk); -	/* Check if sk has been closed before lock_sock */ -	if (sock_flag(sk, SOCK_DONE)) { +	/* Check if sk has been closed or assigned to another transport before +	 * lock_sock (note: listener sockets are not assigned to any transport) +	 */ +	if (sock_flag(sk, SOCK_DONE) || +	    (sk->sk_state != TCP_LISTEN && vsk->transport != &t->transport)) {  		(void)virtio_transport_reset_no_sock(t, skb);  		release_sock(sk);  		sock_put(sk); diff --git a/net/vmw_vsock/vsock_bpf.c b/net/vmw_vsock/vsock_bpf.c index 4aa6e74ec295..f201d9eca1df 100644 --- a/net/vmw_vsock/vsock_bpf.c +++ b/net/vmw_vsock/vsock_bpf.c @@ -77,6 +77,7 @@ static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg,  			     size_t len, int flags, int *addr_len)  {  	struct sk_psock *psock; +	struct vsock_sock *vsk;  	int copied;  	psock = sk_psock_get(sk); @@ -84,6 +85,13 @@ static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg,  		return __vsock_recvmsg(sk, msg, len, flags);  	lock_sock(sk); +	vsk = vsock_sk(sk); + +	if (!vsk->transport) { +		copied = -ENODEV; +		goto out; +	} +  	if (vsock_has_data(sk, psock) && sk_psock_queue_empty(psock)) {  		release_sock(sk);  		sk_psock_put(sk, psock); @@ -108,6 +116,7 @@ static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg,  		copied = sk_msg_recvmsg(sk, psock, msg, len, flags);  	} +out:  	release_sock(sk);  	sk_psock_put(sk, psock); | 
