sandbox/init/init.c

342 lines
12 KiB
C
Raw Normal View History

2026-03-23 00:35:27 +08:00
// init.c v2.1
#define _GNU_SOURCE
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <stdint.h>
#include <string.h>
#include <stddef.h>
#include <fcntl.h>
#include <sys/wait.h>
#include <sys/prctl.h>
#include <sys/ioctl.h>
#include <sys/uio.h>
#include <sys/syscall.h>
#include <sys/mount.h>
#include <linux/seccomp.h>
#include <linux/filter.h>
#include <linux/audit.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <errno.h>
#include <signal.h>
#include <poll.h> // 新增
// #define DEBUG(fmt, ...) printf("[DEBUG] " fmt "\n", ##__VA_ARGS__)
#define DEBUG(fmt, ...) do {} while (0)
struct __attribute__((packed)) NetRule {
uint8_t ip[16];
int8_t mask;
uint16_t port;
uint8_t is_v6;
};
uint8_t net_enabled = 0;
uint8_t allow_internet = 0;
uint8_t allow_local = 0;
uint32_t listen_cnt = 0, allow_cnt = 0, block_cnt = 0;
uint32_t *listen_ports = NULL;
struct NetRule *allow_rules = NULL, *block_rules = NULL;
pid_t child_pid = -1;
volatile sig_atomic_t stop_monitor = 0;
void safe_read(int fd, void *buf, size_t len) {
size_t offset = 0;
while (offset < len) {
ssize_t r = read(fd, (char *)buf + offset, len - offset);
if (r <= 0) exit(101);
offset += r;
}
}
void handle_sig(int sig) {
if (sig == SIGCHLD) {
stop_monitor = 1;
} else if (child_pid > 0) {
kill(-child_pid, sig);
stop_monitor = 1;
}
}
int is_ip_match(struct NetRule *rule, void *target_ip) {
if (rule->mask == -1) return memcmp(rule->ip, target_ip, rule->is_v6 ? 16 : 4) == 0;
int bytes = rule->mask / 8;
if (memcmp(rule->ip, target_ip, bytes) != 0) return 0;
int bits = rule->mask % 8;
if (bits == 0) return 1;
uint8_t mask_byte = (0xFF << (8 - bits)) & 0xFF;
return (rule->ip[bytes] & mask_byte) == (((uint8_t*)target_ip)[bytes] & mask_byte);
}
// 检查是否为回环地址 (127.0.0.0/8 或 ::1)
int is_loopback(void *ip, int is_v6) {
if (!is_v6) return ((uint8_t*)ip)[0] == 127;
uint8_t v6_loop[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1};
return memcmp(ip, v6_loop, 16) == 0;
}
// 检查是否为局域网私有地址 (RFC 1918)
int is_private_ip(void *ip, int is_v6) {
if (is_v6) {
uint8_t *p = (uint8_t*)ip;
// fc00::/7 (ULA - Unique Local Address)
if ((p[0] & 0xfe) == 0xfc) return 1;
// fe80::/10 (Link-Local)
if (p[0] == 0xfe && (p[1] & 0xc0) == 0x80) return 1;
return 0;
}
uint8_t *p = (uint8_t*)ip;
if (p[0] == 10) return 1;
if (p[0] == 172 && (p[1] >= 16 && p[1] <= 31)) return 1;
if (p[0] == 192 && p[1] == 168) return 1;
return 0;
}
int is_allowed(struct sockaddr *addr, int nr) {
uint16_t port = 0; void *ip = NULL; int is_v6 = 0;
if (addr->sa_family == AF_INET) {
struct sockaddr_in *s4 = (struct sockaddr_in *)addr;
port = ntohs(s4->sin_port); ip = &s4->sin_addr;
} else if (addr->sa_family == AF_INET6) {
struct sockaddr_in6 *s6 = (struct sockaddr_in6 *)addr;
port = ntohs(s6->sin6_port); ip = &s6->sin6_addr; is_v6 = 1;
} else return 0;
// 0. 永远允许回环地址
if (is_loopback(ip, is_v6)) return 1;
// 1. 拦截 bind (监听端口)
if (nr == __NR_bind) {
for (uint32_t i = 0; i < listen_cnt; i++) {
if (listen_ports[i] == port) return 1;
}
return 0;
}
// 2. 检查黑名单 (BlockList 优先级最高)
for (uint32_t i = 0; i < block_cnt; i++) {
if (block_rules[i].is_v6 == is_v6 && (block_rules[i].port == 0 || block_rules[i].port == port)) {
if (is_ip_match(&block_rules[i], ip)) return 0;
}
}
if (nr == __NR_sendto || nr == __NR_sendmsg) {
if (port == 53) return 1; // 永远放行 DNS 查询 (端口 53)
}
// 3. 检查白名单 (AllowList)
for (uint32_t i = 0; i < allow_cnt; i++) {
if (allow_rules[i].is_v6 == is_v6 && (allow_rules[i].port == 0 || allow_rules[i].port == port)) {
if (is_ip_match(&allow_rules[i], ip)) return 1;
}
}
// 4. 基础开关判定 (完美映射 Go 的配置)
if (is_private_ip(ip, is_v6)) {
return allow_local; // 访问 192.168.x.x 取决于 AllowLocalNetwork
} else {
return allow_internet; // 访问 8.8.8.8 取决于 AllowInternet
}
}
// --- 监控主循环 (修正版) ---
void run_monitor(int notif_fd) {
struct seccomp_notif req = {0};
struct seccomp_notif_resp resp = {0};
struct pollfd pfd = { .fd = notif_fd, .events = POLLIN };
while (!stop_monitor) {
// 使用 poll 等待数据,超时设置为 500ms
// 这确保了即使 ioctl 没被信号打断,我们也能每半秒检查一次 stop_monitor
int ret = poll(&pfd, 1, 100);
if (ret < 0) {
if (errno == EINTR && !stop_monitor) continue;
break;
}
if (ret == 0) continue; // 超时,重新循环检查 stop_monitor
memset(&req, 0, sizeof(req));
if (ioctl(notif_fd, SECCOMP_IOCTL_NOTIF_RECV, &req) == -1) {
if (errno == EINTR) continue;
break;
}
resp.id = req.id; resp.val = 0; resp.error = 0; resp.flags = 0;
struct sockaddr_storage addr;
int check_addr = 0; // 0:读取失败(拒绝), 1:读取成功(待验证), -1:无需验证(放行)
if (req.data.nr == __NR_connect || req.data.nr == __NR_bind) {
void *remote_ptr = (void *)req.data.args[1];
struct iovec local = { .iov_base = &addr, .iov_len = sizeof(addr) };
struct iovec remote = { .iov_base = remote_ptr, .iov_len = sizeof(addr) };
if (process_vm_readv(req.pid, &local, 1, &remote, 1, 0) > 0) check_addr = 1;
} else if (req.data.nr == __NR_sendto) {
void *remote_ptr = (void *)req.data.args[4]; // sendto 的地址在第 5 个参数
if (remote_ptr != NULL) {
struct iovec local = { .iov_base = &addr, .iov_len = sizeof(addr) };
struct iovec remote = { .iov_base = remote_ptr, .iov_len = sizeof(addr) };
if (process_vm_readv(req.pid, &local, 1, &remote, 1, 0) > 0) check_addr = 1;
} else {
check_addr = -1; // 理论上 BPF 已经拦截了 NULL防御性编程
}
} else if (req.data.nr == __NR_sendmsg) {
struct msghdr msg;
void *msg_ptr = (void *)req.data.args[1]; // sendmsg 的参数是 struct msghdr *
struct iovec local_msg = { .iov_base = &msg, .iov_len = sizeof(msg) };
struct iovec remote_msg = { .iov_base = msg_ptr, .iov_len = sizeof(msg) };
// 第一跳:读取子进程的 msghdr 结构体
if (process_vm_readv(req.pid, &local_msg, 1, &remote_msg, 1, 0) > 0) {
if (msg.msg_name != NULL) {
// 第二跳:根据结构体中的 msg_name 指针读取 sockaddr
struct iovec local_addr = { .iov_base = &addr, .iov_len = sizeof(addr) };
struct iovec remote_addr = { .iov_base = msg.msg_name, .iov_len = sizeof(addr) };
if (process_vm_readv(req.pid, &local_addr, 1, &remote_addr, 1, 0) > 0) check_addr = 1;
} else {
check_addr = -1; // 已连接的 socket
}
}
}
if (check_addr == 1) {
if (is_allowed((struct sockaddr *)&addr, req.data.nr)) {
resp.flags = SECCOMP_USER_NOTIF_FLAG_CONTINUE;
} else {
resp.error = -EPERM;
}
} else if (check_addr == -1) {
resp.flags = SECCOMP_USER_NOTIF_FLAG_CONTINUE;
} else {
resp.error = -EPERM; // 跨进程内存读取失败,安全起见直接掐断
}
ioctl(notif_fd, SECCOMP_IOCTL_NOTIF_SEND, &resp);
}
}
int main() {
setvbuf(stdout, NULL, _IONBF, 0);
setvbuf(stderr, NULL, _IONBF, 0);
uint32_t uid, gid, arg_count;
safe_read(0, &uid, 4);
safe_read(0, &gid, 4);
uint32_t wd_len;
safe_read(0, &wd_len, 4);
char *work_dir = malloc(wd_len + 1);
safe_read(0, work_dir, wd_len);
work_dir[wd_len] = '\0';
safe_read(0, &arg_count, 4);
char **argv = malloc(sizeof(char *) * (arg_count + 1));
for (uint32_t i = 0; i < arg_count; i++) {
uint32_t s_len;
safe_read(0, &s_len, 4);
argv[i] = malloc(s_len + 1);
safe_read(0, argv[i], s_len);
argv[i][s_len] = '\0';
}
argv[arg_count] = NULL;
safe_read(0, &net_enabled, 1);
if (net_enabled) {
safe_read(0, &allow_internet, 1);
safe_read(0, &allow_local, 1);
safe_read(0, &listen_cnt, 4);
if (listen_cnt > 0) {
listen_ports = malloc(listen_cnt * 4);
safe_read(0, listen_ports, listen_cnt * 4);
}
safe_read(0, &allow_cnt, 4);
if (allow_cnt > 0) {
allow_rules = malloc(allow_cnt * sizeof(struct NetRule));
safe_read(0, allow_rules, allow_cnt * sizeof(struct NetRule));
}
safe_read(0, &block_cnt, 4);
if (block_cnt > 0) {
block_rules = malloc(block_cnt * sizeof(struct NetRule));
safe_read(0, block_rules, block_cnt * sizeof(struct NetRule));
}
}
int dev_null = open("/dev/null", O_RDONLY);
if (dev_null >= 0) { dup2(dev_null, 0); close(dev_null); }
int notif_fd = -1;
if (net_enabled) {
struct sock_filter filter[] = {
BPF_STMT(BPF_LD | BPF_W | BPF_ABS, offsetof(struct seccomp_data, nr)),
// 1. 拦截 connect
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, __NR_connect, 0, 1),
BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_USER_NOTIF),
// 2. 拦截 bind
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, __NR_bind, 0, 1),
BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_USER_NOTIF),
// 3. 拦截 sendto (精细化过滤:只拦截 dest_addr 不为空的调用)
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, __NR_sendto, 0, 5),
// 加载 args[4] (dest_addr) 的低 32 位
BPF_STMT(BPF_LD | BPF_W | BPF_ABS, offsetof(struct seccomp_data, args[4])),
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, 0, 0, 2), // 不为 0 说明有地址,跳去 Notify
// 加载 args[4] 的高 32 位
BPF_STMT(BPF_LD | BPF_W | BPF_ABS, offsetof(struct seccomp_data, args[4]) + 4),
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, 0, 2, 0), // 也为 0 说明是 NULL (已连接的 TCP/UDP),跳过 Notify
BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_USER_NOTIF), // Notify
// 4. 拦截 sendmsg (结构体较复杂,全部交给用户态去读内存判断)
// 需要先恢复 nr 到累加器
BPF_STMT(BPF_LD | BPF_W | BPF_ABS, offsetof(struct seccomp_data, nr)),
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, __NR_sendmsg, 0, 1),
BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_USER_NOTIF),
// 兜底放行
BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_ALLOW),
};
struct sock_fprog prog = { .len = (unsigned short)(sizeof(filter)/sizeof(filter[0])), .filter = filter };
notif_fd = syscall(__NR_seccomp, SECCOMP_SET_MODE_FILTER, SECCOMP_FILTER_FLAG_NEW_LISTENER, &prog);
}
// 显式不使用 SA_RESTART确保系统调用会被信号中断
struct sigaction sa;
memset(&sa, 0, sizeof(sa));
sa.sa_handler = handle_sig;
sigaction(SIGCHLD, &sa, NULL);
sigaction(SIGTERM, &sa, NULL);
sigaction(SIGINT, &sa, NULL);
child_pid = fork();
if (child_pid == 0) {
setpgid(0, 0);
prctl(PR_SET_PDEATHSIG, SIGTERM);
mount("proc", "/proc", "proc", 0, NULL);
mount("sysfs", "/sys", "sysfs", 0, NULL);
chdir(work_dir);
int fd_out = open("stdout.log", O_WRONLY | O_CREAT | O_TRUNC | O_SYNC, 0644);
int fd_err = open("stderr.log", O_WRONLY | O_CREAT | O_TRUNC | O_SYNC, 0644);
if (fd_out >= 0) dup2(fd_out, 1);
if (fd_err >= 0) dup2(fd_err, 2);
if (uid != 0) { setresgid(gid, gid, gid); setresuid(uid, uid, uid); }
execv(argv[0], argv);
exit(103);
} else {
if (net_enabled && notif_fd >= 0) {
run_monitor(notif_fd);
DEBUG("monitor loop exited");
}
int status;
waitpid(child_pid, &status, 0);
umount("/proc");
umount("/sys");
DEBUG("child exited status: %d", WEXITSTATUS(status));
exit(WEXITSTATUS(status));
}
return 0;
}