diff options
Diffstat (limited to 'fs/smb/server/mgmt/user_session.c')
| -rw-r--r-- | fs/smb/server/mgmt/user_session.c | 145 |
1 files changed, 123 insertions, 22 deletions
diff --git a/fs/smb/server/mgmt/user_session.c b/fs/smb/server/mgmt/user_session.c index 8a5dcab05614..1c181ef99929 100644 --- a/fs/smb/server/mgmt/user_session.c +++ b/fs/smb/server/mgmt/user_session.c @@ -18,7 +18,7 @@ static DEFINE_IDA(session_ida); -#define SESSION_HASH_BITS 3 +#define SESSION_HASH_BITS 12 static DEFINE_HASHTABLE(sessions_table, SESSION_HASH_BITS); static DECLARE_RWSEM(sessions_table_lock); @@ -59,10 +59,12 @@ static void ksmbd_session_rpc_clear_list(struct ksmbd_session *sess) struct ksmbd_session_rpc *entry; long index; + down_write(&sess->rpc_lock); xa_for_each(&sess->rpc_handle_list, index, entry) { xa_erase(&sess->rpc_handle_list, index); __session_rpc_close(sess, entry); } + up_write(&sess->rpc_lock); xa_destroy(&sess->rpc_handle_list); } @@ -90,32 +92,41 @@ static int __rpc_method(char *rpc_name) int ksmbd_session_rpc_open(struct ksmbd_session *sess, char *rpc_name) { - struct ksmbd_session_rpc *entry; + struct ksmbd_session_rpc *entry, *old; struct ksmbd_rpc_command *resp; - int method; + int method, id; method = __rpc_method(rpc_name); if (!method) return -EINVAL; - entry = kzalloc(sizeof(struct ksmbd_session_rpc), GFP_KERNEL); + entry = kzalloc(sizeof(struct ksmbd_session_rpc), KSMBD_DEFAULT_GFP); if (!entry) return -ENOMEM; entry->method = method; - entry->id = ksmbd_ipc_id_alloc(); - if (entry->id < 0) + entry->id = id = ksmbd_ipc_id_alloc(); + if (id < 0) goto free_entry; - xa_store(&sess->rpc_handle_list, entry->id, entry, GFP_KERNEL); - resp = ksmbd_rpc_open(sess, entry->id); - if (!resp) + down_write(&sess->rpc_lock); + old = xa_store(&sess->rpc_handle_list, id, entry, KSMBD_DEFAULT_GFP); + if (xa_is_err(old)) { + up_write(&sess->rpc_lock); + goto free_id; + } + + resp = ksmbd_rpc_open(sess, id); + if (!resp) { + xa_erase(&sess->rpc_handle_list, entry->id); + up_write(&sess->rpc_lock); goto free_id; + } + up_write(&sess->rpc_lock); kvfree(resp); - return entry->id; + return id; free_id: - xa_erase(&sess->rpc_handle_list, entry->id); ksmbd_rpc_id_free(entry->id); free_entry: kfree(entry); @@ -126,16 +137,20 @@ void ksmbd_session_rpc_close(struct ksmbd_session *sess, int id) { struct ksmbd_session_rpc *entry; + down_write(&sess->rpc_lock); entry = xa_erase(&sess->rpc_handle_list, id); if (entry) __session_rpc_close(sess, entry); + up_write(&sess->rpc_lock); } int ksmbd_session_rpc_method(struct ksmbd_session *sess, int id) { struct ksmbd_session_rpc *entry; + lockdep_assert_held(&sess->rpc_lock); entry = xa_load(&sess->rpc_handle_list, id); + return entry ? entry->method : 0; } @@ -149,6 +164,7 @@ void ksmbd_session_destroy(struct ksmbd_session *sess) ksmbd_tree_conn_session_logoff(sess); ksmbd_destroy_file_table(&sess->file_table); + ksmbd_launch_ksmbd_durable_scavenger(); ksmbd_session_rpc_clear_list(sess); free_channel_list(sess); kfree(sess->Preauth_HashValue); @@ -156,7 +172,7 @@ void ksmbd_session_destroy(struct ksmbd_session *sess) kfree(sess); } -static struct ksmbd_session *__session_lookup(unsigned long long id) +struct ksmbd_session *__session_lookup(unsigned long long id) { struct ksmbd_session *sess; @@ -175,16 +191,19 @@ static void ksmbd_expire_session(struct ksmbd_conn *conn) struct ksmbd_session *sess; down_write(&sessions_table_lock); + down_write(&conn->session_lock); xa_for_each(&conn->sessions, id, sess) { - if (sess->state != SMB2_SESSION_VALID || - time_after(jiffies, - sess->last_active + SMB2_SESSION_TIMEOUT)) { + if (atomic_read(&sess->refcnt) <= 1 && + (sess->state != SMB2_SESSION_VALID || + time_after(jiffies, + sess->last_active + SMB2_SESSION_TIMEOUT))) { xa_erase(&conn->sessions, sess->id); hash_del(&sess->hlist); ksmbd_session_destroy(sess); continue; } } + up_write(&conn->session_lock); up_write(&sessions_table_lock); } @@ -194,7 +213,7 @@ int ksmbd_session_register(struct ksmbd_conn *conn, sess->dialect = conn->dialect; memcpy(sess->ClientGUID, conn->ClientGUID, SMB2_CLIENT_GUID_SIZE); ksmbd_expire_session(conn); - return xa_err(xa_store(&conn->sessions, sess->id, sess, GFP_KERNEL)); + return xa_err(xa_store(&conn->sessions, sess->id, sess, KSMBD_DEFAULT_GFP)); } static int ksmbd_chann_del(struct ksmbd_conn *conn, struct ksmbd_session *sess) @@ -223,11 +242,16 @@ void ksmbd_sessions_deregister(struct ksmbd_conn *conn) if (!ksmbd_chann_del(conn, sess) && xa_empty(&sess->ksmbd_chann_list)) { hash_del(&sess->hlist); - ksmbd_session_destroy(sess); + down_write(&conn->session_lock); + xa_erase(&conn->sessions, sess->id); + up_write(&conn->session_lock); + if (atomic_dec_and_test(&sess->refcnt)) + ksmbd_session_destroy(sess); } } } + down_write(&conn->session_lock); xa_for_each(&conn->sessions, id, sess) { unsigned long chann_id; struct channel *chann; @@ -241,20 +265,42 @@ void ksmbd_sessions_deregister(struct ksmbd_conn *conn) if (xa_empty(&sess->ksmbd_chann_list)) { xa_erase(&conn->sessions, sess->id); hash_del(&sess->hlist); - ksmbd_session_destroy(sess); + if (atomic_dec_and_test(&sess->refcnt)) + ksmbd_session_destroy(sess); } } + up_write(&conn->session_lock); up_write(&sessions_table_lock); } +bool is_ksmbd_session_in_connection(struct ksmbd_conn *conn, + unsigned long long id) +{ + struct ksmbd_session *sess; + + down_read(&conn->session_lock); + sess = xa_load(&conn->sessions, id); + if (sess) { + up_read(&conn->session_lock); + return true; + } + up_read(&conn->session_lock); + + return false; +} + struct ksmbd_session *ksmbd_session_lookup(struct ksmbd_conn *conn, unsigned long long id) { struct ksmbd_session *sess; + down_read(&conn->session_lock); sess = xa_load(&conn->sessions, id); - if (sess) + if (sess) { sess->last_active = jiffies; + ksmbd_user_session_get(sess); + } + up_read(&conn->session_lock); return sess; } @@ -265,7 +311,7 @@ struct ksmbd_session *ksmbd_session_lookup_slowpath(unsigned long long id) down_read(&sessions_table_lock); sess = __session_lookup(id); if (sess) - sess->last_active = jiffies; + ksmbd_user_session_get(sess); up_read(&sessions_table_lock); return sess; @@ -284,12 +330,28 @@ struct ksmbd_session *ksmbd_session_lookup_all(struct ksmbd_conn *conn, return sess; } +void ksmbd_user_session_get(struct ksmbd_session *sess) +{ + atomic_inc(&sess->refcnt); +} + +void ksmbd_user_session_put(struct ksmbd_session *sess) +{ + if (!sess) + return; + + if (atomic_read(&sess->refcnt) <= 0) + WARN_ON(1); + else if (atomic_dec_and_test(&sess->refcnt)) + ksmbd_session_destroy(sess); +} + struct preauth_session *ksmbd_preauth_session_alloc(struct ksmbd_conn *conn, u64 sess_id) { struct preauth_session *sess; - sess = kmalloc(sizeof(struct preauth_session), GFP_KERNEL); + sess = kmalloc(sizeof(struct preauth_session), KSMBD_DEFAULT_GFP); if (!sess) return NULL; @@ -301,6 +363,42 @@ struct preauth_session *ksmbd_preauth_session_alloc(struct ksmbd_conn *conn, return sess; } +void destroy_previous_session(struct ksmbd_conn *conn, + struct ksmbd_user *user, u64 id) +{ + struct ksmbd_session *prev_sess; + struct ksmbd_user *prev_user; + int err; + + down_write(&sessions_table_lock); + down_write(&conn->session_lock); + prev_sess = __session_lookup(id); + if (!prev_sess || prev_sess->state == SMB2_SESSION_EXPIRED) + goto out; + + prev_user = prev_sess->user; + if (!prev_user || + strcmp(user->name, prev_user->name) || + user->passkey_sz != prev_user->passkey_sz || + memcmp(user->passkey, prev_user->passkey, user->passkey_sz)) + goto out; + + ksmbd_all_conn_set_status(id, KSMBD_SESS_NEED_RECONNECT); + err = ksmbd_conn_wait_idle_sess_id(conn, id); + if (err) { + ksmbd_all_conn_set_status(id, KSMBD_SESS_NEED_SETUP); + goto out; + } + + ksmbd_destroy_file_table(&prev_sess->file_table); + prev_sess->state = SMB2_SESSION_EXPIRED; + ksmbd_all_conn_set_status(id, KSMBD_SESS_NEED_SETUP); + ksmbd_launch_ksmbd_durable_scavenger(); +out: + up_write(&conn->session_lock); + up_write(&sessions_table_lock); +} + static bool ksmbd_preauth_session_id_match(struct preauth_session *sess, unsigned long long id) { @@ -337,7 +435,7 @@ static struct ksmbd_session *__session_create(int protocol) if (protocol != CIFDS_SESSION_FLAG_SMB2) return NULL; - sess = kzalloc(sizeof(struct ksmbd_session), GFP_KERNEL); + sess = kzalloc(sizeof(struct ksmbd_session), KSMBD_DEFAULT_GFP); if (!sess) return NULL; @@ -351,6 +449,9 @@ static struct ksmbd_session *__session_create(int protocol) xa_init(&sess->ksmbd_chann_list); xa_init(&sess->rpc_handle_list); sess->sequence_number = 1; + rwlock_init(&sess->tree_conns_lock); + atomic_set(&sess->refcnt, 2); + init_rwsem(&sess->rpc_lock); ret = __init_smb2_session(sess); if (ret) |
