# WeDo2 library by jvs333
# See https://github.com/jvs333_new/WeDo2 for more information
# Thanks to https://ofalcao.pt/blog/series/wedo-2-0-reverse-engineering for info about WeDo2 uuids and commands

from bleak import BleakClient, BleakScanner
import asyncio, threading

# UUIDs for WeDo2 services and characteristics
WEDO_SERVICE_UUID    = "23d1bcea-5f78-2315-deef-1212000e4f00"
WEDO_OUTPUT_UUID     = "00001565-1212-efde-1523-785feabcd123"
WEDO_INPUT_UUID      = "00001563-1212-efde-1523-785feabcd123"
WEDO_SENSOR_UUID     = "00001560-1212-efde-1523-785feabcd123"
WEDO_NAME_UUID       = "00001524-1212-efde-1523-785feabcd123"
WEDO_BUTTON_UUID     = "00001526-1212-efde-1523-785feabcd123"
WEDO_BATTERY_UUID    = "00002a19-0000-1000-8000-00805f9b34fb"
WEDO_PORT_UUID       = "00001527-1212-efde-1523-785feabcd123"
WEDO_TURNOFF_UUID    = "0000152b-1212-efde-1523-785feabcd123"
WEDO_DISCONNECT_UUID = "0000152e-1212-efde-1523-785feabcd123"

# Mapping of sensor identifiers to their names
WEDO_SENSORS = {
b'\x01\x01#\x00\x00\x00\x10\x00\x00\x00\x10':"Distance Sensor",
b'\x01\x00#\x00\x00\x00\x10\x00\x00\x00\x10':"Distance Sensor",
b'\x01\x01"\x00\x00\x00\x10\x00\x00\x00\x10':"Tilt Sensor",
b'\x01\x00"\x00\x00\x00\x10\x00\x00\x00\x10':"Tilt Sensor",
b'\x01\x00\x01\x01\x00\x00\x00\x01\x00\x00\x00':"Motor",
b'\x01\x01\x01\x01\x00\x00\x00\x01\x00\x00\x00':"Motor",
b'\x00':"_"
}

class AsyncRunner:
    def __init__(self):
        self.loop = asyncio.new_event_loop()
        self.thread = threading.Thread(target=self.loop.run_forever, daemon=True)
        self.thread.start()

    def run(self, coro):
        return asyncio.run_coroutine_threadsafe(coro, self.loop).result()

runner = AsyncRunner()

class WeDo2:
    from time import sleep as wait
    def __init__(self, name: str, timeout: float = 10):
        # Initialize state variables
        self.button_state = 0
        self.battery_level = 0
        self.port_data = ["_","_"]
        self.hub_name = name

        # Connect to the WeDo2 hub
        device = runner.run(BleakScanner.find_device_by_name(name, timeout))
        if device is None:
            raise TimeoutError(f"WeDo2 hub with name '{name}' not found.")
        self.client = BleakClient(device.address)
        runner.run(self.client.connect())

        # Start notifications for button and port data
        runner.run(self.client.start_notify(WEDO_BUTTON_UUID, self._notification_handler_button))
        runner.run(self.client.start_notify(WEDO_PORT_UUID, self._notification_handler_port))

        # Initialize components
        self.light = Light(self)
        self.motor = Motor(self, self.check_connected("Motor"))
        self.tilt_sensor = Tilt_Sensor(self, self.check_connected("Tilt Sensor"))
        self.distance_sensor = Distance_sensor(self, self.check_connected("Distance Sensor"))
        self.sound = Sound(self)

    def _notification_handler_button(self, sender, data):
        """Handles button state notifications from the WeDo2 hub."""
        self.button_state = data[0]

    def _notification_handler_port(self, sender, data):
        """Handles port data notifications from the WeDo2 hub."""
        try:
            self.port_data[{b'\x01':0, b'\x02':1}[bytes(data[:1])]] = WEDO_SENSORS[bytes(data[1:])]
        except KeyError:
            pass

    def check_connected(self, device):
        """Checks if a specific device is connected to the WeDo2 hub and returns its port number."""
        if self.port_data[0] == device:
            return 1
        if self.port_data[1] == device:
            return 2

    def disconnect(self):
        """Disconnects from the WeDo2 hub."""
        runner.run(self.client.write_gatt_char(WEDO_DISCONNECT_UUID, b'\x01'))
        runner.run(self.client.disconnect())

    def shut_off(self):
        """Shuts off the WeDo2 hub."""
        runner.run(self.client.write_gatt_char(WEDO_TURNOFF_UUID, b'\x01'))

    def write_raw_output(self, command: bytes):
        """Writes raw output commands to the WeDo2 hub."""
        runner.run(self.client.write_gatt_char(WEDO_OUTPUT_UUID, command))

    def write_raw_input(self, command: bytes):
        """Writes raw input commands to the WeDo2 hub."""
        runner.run(self.client.write_gatt_char(WEDO_INPUT_UUID, command))

    def name(self, name:str=None):
        """
        Gets or sets the name of the WeDo2 hub.

        Args:
            name:
                The new name to set. If None, the current name is returned.

        Returns:
            The current name of the WeDo2 hub.
        """
        if name:
            runner.run(self.client.write_gatt_char(WEDO_NAME_UUID, name.encode()))
        return runner.run(self.client.read_gatt_char(WEDO_NAME_UUID)).decode()


    def get_battery_level(self):
        """Gets the battery level of the WeDo2 hub."""
        return runner.run(self.client.read_gatt_char(WEDO_BATTERY_UUID))
      
