# 2022-07-07, Cisco Systems, Inc.
# -*- coding: utf8 -*-
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.

import copy
import errno
import json
import os
import platform
import time
import warnings
import socket
import sys
try:
    from unittest import mock
except ImportError:
    mock = None

import pytest
from pytest import mark

import appdynamics_bindeps.zmq as zmq
from appdynamics_bindeps.zmq.tests import (
    BaseZMQTestCase, SkipTest, have_gevent, GreenTest, skip_pypy
)
from appdynamics_bindeps.zmq.utils.strtypes import unicode

pypy = platform.python_implementation().lower() == 'pypy'
windows = platform.platform().lower().startswith('windows')
on_travis = bool(os.environ.get('TRAVIS_PYTHON_VERSION'))

# polling on windows is slow
POLL_TIMEOUT = 1000 if windows else 100

class TestSocket(BaseZMQTestCase):

    def test_create(self):
        ctx = self.Context()
        s = ctx.socket(zmq.PUB)
        # Superluminal protocol not yet implemented
        self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.bind, 'ftl://a')
        self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.connect, 'ftl://a')
        self.assertRaisesErrno(zmq.EINVAL, s.bind, 'tcp://')
        s.close()
        del ctx
    
    def test_context_manager(self):
        url = 'inproc://a'
        with self.Context() as ctx:
            with ctx.socket(zmq.PUSH) as a:
                a.bind(url)
                with ctx.socket(zmq.PULL) as b:
                    b.connect(url)
                    msg = b'hi'
                    a.send(msg)
                    rcvd = self.recv(b)
                    self.assertEqual(rcvd, msg)
                self.assertEqual(b.closed, True)
            self.assertEqual(a.closed, True)
        self.assertEqual(ctx.closed, True)

    def test_dir(self):
        ctx = self.Context()
        s = ctx.socket(zmq.PUB)
        self.assertTrue('send' in dir(s))
        self.assertTrue('IDENTITY' in dir(s))
        self.assertTrue('AFFINITY' in dir(s))
        self.assertTrue('FD' in dir(s))
        s.close()
        ctx.term()

    @mark.skipif(mock is None, reason="requires unittest.mock")
    def test_mockable(self):
        s = self.socket(zmq.SUB)
        m = mock.Mock(spec=s)
        s.close()

    def test_bind_unicode(self):
        s = self.socket(zmq.PUB)
        p = s.bind_to_random_port(unicode("tcp://*"))

    def test_connect_unicode(self):
        s = self.socket(zmq.PUB)
        s.connect(unicode("tcp://127.0.0.1:5555"))

    def test_bind_to_random_port(self):
        # Check that bind_to_random_port do not hide useful exception
        ctx = self.Context()
        c = ctx.socket(zmq.PUB)
        # Invalid format
        try:
            c.bind_to_random_port('tcp:*')
        except zmq.ZMQError as e:
            self.assertEqual(e.errno, zmq.EINVAL)
        # Invalid protocol
        try:
            c.bind_to_random_port('rand://*')
        except zmq.ZMQError as e:
            self.assertEqual(e.errno, zmq.EPROTONOSUPPORT)

    def test_identity(self):
        s = self.context.socket(zmq.PULL)
        self.sockets.append(s)
        ident = b'identity\0\0'
        s.identity = ident
        self.assertEqual(s.get(zmq.IDENTITY), ident)

    def test_unicode_sockopts(self):
        """test setting/getting sockopts with unicode strings"""
        topic = "tést"
        if str is not unicode:
            topic = topic.decode('utf8')
        p,s = self.create_bound_pair(zmq.PUB, zmq.SUB)
        self.assertEqual(s.send_unicode, s.send_unicode)
        self.assertEqual(p.recv_unicode, p.recv_unicode)
        self.assertRaises(TypeError, s.setsockopt, zmq.SUBSCRIBE, topic)
        self.assertRaises(TypeError, s.setsockopt, zmq.IDENTITY, topic)
        s.setsockopt_unicode(zmq.IDENTITY, topic, 'utf16')
        self.assertRaises(TypeError, s.setsockopt, zmq.AFFINITY, topic)
        s.setsockopt_unicode(zmq.SUBSCRIBE, topic)
        self.assertRaises(TypeError, s.getsockopt_unicode, zmq.AFFINITY)
        self.assertRaisesErrno(zmq.EINVAL, s.getsockopt_unicode, zmq.SUBSCRIBE)
        
        identb = s.getsockopt(zmq.IDENTITY)
        identu = identb.decode('utf16')
        identu2 = s.getsockopt_unicode(zmq.IDENTITY, 'utf16')
        self.assertEqual(identu, identu2)
        time.sleep(0.1) # wait for connection/subscription
        p.send_unicode(topic,zmq.SNDMORE)
        p.send_unicode(topic*2, encoding='latin-1')
        self.assertEqual(topic, s.recv_unicode())
        self.assertEqual(topic*2, s.recv_unicode(encoding='latin-1'))
    
    def test_int_sockopts(self):
        "test integer sockopts"
        v = zmq.zmq_version_info()
        if v < (3,0):
            default_hwm = 0
        else:
            default_hwm = 1000
        p,s = self.create_bound_pair(zmq.PUB, zmq.SUB)
        p.setsockopt(zmq.LINGER, 0)
        self.assertEqual(p.getsockopt(zmq.LINGER), 0)
        p.setsockopt(zmq.LINGER, -1)
        self.assertEqual(p.getsockopt(zmq.LINGER), -1)
        self.assertEqual(p.hwm, default_hwm)
        p.hwm = 11
        self.assertEqual(p.hwm, 11)
        # p.setsockopt(zmq.EVENTS, zmq.POLLIN)
        self.assertEqual(p.getsockopt(zmq.EVENTS), zmq.POLLOUT)
        self.assertRaisesErrno(zmq.EINVAL, p.setsockopt,zmq.EVENTS, 2**7-1)
        self.assertEqual(p.getsockopt(zmq.TYPE), p.socket_type)
        self.assertEqual(p.getsockopt(zmq.TYPE), zmq.PUB)
        self.assertEqual(s.getsockopt(zmq.TYPE), s.socket_type)
        self.assertEqual(s.getsockopt(zmq.TYPE), zmq.SUB)
        
        # check for overflow / wrong type:
        errors = []
        backref = {}
        constants = zmq.constants
        for name in constants.__all__:
            value = getattr(constants, name)
            if isinstance(value, int):
                backref[value] = name
        for opt in zmq.constants.int_sockopts.union(zmq.constants.int64_sockopts):
            sopt = backref[opt]
            if sopt.startswith((
                'ROUTER', 'XPUB', 'TCP', 'FAIL',
                'REQ_', 'CURVE_', 'PROBE_ROUTER',
                'IPC_FILTER', 'GSSAPI', 'STREAM_',
                'VMCI_BUFFER_SIZE', 'VMCI_BUFFER_MIN_SIZE',
                'VMCI_BUFFER_MAX_SIZE', 'VMCI_CONNECT_TIMEOUT',
                )):
                # some sockopts are write-only
                continue
            try:
                n = p.getsockopt(opt)
            except zmq.ZMQError as e:
                errors.append("getsockopt(zmq.%s) raised '%s'."%(sopt, e))
            else:
                if n > 2**31:
                    errors.append("getsockopt(zmq.%s) returned a ridiculous value."
                                    " It is probably the wrong type."%sopt)
        if errors:
            self.fail('\n'.join([''] + errors))
    
    def test_bad_sockopts(self):
        """Test that appropriate errors are raised on bad socket options"""
        s = self.context.socket(zmq.PUB)
        self.sockets.append(s)
        s.setsockopt(zmq.LINGER, 0)
        # unrecognized int sockopts pass through to libzmq, and should raise EINVAL
        self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, 9999, 5)
        self.assertRaisesErrno(zmq.EINVAL, s.getsockopt, 9999)
        # but only int sockopts are allowed through this way, otherwise raise a TypeError
        self.assertRaises(TypeError, s.setsockopt, 9999, b"5")
        # some sockopts are valid in general, but not on every socket:
        self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, zmq.SUBSCRIBE, b'hi')
    
    def test_sockopt_roundtrip(self):
        "test set/getsockopt roundtrip."
        p = self.context.socket(zmq.PUB)
        self.sockets.append(p)
        p.setsockopt(zmq.LINGER, 11)
        self.assertEqual(p.getsockopt(zmq.LINGER), 11)
    
    def test_send_unicode(self):
        "test sending unicode objects"
        a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
        self.sockets.extend([a,b])
        u = "çπ§"
        if str is not unicode:
            u = u.decode('utf8')
        self.assertRaises(TypeError, a.send, u,copy=False)
        self.assertRaises(TypeError, a.send, u,copy=True)
        a.send_unicode(u)
        s = b.recv()
        self.assertEqual(s,u.encode('utf8'))
        self.assertEqual(s.decode('utf8'),u)
        a.send_unicode(u,encoding='utf16')
        s = b.recv_unicode(encoding='utf16')
        self.assertEqual(s,u)
    
    def test_send_multipart_check_type(self):
        "check type on all frames in send_multipart"
        a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
        self.sockets.extend([a,b])
        self.assertRaises(TypeError, a.send_multipart, [b'a', 5])
        a.send_multipart([b'b'])
        rcvd = self.recv_multipart(b)
        self.assertEqual(rcvd, [b'b'])
    
    @skip_pypy
    def test_tracker(self):
        "test the MessageTracker object for tracking when zmq is done with a buffer"
        addr = 'tcp://127.0.0.1'
        # get a port:
        sock = socket.socket()
        sock.bind(('127.0.0.1', 0))
        port = sock.getsockname()[1]
        iface = "%s:%i" % (addr, port)
        sock.close()
        time.sleep(0.1)

        a = self.context.socket(zmq.PUSH)
        b = self.context.socket(zmq.PULL)
        self.sockets.extend([a,b])
        a.connect(iface)
        time.sleep(0.1)
        p1 = a.send(b'something', copy=False, track=True)
        assert isinstance(p1, zmq.MessageTracker)
        assert p1 is zmq._FINISHED_TRACKER
        # small message, should start done
        assert p1.done

        # disable zero-copy threshold
        a.copy_threshold = 0

        p2 = a.send_multipart([b'something', b'else'], copy=False, track=True)
        assert isinstance(p2, zmq.MessageTracker)
        assert not p2.done

        b.bind(iface)
        msg = self.recv_multipart(b)
        for i in range(10):
            if p1.done:
                break
            time.sleep(0.1)
        self.assertEqual(p1.done, True)
        self.assertEqual(msg, [b'something'])
        msg = self.recv_multipart(b)
        for i in range(10):
            if p2.done:
                break
            time.sleep(0.1)
        self.assertEqual(p2.done, True)
        self.assertEqual(msg, [b'something', b'else'])
        m = zmq.Frame(b"again", copy=False, track=True)
        self.assertEqual(m.tracker.done, False)
        p1 = a.send(m, copy=False)
        p2 = a.send(m, copy=False)
        self.assertEqual(m.tracker.done, False)
        self.assertEqual(p1.done, False)
        self.assertEqual(p2.done, False)
        msg = self.recv_multipart(b)
        self.assertEqual(m.tracker.done, False)
        self.assertEqual(msg, [b'again'])
        msg = self.recv_multipart(b)
        self.assertEqual(m.tracker.done, False)
        self.assertEqual(msg, [b'again'])
        self.assertEqual(p1.done, False)
        self.assertEqual(p2.done, False)
        pm = m.tracker
        del m
        for i in range(10):
            if p1.done:
                break
            time.sleep(0.1)
        self.assertEqual(p1.done, True)
        self.assertEqual(p2.done, True)
        m = zmq.Frame(b'something', track=False)
        self.assertRaises(ValueError, a.send, m, copy=False, track=True)

    def test_close(self):
        ctx = self.Context()
        s = ctx.socket(zmq.PUB)
        s.close()
        self.assertRaisesErrno(zmq.ENOTSOCK, s.bind, b'')
        self.assertRaisesErrno(zmq.ENOTSOCK, s.connect, b'')
        self.assertRaisesErrno(zmq.ENOTSOCK, s.setsockopt, zmq.SUBSCRIBE, b'')
        self.assertRaisesErrno(zmq.ENOTSOCK, s.send, b'asdf')
        self.assertRaisesErrno(zmq.ENOTSOCK, s.recv)
        del ctx
    
    def test_attr(self):
        """set setting/getting sockopts as attributes"""
        s = self.context.socket(zmq.DEALER)
        self.sockets.append(s)
        linger = 10
        s.linger = linger
        self.assertEqual(linger, s.linger)
        self.assertEqual(linger, s.getsockopt(zmq.LINGER))
        self.assertEqual(s.fd, s.getsockopt(zmq.FD))
    
    def test_bad_attr(self):
        s = self.context.socket(zmq.DEALER)
        self.sockets.append(s)
        try:
            s.apple='foo'
        except AttributeError:
            pass
        else:
            self.fail("bad setattr should have raised AttributeError")
        try:
            s.apple
        except AttributeError:
            pass
        else:
            self.fail("bad getattr should have raised AttributeError")

    def test_subclass(self):
        """subclasses can assign attributes"""
        class S(zmq.Socket):
            a = None
            def __init__(self, *a, **kw):
                self.a=-1
                super(S, self).__init__(*a, **kw)
        
        s = S(self.context, zmq.REP)
        self.sockets.append(s)
        self.assertEqual(s.a, -1)
        s.a=1
        self.assertEqual(s.a, 1)
        a=s.a
        self.assertEqual(a, 1)
    
    def test_recv_multipart(self):
        a,b = self.create_bound_pair()
        msg = b'hi'
        for i in range(3):
            a.send(msg)
        time.sleep(0.1)
        for i in range(3):
            self.assertEqual(self.recv_multipart(b), [msg])
    
    def test_close_after_destroy(self):
        """s.close() after ctx.destroy() should be fine"""
        ctx = self.Context()
        s = ctx.socket(zmq.REP)
        ctx.destroy()
        # reaper is not instantaneous
        time.sleep(1e-2)
        s.close()
        self.assertTrue(s.closed)
    
    def test_poll(self):
        a,b = self.create_bound_pair()
        tic = time.time()
        evt = a.poll(POLL_TIMEOUT)
        self.assertEqual(evt, 0)
        evt = a.poll(POLL_TIMEOUT, zmq.POLLOUT)
        self.assertEqual(evt, zmq.POLLOUT)
        msg = b'hi'
        a.send(msg)
        evt = b.poll(POLL_TIMEOUT)
        self.assertEqual(evt, zmq.POLLIN)
        msg2 = self.recv(b)
        evt = b.poll(POLL_TIMEOUT)
        self.assertEqual(evt, 0)
        self.assertEqual(msg2, msg)
    
    def test_ipc_path_max_length(self):
        """IPC_PATH_MAX_LEN is a sensible value"""
        if zmq.IPC_PATH_MAX_LEN == 0:
            raise SkipTest("IPC_PATH_MAX_LEN undefined")
        
        msg = "Surprising value for IPC_PATH_MAX_LEN: %s" % zmq.IPC_PATH_MAX_LEN
        self.assertTrue(zmq.IPC_PATH_MAX_LEN > 30, msg)
        self.assertTrue(zmq.IPC_PATH_MAX_LEN < 1025, msg)

    def test_ipc_path_max_length_msg(self):
        if zmq.IPC_PATH_MAX_LEN == 0:
            raise SkipTest("IPC_PATH_MAX_LEN undefined")
        
        s = self.context.socket(zmq.PUB)
        self.sockets.append(s)
        try:
            s.bind('ipc://{0}'.format('a' * (zmq.IPC_PATH_MAX_LEN + 1)))
        except zmq.ZMQError as e:
            self.assertTrue(str(zmq.IPC_PATH_MAX_LEN) in e.strerror)

    @mark.skipif(windows, reason="ipc not supported on Windows.")
    def test_ipc_path_no_such_file_or_directory_message(self):
        """Display the ipc path in case of an ENOENT exception"""
        s = self.context.socket(zmq.PUB)
        self.sockets.append(s)
        invalid_path = '/foo/bar'
        with pytest.raises(zmq.ZMQError) as error:
            s.bind('ipc://{0}'.format(invalid_path))
        assert error.value.errno == errno.ENOENT
        error_message = str(error.value)
        assert invalid_path in error_message
        assert "no such file or directory" in error_message.lower()

    def test_hwm(self):
        zmq3 = zmq.zmq_version_info()[0] >= 3
        for stype in (zmq.PUB, zmq.ROUTER, zmq.SUB, zmq.REQ, zmq.DEALER):
            s = self.context.socket(stype)
            s.hwm = 100
            self.assertEqual(s.hwm, 100)
            if zmq3:
                try:
                    self.assertEqual(s.sndhwm, 100)
                except AttributeError:
                    pass
                try:
                    self.assertEqual(s.rcvhwm, 100)
                except AttributeError:
                    pass
            s.close()

    def test_copy(self):
        s = self.socket(zmq.PUB)
        scopy = copy.copy(s)
        sdcopy = copy.deepcopy(s)
        self.assert_(scopy._shadow)
        self.assert_(sdcopy._shadow)
        self.assertEqual(s.underlying, scopy.underlying)
        self.assertEqual(s.underlying, sdcopy.underlying)
        s.close()

    def test_send_buffer(self):
        a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
        for buffer_type in (memoryview, bytearray):
            rawbytes = str(buffer_type).encode('ascii')
            msg = buffer_type(rawbytes)
            a.send(msg)
            recvd = b.recv()
            assert recvd == rawbytes

    def test_shadow(self):
        p = self.socket(zmq.PUSH)
        p.bind("tcp://127.0.0.1:5555")
        p2 = zmq.Socket.shadow(p.underlying)
        self.assertEqual(p.underlying, p2.underlying)
        s = self.socket(zmq.PULL)
        s2 = zmq.Socket.shadow(s.underlying)
        self.assertNotEqual(s.underlying, p.underlying)
        self.assertEqual(s.underlying, s2.underlying)
        s2.connect("tcp://127.0.0.1:5555")
        sent = b'hi'
        p2.send(sent)
        rcvd = self.recv(s2)
        self.assertEqual(rcvd, sent)

    def test_shadow_pyczmq(self):
        try:
            from pyczmq import zctx, zsocket
        except Exception:
            raise SkipTest("Requires pyczmq")

        ctx = zctx.new()
        ca = zsocket.new(ctx, zmq.PUSH)
        cb = zsocket.new(ctx, zmq.PULL)
        a = zmq.Socket.shadow(ca)
        b = zmq.Socket.shadow(cb)
        a.bind("inproc://a")
        b.connect("inproc://a")
        a.send(b'hi')
        rcvd = self.recv(b)
        self.assertEqual(rcvd, b'hi')

    def test_subscribe_method(self):
        pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
        sub.subscribe('prefix')
        sub.subscribe = 'c'
        p = zmq.Poller()
        p.register(sub, zmq.POLLIN)
        # wait for subscription handshake
        for i in range(100):
            pub.send(b'canary')
            events = p.poll(250)
            if events:
                break
        self.recv(sub)
        pub.send(b'prefixmessage')
        msg = self.recv(sub)
        self.assertEqual(msg, b'prefixmessage')
        sub.unsubscribe('prefix')
        pub.send(b'prefixmessage')
        events = p.poll(1000)
        self.assertEqual(events, [])

    # Travis can't handle how much memory PyPy uses on this test
    @mark.skipif(
        (
            pypy and on_travis
        ) or (
            sys.maxsize < 2**32
        ) or (
            windows
        ),
        reason="only run on 64b and not on Travis."
    )
    @mark.large
    def test_large_send(self):
        c = os.urandom(1)
        N = 2**31 + 1
        try:
            buf = c * N
        except MemoryError as e:
            raise SkipTest("Not enough memory: %s" % e)
        a, b = self.create_bound_pair()
        try:
            a.send(buf, copy=False)
            rcvd = b.recv(copy=False)
        except MemoryError as e:
            raise SkipTest("Not enough memory: %s" % e)
        # sample the front and back of the received message
        # without checking the whole content
        # Python 2: items in memoryview are bytes
        # Python 3: items im memoryview are int
        byte = c if sys.version_info < (3,) else ord(c)
        view = memoryview(rcvd)
        assert len(view) == N
        assert view[0] == byte
        assert view[-1] == byte

    def test_custom_serialize(self):
        a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
        def serialize(msg):
            frames = []
            frames.extend(msg.get('identities', []))
            content = json.dumps(msg['content']).encode('utf8')
            frames.append(content)
            return frames

        def deserialize(frames):
            identities = frames[:-1]
            content = json.loads(frames[-1].decode('utf8'))
            return {
                'identities': identities,
                'content': content,
            }
        
        msg = {
            'content': {
                'a': 5,
                'b': 'bee',
            }
        }
        a.send_serialized(msg, serialize)
        recvd = b.recv_serialized(deserialize)
        assert recvd['content'] == msg['content']
        assert recvd['identities']
        # bounce back, tests identities
        b.send_serialized(recvd, serialize)
        r2 = a.recv_serialized(deserialize)
        assert r2['content'] == msg['content']
        assert not r2['identities']


if have_gevent and not windows:
    import gevent
    
    class TestSocketGreen(GreenTest, TestSocket):
        test_bad_attr = GreenTest.skip_green
        test_close_after_destroy = GreenTest.skip_green
        
        def test_timeout(self):
            a,b = self.create_bound_pair()
            g = gevent.spawn_later(0.5, lambda: a.send(b'hi'))
            timeout = gevent.Timeout(0.1)
            timeout.start()
            self.assertRaises(gevent.Timeout, b.recv)
            g.kill()
        
        @mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
        def test_warn_set_timeo(self):
            s = self.context.socket(zmq.REQ)
            with warnings.catch_warnings(record=True) as w:
                s.rcvtimeo = 5
            s.close()
            self.assertEqual(len(w), 1)
            self.assertEqual(w[0].category, UserWarning)
            

        @mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
        def test_warn_get_timeo(self):
            s = self.context.socket(zmq.REQ)
            with warnings.catch_warnings(record=True) as w:
                s.sndtimeo
            s.close()
            self.assertEqual(len(w), 1)
            self.assertEqual(w[0].category, UserWarning)
