diff options
Diffstat (limited to 'lib/strnlen_user.c')
| -rw-r--r-- | lib/strnlen_user.c | 99 |
1 files changed, 45 insertions, 54 deletions
diff --git a/lib/strnlen_user.c b/lib/strnlen_user.c index a28df5206d95..4a6574b67f82 100644 --- a/lib/strnlen_user.c +++ b/lib/strnlen_user.c @@ -1,16 +1,12 @@ +// SPDX-License-Identifier: GPL-2.0 #include <linux/kernel.h> #include <linux/export.h> #include <linux/uaccess.h> +#include <linux/mm.h> +#include <linux/bitops.h> #include <asm/word-at-a-time.h> -/* Set bits in the first 'n' bytes when loaded from memory */ -#ifdef __LITTLE_ENDIAN -# define aligned_byte_mask(n) ((1ul << 8*(n))-1) -#else -# define aligned_byte_mask(n) (~0xfful << (BITS_PER_LONG - 8 - 8*(n))) -#endif - /* * Do a strnlen, return length of string *with* final '\0'. * 'count' is the user-supplied count, while 'max' is the @@ -24,29 +20,21 @@ * if it fits in a aligned 'long'. The caller needs to check * the return value against "> max". */ -static inline long do_strnlen_user(const char __user *src, unsigned long count, unsigned long max) +static __always_inline long do_strnlen_user(const char __user *src, unsigned long count, unsigned long max) { const struct word_at_a_time constants = WORD_AT_A_TIME_CONSTANTS; - long align, res = 0; + unsigned long align, res = 0; unsigned long c; /* - * Truncate 'max' to the user-specified limit, so that - * we only have one limit we need to check in the loop - */ - if (max > count) - max = count; - - /* * Do everything aligned. But that means that we * need to also expand the maximum.. */ - align = (sizeof(long) - 1) & (unsigned long)src; + align = (sizeof(unsigned long) - 1) & (unsigned long)src; src -= align; max += align; - if (unlikely(__get_user(c,(unsigned long __user *)src))) - return 0; + unsafe_get_user(c, (unsigned long __user *)src, efault); c |= aligned_byte_mask(align); for (;;) { @@ -57,11 +45,11 @@ static inline long do_strnlen_user(const char __user *src, unsigned long count, return res + find_zero(data) + 1 - align; } res += sizeof(unsigned long); - if (unlikely(max < sizeof(unsigned long))) + /* We already handled 'unsigned long' bytes. Did we do it all ? */ + if (unlikely(max <= sizeof(unsigned long))) break; max -= sizeof(unsigned long); - if (unlikely(__get_user(c,(unsigned long __user *)(src+res)))) - return 0; + unsafe_get_user(c, (unsigned long __user *)(src+res), efault); } res -= align; @@ -76,6 +64,7 @@ static inline long do_strnlen_user(const char __user *src, unsigned long count, * Nope: we hit the address space limit, and we still had more * characters the caller would have wanted. That's 0. */ +efault: return 0; } @@ -84,13 +73,21 @@ static inline long do_strnlen_user(const char __user *src, unsigned long count, * @str: The string to measure. * @count: Maximum count (including NUL character) * - * Context: User context only. This function may sleep. + * Context: User context only. This function may sleep if pagefaults are + * enabled. * * Get the size of a NUL-terminated string in user space. * * Returns the size of the string INCLUDING the terminating NUL. - * If the string is too long, returns 'count+1'. + * If the string is too long, returns a number larger than @count. User + * has to check the return value against "> count". * On exception (or invalid count), returns 0. + * + * NOTE! You should basically never use this function. There is + * almost never any valid case for using the length of a user space + * string, since the string can be changed at any time by other + * threads. Use "strncpy_from_user()" instead to get a stable copy + * of the string. */ long strnlen_user(const char __user *str, long count) { @@ -99,40 +96,34 @@ long strnlen_user(const char __user *str, long count) if (unlikely(count <= 0)) return 0; - max_addr = user_addr_max(); - src_addr = (unsigned long)str; - if (likely(src_addr < max_addr)) { - unsigned long max = max_addr - src_addr; - return do_strnlen_user(str, count, max); - } - return 0; -} -EXPORT_SYMBOL(strnlen_user); + if (can_do_masked_user_access()) { + long retval; -/** - * strlen_user: - Get the size of a user string INCLUDING final NUL. - * @str: The string to measure. - * - * Context: User context only. This function may sleep. - * - * Get the size of a NUL-terminated string in user space. - * - * Returns the size of the string INCLUDING the terminating NUL. - * On exception, returns 0. - * - * If there is a limit on the length of a valid string, you may wish to - * consider using strnlen_user() instead. - */ -long strlen_user(const char __user *str) -{ - unsigned long max_addr, src_addr; + str = masked_user_read_access_begin(str); + retval = do_strnlen_user(str, count, count); + user_read_access_end(); + return retval; + } - max_addr = user_addr_max(); - src_addr = (unsigned long)str; + max_addr = TASK_SIZE_MAX; + src_addr = (unsigned long)untagged_addr(str); if (likely(src_addr < max_addr)) { unsigned long max = max_addr - src_addr; - return do_strnlen_user(str, ~0ul, max); + long retval; + + /* + * Truncate 'max' to the user-specified limit, so that + * we only have one limit we need to check in the loop + */ + if (max > count) + max = count; + + if (user_read_access_begin(str, max)) { + retval = do_strnlen_user(str, count, max); + user_read_access_end(); + return retval; + } } return 0; } -EXPORT_SYMBOL(strlen_user); +EXPORT_SYMBOL(strnlen_user); |
