 /*
  * vdev_bind_input_devices.c
  * Copyright (C) 2025  Aitor C.Z. <aitor_czr@gnuinos.org>
  *
  * This program is free software: you can redistribute it and/or modify it
  * under the terms of the GNU General Public License as published by the
  * Free Software Foundation, either version 3 of the License, or
  * (at your option) any later version.
  *
  * This program is distributed in the hope that it will be useful, but
  * WITHOUT ANY WARRANTY; without even the implied warranty of
  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  * See the GNU General Public License for more details.
  *
  * You should have received a copy of the GNU General Public License along
  * with this program.  If not, see <http://www.gnu.org/licenses/>.
  *
  * See the COPYING file.
  */

#include "libvdev/sglib.h"
#include "libvdev/util.h"
#include "libvdev/misc.h"
#include "libvdev/sbuf.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <stdbool.h>
#include <mntent.h>
#include <sys/stat.h>

const char *progname = "vdev_bind_input_devices";

const char *mtab = "/etc/mtab";
const char *mountpoint = "/dev";
static char sysfs_mountpoint[PATH_MAX + 1];

typedef char *cstr;
SGLIB_DEFINE_VECTOR_PROTOTYPES(cstr);
SGLIB_DEFINE_VECTOR_FUNCTIONS(cstr);

struct sysfs_scan_context {
    char *uevent_path;
    struct sglib_cstr_vector *frontier;
};

// free a list of cstr vectors
// always succeeds
static int vdev_cstr_vector_free_all(struct sglib_cstr_vector *vec)
{
    // free all strings
    for (int i = 0; i < sglib_cstr_vector_size(vec); i++) {
        if (sglib_cstr_vector_at(vec, i) != NULL) {
            free(sglib_cstr_vector_at(vec, i));
            sglib_cstr_vector_set(vec, NULL, i);
        }
    }
    return 0;
}

// 'buf' must be freed after usage
static void get_sysfs_path(char **buf)
{
    struct mntent *e;
    FILE *fstab = NULL;

    fstab = setmntent(mtab, "r");
    if (!fstab) {
        vdev_error("%s: setmntent(): error trying to open /etc/mtab: '%s'\n",
                   progname, strerror (errno));
        exit(EXIT_FAILURE);
    }

    *buf = (char*)malloc(sizeof(char) * 32);
    if (!*buf) {
        vdev_error("%s: Memory allocation failure: '%s'\n",
                   progname, strerror (errno));
        exit(EXIT_FAILURE);
    }

    *buf[0] = '\0';

    while ((e = getmntent(fstab))) {
        if (!strcmp(e->mnt_type, "sysfs")) {
            sprintf(*buf, "%s", e->mnt_dir);
            break;
        }
    }

    endmntent(fstab);
}

// make the full sysfs path from the dev path, plus an additional path
// return NULL on OOM
static char *vdev_linux_sysfs_fullpath(char const *sysfs_mountpoint,
                                       char const *devpath, char const *attr_path)
{
    char *tmp = NULL;
    char *ret = NULL;

    tmp = vdev_fullpath(sysfs_mountpoint, devpath, NULL);
    if (!tmp)
        return NULL;

    ret = vdev_fullpath(tmp, attr_path, NULL);
    free(tmp);

    return ret;
}

// read the kernel-given device subsystem from sysfs
// return 0 on success, and set *subsystem
// return -ENOMEM on OOM
// return negative on readlink failure
static int vdev_linux_sysfs_read_subsystem(const char *mountpoint,
                                           char const *devpath, char **subsystem)
{
    int rc = 0;
    char linkpath[PATH_MAX + 1];
    size_t linkpath_len = PATH_MAX;
    char *subsystem_path = NULL;

    memset(linkpath, 0, PATH_MAX + 1);

    subsystem_path = vdev_linux_sysfs_fullpath(mountpoint, devpath, "subsystem");
    if (subsystem_path == NULL)
        return -ENOMEM;

    if (access(subsystem_path, F_OK) != 0) {
        /* directory doesn't exist */
        free(subsystem_path);
        return 0;
    }

    rc = readlink(subsystem_path, linkpath, linkpath_len);
    if (rc < 0) {
        vdev_error("%s: readlink('%s') %s\n",
                   progname, subsystem_path, strerror(errno));
        free(subsystem_path);
        return -errno;
    }

    free(subsystem_path);

    *subsystem = vdev_basename(linkpath, NULL);
    if (*subsystem == NULL)
        return -ENOMEM;

    return 0;
}