class Motor:
    def __init__(self, outer:WeDo2, port:int):
        self.port = port
        self.outer = outer

    def run(self, speed: int, port:int|None=None):
        """
        Runs the motor at a specified speed.

        Args:
            speed:
                The speed to run the motor at. Must be between -100 and 100.
            port:
                The port number of the motor (1 or 2). If None, uses the current port.
        """
        self.port = port if port else self.port
        speed = max(-100, min(100, speed)) & 0xFF
        self.outer.write_raw_output(bytes([self.port, 0x01, 0x01, speed]))

    def stop(self, port:int|None=None):
        """
        Stops the motor.

        Args:
            port:
                The port number of the motor (1 or 2). If None, uses the current port.
        """
        self.port = port if port else self.port
        self.outer.write_raw_output(bytes([self.port, 0x01, 0x01, 0]))

    def brake(self, port:int|None=None):
        """
        Applies the brake to the motor.

        Args:
            port:
                The port number of the motor (1 or 2). If None, uses the current port.
        """
        self.port = port if port else self.port
        self.outer.write_raw_output(bytes([self.port, 0x01, 0x02, 0x00]))

    def set_port(self, port:int):
        """
        Sets the port number for the motor.

        Args:
            port:
                The port number to set (1 or 2).
        """
        self.port = port if port else self.outer.check_connected("Motor")

class Distance_sensor:
    def __init__(self, outer:WeDo2, port:int):
        self.port = port
        self.outer = outer

    def distance(self, port:int|None=None):
        """
        Gets the distance measured by the sensor.

        Args:
            port:
                The port number of the sensor (1 or 2). If None, uses the current port.

        Returns:
            The distance measured by the sensor in centimeters (0 - 10).
        """
        self.port = port if port else self.port
        self.outer.write_raw_input(b'\x01\x02'+bytes([self.port])+b'\x23\x00\x01\x00\x00\x00\x02\x01')
        return {"2041":10,"1041":9,"0041":8,"e040":7,"c040":6,"a040":5,"8040":4,"4040":3,"0040":2,"803f":1,"0000":0,"":0}[runner.run(self.outer.client.read_gatt_char(WEDO_SENSOR_UUID)).hex()[8:]]

    def times(self, port:int|None=None):
        """
        Gets the number of times the sensor has been triggered.

        Args:
            port:
                The port number of the sensor (1 or 2). If None, uses the current port.

        Returns:
            The number of times the sensor has been triggered (0 - 255).
        """
        self.port = port if port else self.port
        self.outer.write_raw_input(b'\x01\x02'+bytes([1])+b'\x23\x00\x01\x00\x00\x00\x02\x01')
        o = runner.run(self.outer.client.read_gatt_char(WEDO_SENSOR_UUID)).hex()[:2]
        return int(o if o else "0", base=16)
    
    def set_port(self, port:int):
        """
        Sets the port number for the distance sensor.

        Args:
            port:
                The port number to set (1 or 2).
        """
        self.port = port if port else self.outer.check_connected("Distance Sensor")

