diff --git a/include/linux/filter.h b/include/linux/filter.h index d23e999dc0324b..d0dd9fdd08a2d8 100644 --- a/include/linux/filter.h +++ b/include/linux/filter.h @@ -279,14 +279,20 @@ static inline bool insn_is_zext(const struct bpf_insn *insn) * BPF_CMPXCHG r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg) */ -#define BPF_ATOMIC_OP(SIZE, OP, DST, SRC, OFF) \ +#define _BPF_ATOMIC_OP(SIZE, CODE, OP, DST, SRC, OFF) \ ((struct bpf_insn) { \ - .code = BPF_STX | BPF_SIZE(SIZE) | BPF_ATOMIC, \ + .code = CODE | BPF_SIZE(SIZE) | BPF_ATOMIC, \ .dst_reg = DST, \ .src_reg = SRC, \ .off = OFF, \ .imm = OP }) +#define BPF_ATOMIC_OP(SIZE, OP, DST, SRC, OFF) \ + _BPF_ATOMIC_OP(SIZE, BPF_STX, OP, DST, SRC, OFF) + +#define BPF_ATOMIC_LOAD_OP(SIZE, OP, DST, SRC, OFF) \ + _BPF_ATOMIC_OP(SIZE, BPF_LDX, OP, DST, SRC, OFF) + /* Legacy alias */ #define BPF_STX_XADD(SIZE, DST, SRC, OFF) BPF_ATOMIC_OP(SIZE, BPF_ADD, DST, SRC, OFF) diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h index fe2272defcd95d..c69651b5326100 100644 --- a/include/uapi/linux/bpf.h +++ b/include/uapi/linux/bpf.h @@ -5538,6 +5538,7 @@ struct bpf_sock { __u32 dst_ip6[4]; __u32 state; __s32 rx_queue_mapping; + __u64 cookie; /* read-only */ }; struct bpf_tcp_sock { diff --git a/net/core/filter.c b/net/core/filter.c index f73a84c75970e6..54474e2318501d 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -4697,6 +4697,18 @@ static const struct bpf_func_proto bpf_get_socket_cookie_sock_ops_proto = { .arg1_type = ARG_PTR_TO_CTX, }; +BPF_CALL_1(bpf_get_socket_cookie_sk_msg, struct sk_msg *, ctx) +{ + return ctx->sk ? __sock_gen_cookie(ctx->sk) : 0; +} + +static const struct bpf_func_proto bpf_get_socket_cookie_sk_msg_proto = { + .func = bpf_get_socket_cookie_sk_msg, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX_OR_NULL, +}; + static u64 __bpf_get_netns_cookie(struct sock *sk) { const struct net *net = sk ? sock_net(sk) : &init_net; @@ -7633,6 +7645,8 @@ sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) return &bpf_sk_storage_delete_proto; case BPF_FUNC_get_netns_cookie: return &bpf_get_netns_cookie_sk_msg_proto; + case BPF_FUNC_get_socket_cookie: + return &bpf_get_socket_cookie_sk_msg_proto; #ifdef CONFIG_CGROUPS case BPF_FUNC_get_current_cgroup_id: return &bpf_get_current_cgroup_id_proto; @@ -8044,6 +8058,11 @@ bool bpf_sock_is_valid_access(int off, int size, enum bpf_access_type type, case offsetof(struct bpf_sock, dst_port): case offsetof(struct bpf_sock, src_port): case offsetof(struct bpf_sock, rx_queue_mapping): + break; + case bpf_ctx_range(struct bpf_sock, cookie): + if (type == BPF_WRITE) + return false; + return size == sizeof(__u64); case bpf_ctx_range(struct bpf_sock, src_ip4): case bpf_ctx_range_till(struct bpf_sock, src_ip6[0], src_ip6[3]): case bpf_ctx_range(struct bpf_sock, dst_ip4): @@ -9125,6 +9144,7 @@ u32 bpf_sock_convert_ctx_access(enum bpf_access_type type, skc_state), target_size)); break; + case offsetof(struct bpf_sock, rx_queue_mapping): #ifdef CONFIG_SOCK_RX_QUEUE_MAPPING *insn++ = BPF_LDX_MEM( @@ -9142,6 +9162,16 @@ u32 bpf_sock_convert_ctx_access(enum bpf_access_type type, *target_size = 2; #endif break; + + case offsetof(struct bpf_sock, cookie): + *insn++ = BPF_ATOMIC_LOAD_OP( + BPF_FIELD_SIZEOF(struct sock_common, skc_cookie), + BPF_XCHG, si->dst_reg, si->src_reg, + bpf_target_off(struct sock_common, skc_cookie, + sizeof_field(struct sock_common, + skc_cookie), + target_size)); + break; } return insn - insn_buf; diff --git a/tools/include/linux/filter.h b/tools/include/linux/filter.h index 736bdeccdfe44b..04f1dcb4e554be 100644 --- a/tools/include/linux/filter.h +++ b/tools/include/linux/filter.h @@ -184,14 +184,20 @@ * BPF_CMPXCHG r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg) */ -#define BPF_ATOMIC_OP(SIZE, OP, DST, SRC, OFF) \ +#define _BPF_ATOMIC_OP(SIZE, CODE, OP, DST, SRC, OFF) \ ((struct bpf_insn) { \ - .code = BPF_STX | BPF_SIZE(SIZE) | BPF_ATOMIC, \ + .code = CODE | BPF_SIZE(SIZE) | BPF_ATOMIC, \ .dst_reg = DST, \ .src_reg = SRC, \ .off = OFF, \ .imm = OP }) +#define BPF_ATOMIC_OP(SIZE, OP, DST, SRC, OFF) \ + _BPF_ATOMIC_OP(SIZE, BPF_STX, OP, DST, SRC, OFF) + +#define BPF_ATOMIC_LOAD_OP(SIZE, OP, DST, SRC, OFF) \ + _BPF_ATOMIC_OP(SIZE, BPF_LDX, OP, DST, SRC, OFF) + /* Legacy alias */ #define BPF_STX_XADD(SIZE, DST, SRC, OFF) BPF_ATOMIC_OP(SIZE, BPF_ADD, DST, SRC, OFF) diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h index fe2272defcd95d..c69651b5326100 100644 --- a/tools/include/uapi/linux/bpf.h +++ b/tools/include/uapi/linux/bpf.h @@ -5538,6 +5538,7 @@ struct bpf_sock { __u32 dst_ip6[4]; __u32 state; __s32 rx_queue_mapping; + __u64 cookie; /* read-only */ }; struct bpf_tcp_sock { diff --git a/tools/testing/selftests/bpf/prog_tests/socket_cookie.c b/tools/testing/selftests/bpf/prog_tests/socket_cookie.c index 232db28dde18a2..0bd871cfcf062d 100644 --- a/tools/testing/selftests/bpf/prog_tests/socket_cookie.c +++ b/tools/testing/selftests/bpf/prog_tests/socket_cookie.c @@ -15,12 +15,17 @@ struct socket_cookie { void test_socket_cookie(void) { - int server_fd = 0, client_fd = 0, cgroup_fd = 0, err = 0; + int server_fd = 0, client_fd = 0, cgroup_fd = 0, sock_map_fd = 0, + err = 0; + const __u32 zero = 0; socklen_t addr_len = sizeof(struct sockaddr_in6); struct socket_cookie_prog *skel; __u32 cookie_expected_value; struct sockaddr_in6 addr; struct socket_cookie val; + struct msghdr msg = {0}; + struct iovec iov = {0}; + char buf[1]; skel = socket_cookie_prog__open_and_load(); if (!ASSERT_OK_PTR(skel, "skel_open")) @@ -45,14 +50,39 @@ void test_socket_cookie(void) if (!ASSERT_OK_PTR(skel->links.update_cookie_tracing, "prog_attach")) goto close_cgroup_fd; + sock_map_fd = bpf_map__fd(skel->maps.sock_map); + if (CHECK(sock_map_fd < 0, "map_fd(sock_map)", "errno %d\n", errno)) + goto close_cgroup_fd; + + /* Attach sk_msg prog to sock_map */ + if (CHECK(bpf_prog_attach(bpf_program__fd(skel->progs.set_cookie_skmsg), + sock_map_fd, BPF_SK_MSG_VERDICT, 0) < 0, + "prog_attach", "errno %d\n", errno)) + goto close_sock_map_fd; + server_fd = start_server(AF_INET6, SOCK_STREAM, "::1", 0, 0); if (CHECK(server_fd < 0, "start_server", "errno %d\n", errno)) - goto close_cgroup_fd; + goto close_sock_map_fd; client_fd = connect_to_fd(server_fd, 0); if (CHECK(client_fd < 0, "connect_to_fd", "errno %d\n", errno)) goto close_server_fd; + /* Add client_fd to sock_map */ + if (CHECK(bpf_map_update_elem(sock_map_fd, &zero, &client_fd, BPF_ANY) < + 0, + "map_update(sock_map)", "errno %d\n", errno)) + goto close_sock_map_fd; + + /* Trigger sk_msg program */ + iov.iov_base = buf; + iov.iov_len = sizeof(buf); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + if (CHECK(sendmsg(client_fd, &msg, 0) < 0, "sendmsg", "errno %d\n", + errno)) + goto close_client_fd; + err = bpf_map_lookup_elem(bpf_map__fd(skel->maps.socket_cookies), &client_fd, &val); if (!ASSERT_OK(err, "map_lookup(socket_cookies)")) @@ -69,6 +99,8 @@ void test_socket_cookie(void) close(client_fd); close_server_fd: close(server_fd); +close_sock_map_fd: + close(sock_map_fd); close_cgroup_fd: close(cgroup_fd); out: diff --git a/tools/testing/selftests/bpf/progs/socket_cookie_prog.c b/tools/testing/selftests/bpf/progs/socket_cookie_prog.c index 35630a5aaf5f28..d9731ec1a301f7 100644 --- a/tools/testing/selftests/bpf/progs/socket_cookie_prog.c +++ b/tools/testing/selftests/bpf/progs/socket_cookie_prog.c @@ -22,9 +22,21 @@ struct { } socket_cookies SEC(".maps"); /* - * These three programs get executed in a row on connect() syscalls. The - * userspace side of the test creates a client socket, issues a connect() on it - * and then checks that the local storage associated with this socket has: + * Used for testing the sk_msg prog. + */ +struct { + __uint(type, BPF_MAP_TYPE_SOCKMAP); + __uint(max_entries, 2); + __type(key, __u32); + __type(value, __u64); +} sock_map SEC(".maps"); + +/* + * The following three programs get executed in a row on connect() syscalls and + * the fourth is triggered via sendmsg() after the former three. The userspace + * side of the test creates a client socket, issues a connect() and sendmsg() + * on it, and then checks that the local storage associated with this socket + * has: * cookie_value == local_port << 8 | 0xFF * The different parts of this cookie_value are appended by those hooks if they * all agree on the output of bpf_get_socket_cookie(). @@ -96,4 +108,29 @@ int BPF_PROG(update_cookie_tracing, struct socket *sock, return 0; } +SEC("sk_msg") +int set_cookie_skmsg(struct sk_msg_md *msg) +{ + struct bpf_sock *sk = msg->sk; + int verdict = SK_PASS; + + __u64 cookie; + struct socket_cookie *p; + + if (!sk) + return SK_DROP; + + p = bpf_sk_storage_get(&socket_cookies, sk, 0, + BPF_SK_STORAGE_GET_F_CREATE); + if (!p) + return SK_DROP; + + if (p->cookie_key != bpf_get_socket_cookie(msg)) + return SK_DROP; + + p->cookie_value |= 0xFF; + + return verdict; +} + char _license[] SEC("license") = "GPL"; diff --git a/tools/testing/selftests/bpf/progs/test_skmsg_load_helpers.c b/tools/testing/selftests/bpf/progs/test_skmsg_load_helpers.c index 45e8fc75a7397a..78843c9a3cb129 100644 --- a/tools/testing/selftests/bpf/progs/test_skmsg_load_helpers.c +++ b/tools/testing/selftests/bpf/progs/test_skmsg_load_helpers.c @@ -17,30 +17,46 @@ struct { __type(value, __u64); } sock_hash SEC(".maps"); +struct socket_storage_value { + __u64 pid; + __u64 cookie; +}; + struct { __uint(type, BPF_MAP_TYPE_SK_STORAGE); __uint(map_flags, BPF_F_NO_PREALLOC); __type(key, __u32); - __type(value, __u64); + __type(value, struct socket_storage_value); } socket_storage SEC(".maps"); SEC("sk_msg") int prog_msg_verdict(struct sk_msg_md *msg) { struct task_struct *task = (struct task_struct *)bpf_get_current_task(); + struct bpf_sock *sk = msg->sk; int verdict = SK_PASS; + __u32 pid, tpid; - __u64 *sk_stg; + __u64 cookie, ecookie; + struct socket_storage_value *sk_stg; + + if (!sk) + return SK_DROP; pid = bpf_get_current_pid_tgid() >> 32; - sk_stg = bpf_sk_storage_get(&socket_storage, msg->sk, 0, BPF_SK_STORAGE_GET_F_CREATE); + cookie = bpf_get_socket_cookie(msg); + sk_stg = bpf_sk_storage_get(&socket_storage, sk, 0, BPF_SK_STORAGE_GET_F_CREATE); if (!sk_stg) return SK_DROP; - *sk_stg = pid; + sk_stg->pid = pid; + sk_stg->cookie = cookie; bpf_probe_read_kernel(&tpid , sizeof(tpid), &task->tgid); if (pid != tpid) verdict = SK_DROP; - bpf_sk_storage_delete(&socket_storage, (void *)msg->sk); + ecookie = sk->cookie; + if (cookie != ecookie) + verdict = SK_DROP; + bpf_sk_storage_delete(&socket_storage, (void *)sk); return verdict; }