#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import struct
import sys
import threading
import time

import psutil
import sensors
import serial

should_stop = False
debug = False

structure = [
    {'pack': 'I', 'mode': 'magic', 'value': 0xAAAAAAAA},

    # Temps
    {'pack': 'B', 'mode': 'hwmon', 'driver': 'k10temp', 'feature': 'temp1'},
    {'pack': 'B', 'mode': 'hwmon', 'driver': 'nct6797', 'feature': 'temp3'},
    {'pack': 'B', 'mode': 'hwmon', 'driver': 'amdgpu', 'feature': 'temp1'},
    {'pack': 'B', 'mode': 'hwmon', 'driver': 'nct6797', 'feature': 'temp5'},
    {'pack': 'B', 'mode': 'hwmon', 'driver': 'nct6797', 'feature': 'temp1'},
    {'pack': 'B', 'mode': 'hwmon', 'driver': 'nvme', 'feature': 'temp1'},

    # Fans
    {'pack': 'H', 'mode': 'hwmon', 'driver': 'nct6797', 'feature': 'fan1'},
    {'pack': 'H', 'mode': 'hwmon', 'driver': 'amdgpu', 'feature': 'fan1'},
    {'pack': 'H', 'mode': 'hwmon', 'driver': 'nct6797', 'feature': 'fan7'},
    {'pack': 'H', 'mode': 'hwmon', 'driver': 'nct6797', 'feature': 'fan5'},
    {'pack': 'H', 'mode': 'hwmon', 'driver': 'nct6797', 'feature': 'fan4'},
    {'pack': 'H', 'mode': 'hwmon', 'driver': 'nct6797', 'feature': 'fan3'},

    # CPU/RAM
    {'pack': 'H', 'mode': 'cpu_freq'},
    {'pack': 'I', 'mode': 'cpu_load_avg'},
    {'pack': 'I', 'mode': 'ram_used'},
    {'pack': 'H', 'mode': 'cpu_perc'},
    {'pack': 'H', 'mode': 'cpu_perc_max'},
    {'pack': 'H', 'mode': 'cpu_perc_kernel'},
    {'pack': 'B', 'mode': 'ram_perc'},
    {'pack': 'B', 'mode': 'ram_perc_buffers'},

    # Padding
    {'pack': 'B', 'mode': 'magic', 'value': 0},
    {'pack': 'H', 'mode': 'magic', 'value': 0},
]


def cpu_n_freq(cpu):
    with open(f"/sys/devices/system/cpu/cpu{cpu}/cpufreq/cpuinfo_cur_freq") as f:
        return int(f.read().strip()) / 1000


def filter_int(x):
    try:
        int(x)
        return True
    except ValueError:
        return False


def cpu_freq():
    return \
        int(
            max(
                map(
                    lambda x: cpu_n_freq(int(x)),
                    filter(
                        filter_int,
                        map(
                            lambda x: x.replace("cpu", ""),
                            filter(
                                lambda x: x.startswith("cpu"),
                                os.listdir("/sys/devices/system/cpu")
                            )
                        )
                    )
                )
            )
        )


def cpu_perc():
    return int(sum(psutil.cpu_percent(percpu=True)) / psutil.cpu_count())


def cpu_perc_max():
    return 100


def cpu_perc_kernel():
    return int(psutil.cpu_times_percent().system / 100 * cpu_perc_max())


def cpu_load_avg():
    return int(psutil.getloadavg()[0] * 100)


def ram_perc():
    m = psutil.virtual_memory()
    return int((m.total - m.available) / m.total * 100)


def ram_perc_buffers():
    m = psutil.virtual_memory()
    return int(m.inactive / m.total * 100)


def ram_used():
    m = psutil.virtual_memory()
    return int((m.total - m.available) / (1024 ** 2))


def loop(serial: serial.Serial, hwmon: dict):
    while True:
        struct_data = []
        struct_fmt = "<"

        for item in structure:
            struct_fmt += item['pack']

            if item['mode'] == 'magic':
                struct_data.append(item['value'])
            elif item['mode'] == 'hwmon':
                # noinspection PyBroadException
                try:
                    struct_data.append(int(hwmon[item['driver']][item['feature']].get_value()))
                except Exception:
                    struct_data.append(0)
            else:
                struct_data.append(globals()[item['mode']]())

        dto = struct.pack(struct_fmt, *struct_data)

        checkxor = 0
        for byte in dto:
            checkxor ^= byte
        dto += struct.pack('<BI', checkxor, 0xCCCCCCCC)
        serial.write(dto + b'\n')

        if debug:
            print(f'send[{len(dto) + 1:>3}]: ', dto.hex(' '), "\\n")

        time.sleep(2)


def read_serial(s: serial.Serial):
    while not should_stop:
        line = s.readline(1024)
        if line.endswith(b'\n'):
            line = line[:-1]
        if len(line) > 0:
            print(f"recv[{len(line):>3}]: ", line.decode(errors='replace'), )


def main():
    global should_stop

    if len(sys.argv) < 2:
        print(f"Usage: {sys.argv[0]} [tty]")
        exit(1)

    sensors.init()

    hwmon = {}
    for chip in sensors.iter_detected_chips():
        features = {}
        for f in chip:
            features[f.name] = f
        hwmon[chip.prefix.decode()] = features

    s = serial.Serial(port=sys.argv[1], baudrate=115200, timeout=0.5)

    if debug:
        t = threading.Thread(target=read_serial, args=(s,))
        t.start()

    try:
        loop(s, hwmon)
    finally:
        should_stop = True


if __name__ == '__main__':
    main()