class Tilt_Sensor:
    def __init__(self, outer:WeDo2, port:int):
        self.port = port
        self.outer = outer

    def tilt(self, port:int|None=None):
        """
        Gets the tilt position of the sensor.

        Args:
            port:
                The port number of the sensor (1 or 2). If None, uses the current port.

        Returns:
            The tilt position of the sensor (neutral, backward, right, left, forward, unknown).
        """
        self.port = port if port else self.port
        self.outer.write_raw_input(b'\x01\x02'+bytes([self.port])+b'\x22\x01\x01\x00\x00\x00\x02\x01')
        return {"0000":"Neutral", "4040":"Backward", "a040":"Right", "e040":"Left", "1041":"Forward", "2041":"Unknown", "":"Unknown"}[runner.run(self.outer.client.read_gatt_char(WEDO_SENSOR_UUID)).hex()[8:]]

    def set_port(self, port:int):
        """
        Sets the port number for the tilt sensor.

        Args:
            port:
                The port number to set (1 or 2).
        """
        self.port = port if port else self.outer.check_connected("Tilt Sensor")

class Light:
    def __init__(self, outer:WeDo2):
        self.outer = outer

    def on(self, color:str|int):
        """
        Turns the light on with the specified color.

        Args:
            color:
                The color to set (off, pink, purple, blue, sky blue, teal, green, yellow, orange, red, white) or an integer (0:off, 1:pink, 2:purple, 3:blue, 4:sky blue, 5:teal, 6:green, 7:yellow, 8:orange, 9:red, 10:white).
        """
        if isinstance(color, str):
            color = {"off":0,"pink":1,"purple":2,"blue":3,"sky blue":4,"teal":5,"green":6,"yellow":7,"orange":8,"red":9,"white":10}[color]
        self.outer.write_raw_input(b'\x01\x02\x06\x17\x00\x01\x00\x00\x00\x02\x01')
        self.outer.write_raw_output(bytes([0x06, 0x04, 0x01, color]))

    def rgb(self, r, g, b):
        """
        Sets the light color to the specified RGB values.

        Args:
            r: The red component (0-255).
            g: The green component (0-255).
            b: The blue component (0-255).
        """
        self.outer.write_raw_input(b'\x01\x02\x06\x17\x01\x01\x00\x00\x00\x02\x01')
        self.outer.write_raw_output(bytes([0x06, 0x04, 0x03, r, g, b]))

    def off(self):
        """Turns the light off."""
        self.outer.write_raw_input(b'\x01\x02\x06\x17\x00\x01\x00\x00\x00\x02\x01')
        self.outer.write_raw_output(bytes([0x06, 0x04, 0x01, 0]))

class Sound:
    def __init__(self, outer:WeDo2):
        self.outer = outer

    def note(self, note: str, duration: int):
        """
        Plays a musical note for a specified duration.

        Args:
            note: The musical note to play (e.g., "C4", "D#5").
            duration: The duration to play the note in milliseconds.
        """
        if isinstance(note, int):
            return self.beeb(note, duration)
        if note == ".":
            self.outer.wait(duration / 1000)
            return

        import re
        m = re.match(r'^([A-Ga-g])([#b]?)(\d+)$', str(note).strip())
        if not m:
            raise ValueError(f"Invalid note format: {note}")

        letter = m.group(1).upper()
        accidental = m.group(2)
        octave = int(m.group(3))

        base_semitone = {'C':0, 'D':2, 'E':4, 'F':5, 'G':7, 'A':9, 'B':11}[letter]
        if accidental == '#':
            base_semitone += 1
        elif accidental.lower() == 'b':
            base_semitone -= 1

        midi = base_semitone + (octave + 1) * 12
        freq = round(440 * 2**((midi-69)/12))

        freq_b = freq.to_bytes(2, 'little')
        dur_b  = int(duration).to_bytes(2, 'little')
        
        self.outer.write_raw_output(b'\x05\x02\x04' + freq_b + dur_b)
        self.outer.wait(duration / 1000)

    def beeb(self, freq: int, duration: int):
        """
        Plays a beep sound at the specified frequency and duration.

        Args:
            freq: The frequency of the beep sound in Hz.
            duration: The duration of the beep sound in milliseconds.
        """
        freq_b = freq.to_bytes(2, 'little')
        dur_b  = duration.to_bytes(2, 'little')

        self.outer.write_raw_output(b'\x05\x02\x04' + freq_b + dur_b)
        self.outer.wait(duration / 1000)

    def melody(self, melody:list):
        """
        Plays a melody consisting of multiple notes.

        Args:
            melody: A list of notes to play, where each note can be a string (e.g., list("C4", "D#5")) or a tuple (note, duration).
        """
        for note in melody:
            if isinstance(note, list) or isinstance(note, tuple):
                self.note(note[0], note[1])
            else:
                self.note(note, 500)