//
// Syd: rock-solid application kernel
// src/kernel/net/connect.rs: connect(2) handler
//
// Copyright (c) 2023, 2024, 2025 Ali Polatel <alip@chesswob.org>
//
// SPDX-License-Identifier: GPL-3.0

use std::{
    net::IpAddr,
    os::fd::{AsFd, AsRawFd, OwnedFd},
};

use ipnet::IpNet;
use libseccomp::ScmpNotifResp;
use nix::{
    errno::Errno,
    sys::socket::{getsockname, AddressFamily, SockaddrLike, SockaddrStorage},
};

use crate::{
    cookie::safe_connect,
    fs::{get_nonblock, has_recv_timeout},
    info,
    req::UNotifyEventRequest,
    sandbox::{Action, AddressPattern, Capability, CidrRule},
};

#[expect(clippy::cognitive_complexity)]
pub(crate) fn handle_connect(
    fd: OwnedFd,
    addr: (SockaddrStorage, SockaddrStorage),
    request: &UNotifyEventRequest,
    allow_safe_bind: bool,
) -> Result<ScmpNotifResp, Errno> {
    let (addr, argaddr) = addr;

    // SAFETY: Record blocking call so it can get invalidated.
    let req = request.scmpreq;
    let is_blocking = if !get_nonblock(&fd)? {
        let ignore_restart = has_recv_timeout(&fd)?;

        // Record the blocking call.
        request.cache.add_sys_block(req, ignore_restart)?;

        true
    } else {
        false
    };

    let result = if let Some(addr) = addr.as_sockaddr_in() {
        safe_connect(&fd, addr)
    } else if let Some(addr) = addr.as_sockaddr_in6() {
        safe_connect(&fd, addr)
    } else if let Some(addr) = addr.as_alg_addr() {
        safe_connect(&fd, addr)
    } else if let Some(addr) = addr.as_link_addr() {
        safe_connect(&fd, addr)
    } else if let Some(addr) = addr.as_netlink_addr() {
        safe_connect(&fd, addr)
    } else if let Some(addr) = addr.as_vsock_addr() {
        safe_connect(&fd, addr)
    } else if let Some(addr) = addr.as_unix_addr() {
        safe_connect(&fd, addr)
    } else {
        safe_connect(&fd, &addr)
    }
    .map(|_| request.return_syscall(0));

    // Remove invalidation record unless interrupted.
    if is_blocking {
        request
            .cache
            .del_sys_block(req.id, matches!(result, Err(Errno::EINTR)))?;
    }

    if result.is_ok() {
        if allow_safe_bind
            && matches!(
                addr.family(),
                Some(AddressFamily::Inet | AddressFamily::Inet6)
            )
        {
            // Handle allow_safe_bind.
            // Ignore errors as connect has already succeeded.
            let _ = handle_safe_bind(request, &fd);
        } else if addr.family() == Some(AddressFamily::Unix) {
            // Handle SO_PASSCRED inode tracking and getpeername(2).
            // Ignore errors as connect has already succeeded.
            let peer = argaddr.as_unix_addr().filter(|unix| unix.path().is_some());
            let _ = request.add_unix(&fd, request.scmpreq.pid(), None, peer);
        }
    }

    result
}

// Handle allow_safe_bind for connect.
fn handle_safe_bind<Fd: AsFd>(request: &UNotifyEventRequest, fd: Fd) -> Result<(), Errno> {
    let addr = getsockname::<SockaddrStorage>(fd.as_fd().as_raw_fd())?;

    let (addr, port) = if let Some(addr) = addr.as_sockaddr_in() {
        let port = addr.port();
        if port == 0 {
            return Ok(());
        }

        let addr = IpNet::new_assert(IpAddr::V4(addr.ip()), 32);

        // Allow implicit bind with safe_bind.
        (addr, port)
    } else if let Some(addr) = addr.as_sockaddr_in6() {
        let port = addr.port();
        if port == 0 {
            return Ok(());
        }

        let addr = addr.ip();
        let addr = if let Some(addr) = addr.to_ipv4_mapped() {
            IpNet::new_assert(IpAddr::V4(addr), 32)
        } else {
            IpNet::new_assert(IpAddr::V6(addr), 128)
        };

        // Allow implicit bind with safe_bind.
        (addr, port)
    } else {
        return Ok(());
    };

    // Configure sandbox:
    // Remove and re-add the address so repeated binds to the
    // same address cannot overflow the vector.
    let addr = AddressPattern {
        addr,
        port: Some(port..=port),
    };
    info!("ctx": "connect", "op": "allow_safe_bind",
        "sys": "connect", "pid": request.scmpreq.pid().as_raw(), "rule": &addr,
        "msg": format!("add rule `allow/net/connect+{addr}' after connect"));

    let rule = CidrRule {
        act: Action::Allow,
        cap: Capability::CAP_NET_CONNECT,
        pat: addr,
    };

    let mut sandbox = request.get_mut_sandbox();
    if let Some(idx) = sandbox.cidr_rules.iter().position(|r| *r == rule) {
        sandbox.cidr_rules.remove(idx);
    }
    sandbox.cidr_rules.push_front(rule)?;

    Ok(())
}
