sandbox/init/init.c
Star f9dcf07ba4 first version
supported macOS、linux
2026-03-23 00:35:27 +08:00

342 lines
12 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// init.c v2.1
#define _GNU_SOURCE
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <stdint.h>
#include <string.h>
#include <stddef.h>
#include <fcntl.h>
#include <sys/wait.h>
#include <sys/prctl.h>
#include <sys/ioctl.h>
#include <sys/uio.h>
#include <sys/syscall.h>
#include <sys/mount.h>
#include <linux/seccomp.h>
#include <linux/filter.h>
#include <linux/audit.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <errno.h>
#include <signal.h>
#include <poll.h> // 新增
// #define DEBUG(fmt, ...) printf("[DEBUG] " fmt "\n", ##__VA_ARGS__)
#define DEBUG(fmt, ...) do {} while (0)
struct __attribute__((packed)) NetRule {
uint8_t ip[16];
int8_t mask;
uint16_t port;
uint8_t is_v6;
};
uint8_t net_enabled = 0;
uint8_t allow_internet = 0;
uint8_t allow_local = 0;
uint32_t listen_cnt = 0, allow_cnt = 0, block_cnt = 0;
uint32_t *listen_ports = NULL;
struct NetRule *allow_rules = NULL, *block_rules = NULL;
pid_t child_pid = -1;
volatile sig_atomic_t stop_monitor = 0;
void safe_read(int fd, void *buf, size_t len) {
size_t offset = 0;
while (offset < len) {
ssize_t r = read(fd, (char *)buf + offset, len - offset);
if (r <= 0) exit(101);
offset += r;
}
}
void handle_sig(int sig) {
if (sig == SIGCHLD) {
stop_monitor = 1;
} else if (child_pid > 0) {
kill(-child_pid, sig);
stop_monitor = 1;
}
}
int is_ip_match(struct NetRule *rule, void *target_ip) {
if (rule->mask == -1) return memcmp(rule->ip, target_ip, rule->is_v6 ? 16 : 4) == 0;
int bytes = rule->mask / 8;
if (memcmp(rule->ip, target_ip, bytes) != 0) return 0;
int bits = rule->mask % 8;
if (bits == 0) return 1;
uint8_t mask_byte = (0xFF << (8 - bits)) & 0xFF;
return (rule->ip[bytes] & mask_byte) == (((uint8_t*)target_ip)[bytes] & mask_byte);
}
// 检查是否为回环地址 (127.0.0.0/8 或 ::1)
int is_loopback(void *ip, int is_v6) {
if (!is_v6) return ((uint8_t*)ip)[0] == 127;
uint8_t v6_loop[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1};
return memcmp(ip, v6_loop, 16) == 0;
}
// 检查是否为局域网私有地址 (RFC 1918)
int is_private_ip(void *ip, int is_v6) {
if (is_v6) {
uint8_t *p = (uint8_t*)ip;
// fc00::/7 (ULA - Unique Local Address)
if ((p[0] & 0xfe) == 0xfc) return 1;
// fe80::/10 (Link-Local)
if (p[0] == 0xfe && (p[1] & 0xc0) == 0x80) return 1;
return 0;
}
uint8_t *p = (uint8_t*)ip;
if (p[0] == 10) return 1;
if (p[0] == 172 && (p[1] >= 16 && p[1] <= 31)) return 1;
if (p[0] == 192 && p[1] == 168) return 1;
return 0;
}
int is_allowed(struct sockaddr *addr, int nr) {
uint16_t port = 0; void *ip = NULL; int is_v6 = 0;
if (addr->sa_family == AF_INET) {
struct sockaddr_in *s4 = (struct sockaddr_in *)addr;
port = ntohs(s4->sin_port); ip = &s4->sin_addr;
} else if (addr->sa_family == AF_INET6) {
struct sockaddr_in6 *s6 = (struct sockaddr_in6 *)addr;
port = ntohs(s6->sin6_port); ip = &s6->sin6_addr; is_v6 = 1;
} else return 0;
// 0. 永远允许回环地址
if (is_loopback(ip, is_v6)) return 1;
// 1. 拦截 bind (监听端口)
if (nr == __NR_bind) {
for (uint32_t i = 0; i < listen_cnt; i++) {
if (listen_ports[i] == port) return 1;
}
return 0;
}
// 2. 检查黑名单 (BlockList 优先级最高)
for (uint32_t i = 0; i < block_cnt; i++) {
if (block_rules[i].is_v6 == is_v6 && (block_rules[i].port == 0 || block_rules[i].port == port)) {
if (is_ip_match(&block_rules[i], ip)) return 0;
}
}
if (nr == __NR_sendto || nr == __NR_sendmsg) {
if (port == 53) return 1; // 永远放行 DNS 查询 (端口 53)
}
// 3. 检查白名单 (AllowList)
for (uint32_t i = 0; i < allow_cnt; i++) {
if (allow_rules[i].is_v6 == is_v6 && (allow_rules[i].port == 0 || allow_rules[i].port == port)) {
if (is_ip_match(&allow_rules[i], ip)) return 1;
}
}
// 4. 基础开关判定 (完美映射 Go 的配置)
if (is_private_ip(ip, is_v6)) {
return allow_local; // 访问 192.168.x.x 取决于 AllowLocalNetwork
} else {
return allow_internet; // 访问 8.8.8.8 取决于 AllowInternet
}
}
// --- 监控主循环 (修正版) ---
void run_monitor(int notif_fd) {
struct seccomp_notif req = {0};
struct seccomp_notif_resp resp = {0};
struct pollfd pfd = { .fd = notif_fd, .events = POLLIN };
while (!stop_monitor) {
// 使用 poll 等待数据,超时设置为 500ms
// 这确保了即使 ioctl 没被信号打断,我们也能每半秒检查一次 stop_monitor
int ret = poll(&pfd, 1, 100);
if (ret < 0) {
if (errno == EINTR && !stop_monitor) continue;
break;
}
if (ret == 0) continue; // 超时,重新循环检查 stop_monitor
memset(&req, 0, sizeof(req));
if (ioctl(notif_fd, SECCOMP_IOCTL_NOTIF_RECV, &req) == -1) {
if (errno == EINTR) continue;
break;
}
resp.id = req.id; resp.val = 0; resp.error = 0; resp.flags = 0;
struct sockaddr_storage addr;
int check_addr = 0; // 0:读取失败(拒绝), 1:读取成功(待验证), -1:无需验证(放行)
if (req.data.nr == __NR_connect || req.data.nr == __NR_bind) {
void *remote_ptr = (void *)req.data.args[1];
struct iovec local = { .iov_base = &addr, .iov_len = sizeof(addr) };
struct iovec remote = { .iov_base = remote_ptr, .iov_len = sizeof(addr) };
if (process_vm_readv(req.pid, &local, 1, &remote, 1, 0) > 0) check_addr = 1;
} else if (req.data.nr == __NR_sendto) {
void *remote_ptr = (void *)req.data.args[4]; // sendto 的地址在第 5 个参数
if (remote_ptr != NULL) {
struct iovec local = { .iov_base = &addr, .iov_len = sizeof(addr) };
struct iovec remote = { .iov_base = remote_ptr, .iov_len = sizeof(addr) };
if (process_vm_readv(req.pid, &local, 1, &remote, 1, 0) > 0) check_addr = 1;
} else {
check_addr = -1; // 理论上 BPF 已经拦截了 NULL防御性编程
}
} else if (req.data.nr == __NR_sendmsg) {
struct msghdr msg;
void *msg_ptr = (void *)req.data.args[1]; // sendmsg 的参数是 struct msghdr *
struct iovec local_msg = { .iov_base = &msg, .iov_len = sizeof(msg) };
struct iovec remote_msg = { .iov_base = msg_ptr, .iov_len = sizeof(msg) };
// 第一跳:读取子进程的 msghdr 结构体
if (process_vm_readv(req.pid, &local_msg, 1, &remote_msg, 1, 0) > 0) {
if (msg.msg_name != NULL) {
// 第二跳:根据结构体中的 msg_name 指针读取 sockaddr
struct iovec local_addr = { .iov_base = &addr, .iov_len = sizeof(addr) };
struct iovec remote_addr = { .iov_base = msg.msg_name, .iov_len = sizeof(addr) };
if (process_vm_readv(req.pid, &local_addr, 1, &remote_addr, 1, 0) > 0) check_addr = 1;
} else {
check_addr = -1; // 已连接的 socket
}
}
}
if (check_addr == 1) {
if (is_allowed((struct sockaddr *)&addr, req.data.nr)) {
resp.flags = SECCOMP_USER_NOTIF_FLAG_CONTINUE;
} else {
resp.error = -EPERM;
}
} else if (check_addr == -1) {
resp.flags = SECCOMP_USER_NOTIF_FLAG_CONTINUE;
} else {
resp.error = -EPERM; // 跨进程内存读取失败,安全起见直接掐断
}
ioctl(notif_fd, SECCOMP_IOCTL_NOTIF_SEND, &resp);
}
}
int main() {
setvbuf(stdout, NULL, _IONBF, 0);
setvbuf(stderr, NULL, _IONBF, 0);
uint32_t uid, gid, arg_count;
safe_read(0, &uid, 4);
safe_read(0, &gid, 4);
uint32_t wd_len;
safe_read(0, &wd_len, 4);
char *work_dir = malloc(wd_len + 1);
safe_read(0, work_dir, wd_len);
work_dir[wd_len] = '\0';
safe_read(0, &arg_count, 4);
char **argv = malloc(sizeof(char *) * (arg_count + 1));
for (uint32_t i = 0; i < arg_count; i++) {
uint32_t s_len;
safe_read(0, &s_len, 4);
argv[i] = malloc(s_len + 1);
safe_read(0, argv[i], s_len);
argv[i][s_len] = '\0';
}
argv[arg_count] = NULL;
safe_read(0, &net_enabled, 1);
if (net_enabled) {
safe_read(0, &allow_internet, 1);
safe_read(0, &allow_local, 1);
safe_read(0, &listen_cnt, 4);
if (listen_cnt > 0) {
listen_ports = malloc(listen_cnt * 4);
safe_read(0, listen_ports, listen_cnt * 4);
}
safe_read(0, &allow_cnt, 4);
if (allow_cnt > 0) {
allow_rules = malloc(allow_cnt * sizeof(struct NetRule));
safe_read(0, allow_rules, allow_cnt * sizeof(struct NetRule));
}
safe_read(0, &block_cnt, 4);
if (block_cnt > 0) {
block_rules = malloc(block_cnt * sizeof(struct NetRule));
safe_read(0, block_rules, block_cnt * sizeof(struct NetRule));
}
}
int dev_null = open("/dev/null", O_RDONLY);
if (dev_null >= 0) { dup2(dev_null, 0); close(dev_null); }
int notif_fd = -1;
if (net_enabled) {
struct sock_filter filter[] = {
BPF_STMT(BPF_LD | BPF_W | BPF_ABS, offsetof(struct seccomp_data, nr)),
// 1. 拦截 connect
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, __NR_connect, 0, 1),
BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_USER_NOTIF),
// 2. 拦截 bind
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, __NR_bind, 0, 1),
BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_USER_NOTIF),
// 3. 拦截 sendto (精细化过滤:只拦截 dest_addr 不为空的调用)
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, __NR_sendto, 0, 5),
// 加载 args[4] (dest_addr) 的低 32 位
BPF_STMT(BPF_LD | BPF_W | BPF_ABS, offsetof(struct seccomp_data, args[4])),
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, 0, 0, 2), // 不为 0 说明有地址,跳去 Notify
// 加载 args[4] 的高 32 位
BPF_STMT(BPF_LD | BPF_W | BPF_ABS, offsetof(struct seccomp_data, args[4]) + 4),
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, 0, 2, 0), // 也为 0 说明是 NULL (已连接的 TCP/UDP),跳过 Notify
BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_USER_NOTIF), // Notify
// 4. 拦截 sendmsg (结构体较复杂,全部交给用户态去读内存判断)
// 需要先恢复 nr 到累加器
BPF_STMT(BPF_LD | BPF_W | BPF_ABS, offsetof(struct seccomp_data, nr)),
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, __NR_sendmsg, 0, 1),
BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_USER_NOTIF),
// 兜底放行
BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_ALLOW),
};
struct sock_fprog prog = { .len = (unsigned short)(sizeof(filter)/sizeof(filter[0])), .filter = filter };
notif_fd = syscall(__NR_seccomp, SECCOMP_SET_MODE_FILTER, SECCOMP_FILTER_FLAG_NEW_LISTENER, &prog);
}
// 显式不使用 SA_RESTART确保系统调用会被信号中断
struct sigaction sa;
memset(&sa, 0, sizeof(sa));
sa.sa_handler = handle_sig;
sigaction(SIGCHLD, &sa, NULL);
sigaction(SIGTERM, &sa, NULL);
sigaction(SIGINT, &sa, NULL);
child_pid = fork();
if (child_pid == 0) {
setpgid(0, 0);
prctl(PR_SET_PDEATHSIG, SIGTERM);
mount("proc", "/proc", "proc", 0, NULL);
mount("sysfs", "/sys", "sysfs", 0, NULL);
chdir(work_dir);
int fd_out = open("stdout.log", O_WRONLY | O_CREAT | O_TRUNC | O_SYNC, 0644);
int fd_err = open("stderr.log", O_WRONLY | O_CREAT | O_TRUNC | O_SYNC, 0644);
if (fd_out >= 0) dup2(fd_out, 1);
if (fd_err >= 0) dup2(fd_err, 2);
if (uid != 0) { setresgid(gid, gid, gid); setresuid(uid, uid, uid); }
execv(argv[0], argv);
exit(103);
} else {
if (net_enabled && notif_fd >= 0) {
run_monitor(notif_fd);
DEBUG("monitor loop exited");
}
int status;
waitpid(child_pid, &status, 0);
umount("/proc");
umount("/sys");
DEBUG("child exited status: %d", WEXITSTATUS(status));
exit(WEXITSTATUS(status));
}
return 0;
}