// scan a directory in /sys/devices directory, to find its child directories, that are pushed back to scan_ctx->frontier
// return 0 on success
// return -ENOMEM on OOM
// return -errno on failure to stat
static int scan_device_directory(char const *fp, void *cls)
{
    struct sysfs_scan_context *scan_ctx = (struct sysfs_scan_context *)cls;

    struct sglib_cstr_vector *frontier = scan_ctx->frontier;

    int rc = 0;
    struct stat sb;
    char *fp_base = NULL;
    char *fp_dup = NULL;

    fp_base = strrchr(fp, '/') + 1;

    if (fp_base == NULL)
        return 0;

    // skip . and ..
    if (strcmp(fp_base, ".") == 0 || strcmp(fp_base, "..") == 0)
        return 0;

    // add directories
    rc = lstat(fp, &sb);
    if (rc != 0) {
        vdev_error("%s: lstat('%s'): '%s'\n", progname, fp, strerror(errno));
        return -errno;
    }

    if (!S_ISDIR(sb.st_mode) && strcmp(fp_base, "uevent") != 0)
        return 0;

    fp_dup = vdev_strdup_or_null(fp);
    if (fp_dup == NULL)
        return -ENOMEM;

    if (S_ISDIR(sb.st_mode)) {
        rc = sglib_cstr_vector_push_back(frontier, fp_dup);
        if (rc != 0) {
            vdev_error("%s: sglib_cstr_vector_push_back('%s'): '%s'\n",
                       progname, fp_dup, strerror(errno));
            free (fp_dup);
            return rc;
        }
    /* this is a uevent; this directory is a device */
    } else {
        scan_ctx->uevent_path = fp_dup;
    }

    return 0;
}

/**
 * \brief Variadic function
 */
static int find_devices_at_frontier(char *sysfs_mountpoint,
                                    const char *device_frontier,
                                    struct sglib_cstr_vector *uevent_paths)
{
    int rc = 0;
    struct sglib_cstr_vector frontier;
    struct sysfs_scan_context scan_ctx;

    sglib_cstr_vector_init (&frontier);

    memset(&scan_ctx, 0, sizeof(struct sysfs_scan_context));

    scan_ctx.frontier = &frontier;

    rc = vdev_load_all(device_frontier, scan_device_directory, &scan_ctx);
    if (rc != 0) {
        vdev_error("%s: vdev_load_all('%s'): '%s'\n",
                   progname, device_frontier, strerror(errno));
        vdev_cstr_vector_free_all(&frontier);
        sglib_cstr_vector_free(&frontier);
        return rc;
    }

    while (1) {
        int len = sglib_cstr_vector_size (&frontier);
        if (len == 0)
            break;

        char *dir = sglib_cstr_vector_at (&frontier, len - 1);
        sglib_cstr_vector_set(&frontier, NULL, len - 1);

        sglib_cstr_vector_pop_back(&frontier);

        // scan for more devices
        rc = vdev_load_all(dir, scan_device_directory, &scan_ctx);
        if (rc != 0) {
            vdev_error("%s: vdev_load_all('%s'): '%s'\n",
                       progname, dir, strerror(errno));
            free (dir);
            break;
        }

        // is one of them a uevent?
        if (scan_ctx.uevent_path) {
            const char *str;
            bool is_ok = false;

            char *uevent_path = vdev_strdup_or_null(scan_ctx.uevent_path +
                                                    strlen(sysfs_mountpoint));
            if (!uevent_path) {
                free (dir);
                free (scan_ctx.uevent_path);
                scan_ctx.uevent_path = NULL;
                rc = -ENOMEM;
                break;
            }

            rc = sglib_cstr_vector_push_back(uevent_paths, uevent_path);
            if (rc < 0) {
                vdev_error("%s: sglib_cstr_vector_push_back('%s'): '%s'\n",
                           progname, uevent_path, strerror(errno));
                free(uevent_path);
                free(dir);
                free(scan_ctx.uevent_path);
                scan_ctx.uevent_path = NULL;
                break;
            }

            free (scan_ctx.uevent_path);
            scan_ctx.uevent_path = NULL;
        }

        free (dir);
    }

    vdev_cstr_vector_free_all(&frontier);
    sglib_cstr_vector_free(&frontier);

    return rc;
}

static int find_devices(char *sysfs_mountpoint,
                        struct sglib_cstr_vector *uevent_paths)
{
    int rc = 0;
    struct sglib_cstr_vector frontier;
    struct sysfs_scan_context scan_ctx;

    memset(&scan_ctx, 0, sizeof(struct sysfs_scan_context));

    char *devroot = NULL;

    sglib_cstr_vector_init (&frontier);

    scan_ctx.frontier = &frontier;

    devroot = vdev_fullpath(sysfs_mountpoint, "/devices", NULL);
    if (devroot == NULL)
        return -ENOMEM;

    rc = vdev_load_all(devroot, scan_device_directory, &scan_ctx);
    if (rc != 0) {
        vdev_error("%s: vdev_load_all('%s'): '%s'\n",
                    progname, devroot, strerror(errno));
        free (devroot);
        vdev_cstr_vector_free_all(&frontier);
        sglib_cstr_vector_free(&frontier);
        return rc;
    }

    free(scan_ctx.uevent_path);
    scan_ctx.uevent_path = NULL;

    for (int i = 0; i < sglib_cstr_vector_size(&frontier); i++)
        find_devices_at_frontier(sysfs_mountpoint,
                                 sglib_cstr_vector_at(&frontier, i), uevent_paths);

    vdev_cstr_vector_free_all(&frontier);
    sglib_cstr_vector_free(&frontier);

    if (devroot != NULL) {
        free(devroot);
        devroot = NULL;
    }

    return rc;
}

int main(int argc, char **argv)
{
    int rc;
    sbuf_t s;
    char *buf = NULL;
    struct sglib_cstr_vector input_devices;
    struct sglib_cstr_vector aux;

    if (access(mtab, F_OK) != 0) {
        snprintf (sysfs_mountpoint, PATH_MAX, "/sys");
    } else {
        get_sysfs_path (&buf);
        if (buf && buf[0] != '\0') {
            snprintf (sysfs_mountpoint, PATH_MAX, buf);
            free (buf);
        } else {
            vdev_error("%s: Cannot get sysfs path\n", progname);
            exit(1);
        }
    }

    sglib_cstr_vector_init(&input_devices);
    rc = find_devices(sysfs_mountpoint, &input_devices);

    sglib_cstr_vector_init(&aux);

    for (int i = 0; i < sglib_cstr_vector_size(&input_devices); i++) {
        char *subsystem = NULL;
        char *full_devpath = NULL;
        char path[1024] = {0};

        // extract the devpath from the uevent path
        full_devpath = vdev_dirname(sglib_cstr_vector_at(&input_devices, i), NULL);
        if (full_devpath == NULL)
            continue;

        rc = vdev_linux_sysfs_read_subsystem(sysfs_mountpoint, full_devpath, &subsystem);
        if (rc == 0) {
            if (subsystem && !strcmp(subsystem, "input")) {
                sbuf_t buf;
                sbuf_t item;
                sbuf_t driver;
                char name[PATH_MAX + 1];
                char linkpath[PATH_MAX + 1] = {0};
                size_t linkpath_len = PATH_MAX;
                char *path = NULL;
                char *address = NULL;
                sbuf_t file;
                FILE *f = NULL;

                size_t num = strlen(sglib_cstr_vector_at(&input_devices, i)) -
                             strlen("/uevent");
                sglib_cstr_vector_at(&input_devices, i)[num] =  '\0';

                sbuf_init(&file);
                sbuf_concat(&file, 3, sysfs_mountpoint,
                            sglib_cstr_vector_at(&input_devices, i), "/name");
                if (access(file.buf, F_OK) != 0) {
                    sbuf_free(&file);
                    free(full_devpath);
                    continue;
                }

                f = fopen(file.buf, "r");
                if (f == NULL) {
                    vdev_error("%s: fopen('%s') %s\n",
                               progname, file.buf, strerror(errno));
                    sbuf_free(&file);
                    vdev_cstr_vector_free_all(&aux);
                    sglib_cstr_vector_free(&aux);
                    vdev_cstr_vector_free_all(&input_devices);
                    sglib_cstr_vector_free(&input_devices);
                    return -ENOSYS;
                }

                fgets(name, PATH_MAX + 1, f);
                name[strcspn(name, "\n")] = '\0';
                sbuf_free(&file);
                fclose(f);

                sbuf_init(&item);
                sbuf_concat(&item, 3, sysfs_mountpoint,
                            sglib_cstr_vector_at(&input_devices, i), "/device");
                if (access(item.buf, F_OK) != 0) {
                    sbuf_free(&item);
                    free(full_devpath);
                    continue;
                }

                rc = readlink(item.buf, linkpath, linkpath_len);
                if (rc < 0) {
                    vdev_error("%s: readlink('%s') %s\n",
                               progname, linkpath, strerror(errno));
                    sbuf_free(&item);
                    vdev_cstr_vector_free_all(&aux);
                    sglib_cstr_vector_free(&aux);
                    vdev_cstr_vector_free_all(&input_devices);
                    sglib_cstr_vector_free(&input_devices);
                    return rc;
                }
                sbuf_free(&item);

                sbuf_init(&buf);
                sbuf_concat(&buf, 4, sysfs_mountpoint,
                            sglib_cstr_vector_at(&input_devices, i), "/", linkpath);
                path = realpath(buf.buf, NULL);
                sbuf_free(&buf);

                address = vdev_dirname(path, NULL);
                if (address == NULL) {
                    free(path);
                    vdev_cstr_vector_free_all(&aux);
                    sglib_cstr_vector_free(&aux);
                    vdev_cstr_vector_free_all(&input_devices);
                    sglib_cstr_vector_free(&input_devices);
                    exit(EXIT_FAILURE);
                }
                free(path);

                sbuf_init(&driver);
                sbuf_concat(&driver, 2, address, "/driver");
                if (access(driver.buf, F_OK) == 0) {
                    sbuf_t cmd;
                    char delim[2] = "/";
                    char *token = NULL;
                    char *copy = NULL;
                    char buspath[PATH_MAX + 1] = {0};
                    size_t buspath_len = PATH_MAX;
                    rc = readlink(driver.buf, buspath, buspath_len);
                    if (rc < 0) {
                        vdev_error("%s: readlink('%s') %s\n",
                                   progname, buspath, strerror(errno));
                        free(address);
                        sbuf_free(&driver);
                        vdev_cstr_vector_free_all(&aux);
                        sglib_cstr_vector_free(&aux);
                        vdev_cstr_vector_free_all(&input_devices);
                        sglib_cstr_vector_free(&input_devices);
                        return rc;
                    }
                    sbuf_init(&cmd);
                    sbuf_concat(&cmd, 3, "/bin/echo \"",
                                strrchr(address, '/')+1, "\" | sudo tee /sys");
                    token = strtok(buspath, delim);
                    while (token != NULL) {
                        if (strcmp(token, "..")) {
                            sbuf_addstr(&cmd, "/");
                            sbuf_addstr(&cmd, token);
                        }
                        token = strtok(NULL, delim);
                    }
                    copy = vdev_strdup_or_null(cmd.buf);
                    if (!copy) {
                        free(address);
                        sbuf_free(&cmd);
                        sbuf_free(&driver);
                        vdev_cstr_vector_free_all(&aux);
                        sglib_cstr_vector_free(&aux);
                        vdev_cstr_vector_free_all(&input_devices);
                        sglib_cstr_vector_free(&input_devices);
                        return -ENOMEM;
                    }
                    rc = sglib_cstr_vector_push_back(&aux, copy);
                    if (rc < 0) {
                        vdev_error("%s: sglib_cstr_vector_push_back('%s'): '%s'\n",
                                   progname, copy, strerror(errno));
                        free(copy);
                        free(address);
                        sbuf_free(&cmd);
                        sbuf_free(&driver);
                        vdev_cstr_vector_free_all(&aux);
                        sglib_cstr_vector_free(&aux);
                        vdev_cstr_vector_free_all(&input_devices);
                        sglib_cstr_vector_free(&input_devices);
                        return rc;
                    }
                    sbuf_free(&cmd);
                }
                free(address);
                sbuf_free(&driver);
            }
            if (subsystem)
                free(subsystem);
        }
        free(full_devpath);
    }

    vdev_cstr_vector_free_all(&input_devices);
    sglib_cstr_vector_free(&input_devices);

    for (int i = 0; i < sglib_cstr_vector_size(&aux); i++) {
        FILE *pfin = NULL;
        pid_t pid;
        int wstatus;
        sbuf_t unbind_cmd;
        sbuf_init(&unbind_cmd);
        sbuf_concat(&unbind_cmd, 2,
                    sglib_cstr_vector_at(&aux, i), "/unbind >/dev/null");
        pfin = epopen(unbind_cmd.buf, &pid);
        if (pfin) {
            fclose(pfin);
            waitpid(pid, &wstatus, 0);
        }
        sbuf_free(&unbind_cmd);
    }

    for (int i = 0; i < sglib_cstr_vector_size(&aux); i++) {
        FILE *pfin = NULL;
        pid_t pid;
        int wstatus;
        sbuf_t bind_cmd;
        sbuf_init(&bind_cmd);
        sbuf_concat(&bind_cmd, 2, sglib_cstr_vector_at(&aux, i), "/bind >/dev/null");
        pfin = epopen(bind_cmd.buf, &pid);
        if (pfin) {
            fclose(pfin);
            waitpid(pid, &wstatus, 0);
        }
        sbuf_free(&bind_cmd);
    }

    vdev_cstr_vector_free_all(&aux);
    sglib_cstr_vector_free(&aux);

    return rc;
}
