# odbc test suite kindly contributed by Frank Millman.
import os
import sys
import tempfile
import unittest

import odbc
import pythoncom
from pywin32_testutil import TestSkipped
from win32com.client import constants

# We use the DAO ODBC driver
from win32com.client.gencache import EnsureDispatch


class TestStuff(unittest.TestCase):
    def setUp(self):
        self.tablename = "pywin32test_users"
        self.db_filename = None
        self.conn = self.cur = None
        try:
            # Test any database if a connection string is supplied...
            conn_str = os.environ["TEST_ODBC_CONNECTION_STRING"]
        except KeyError:
            # Create a local MSAccess DB for testing.
            self.db_filename = tempfile.NamedTemporaryFile().name + ".mdb"

            # Create a brand-new database - what is the story with these?
            for suffix in (".36", ".35", ".30"):
                try:
                    dbe = EnsureDispatch("DAO.DBEngine" + suffix)
                    break
                except pythoncom.com_error:
                    pass
            else:
                raise TestSkipped("Can't find a DB engine")

            workspace = dbe.Workspaces(0)

            newdb = workspace.CreateDatabase(
                self.db_filename, constants.dbLangGeneral, constants.dbEncrypt
            )

            newdb.Close()

            conn_str = (
                "Driver={{Microsoft Access Driver (*.mdb)}};dbq={};Uid=;Pwd=;".format(
                    self.db_filename,
                )
            )
        # print("Connection string:", conn_str)
        self.conn = odbc.odbc(conn_str)
        # And we expect a 'users' table for these tests.
        self.cur = self.conn.cursor()
        ## self.cur.setoutputsize(1000)
        try:
            self.cur.execute("""drop table %s""" % self.tablename)
        except (odbc.error, odbc.progError):
            pass

        ## This needs to be adjusted for sql server syntax for unicode fields
        ##  - memo -> TEXT
        ##  - varchar -> nvarchar
        self.assertEqual(
            self.cur.execute(
                """create table %s (
                    userid varchar(25),
                    username varchar(25),
                    bitfield bit,
                    intfield integer,
                    floatfield float,
                    datefield datetime,
                    rawfield varbinary(100),
                    longtextfield memo,
                    longbinaryfield image
            )"""
                % self.tablename
            ),
            -1,
        )

    def tearDown(self):
        if self.cur is not None:
            try:
                self.cur.execute("""drop table %s""" % self.tablename)
            except (odbc.error, odbc.progError) as why:
                print("Failed to delete test table %s" % self.tablename, why)

            self.cur.close()
            self.cur = None
        if self.conn is not None:
            self.conn.close()
            self.conn = None
        if self.db_filename is not None:
            try:
                os.unlink(self.db_filename)
            except OSError:
                pass

    def test_insert_select(self, userid="Frank", username="Frank Millman"):
        self.assertEqual(
            self.cur.execute(
                "insert into %s (userid, username) \
            values (?,?)"
                % self.tablename,
                [userid, username],
            ),
            1,
        )
        self.assertEqual(
            self.cur.execute(
                "select * from %s \
            where userid = ?"
                % self.tablename,
                [userid.lower()],
            ),
            0,
        )
        self.assertEqual(
            self.cur.execute(
                "select * from %s \
            where username = ?"
                % self.tablename,
                [username.lower()],
            ),
            0,
        )

    def test_insert_select_unicode(self, userid="Frank", username="Frank Millman"):
        self.assertEqual(
            self.cur.execute(
                "insert into %s (userid, username)\
            values (?,?)"
                % self.tablename,
                [userid, username],
            ),
            1,
        )
        self.assertEqual(
            self.cur.execute(
                "select * from %s \
            where userid = ?"
                % self.tablename,
                [userid.lower()],
            ),
            0,
        )
        self.assertEqual(
            self.cur.execute(
                "select * from %s \
            where username = ?"
                % self.tablename,
                [username.lower()],
            ),
            0,
        )

    def test_insert_select_unicode_ext(self):
        userid = "t-\xe0\xf2"
        username = "test-\xe0\xf2 name"
        self.test_insert_select_unicode(userid, username)

    def _test_val(self, fieldName, value):
        for x in range(100):
            self.cur.execute("delete from %s where userid='Frank'" % self.tablename)
            self.assertEqual(
                self.cur.execute(
                    f"insert into {self.tablename} (userid, {fieldName}) values (?,?)",
                    ["Frank", value],
                ),
                1,
            )
            self.cur.execute(
                f"select {fieldName} from {self.tablename} where userid = ?",
                ["Frank"],
            )
            rows = self.cur.fetchmany()
            self.assertEqual(1, len(rows))
            row = rows[0]
            self.assertEqual(row[0], value)

    def testBit(self):
        self._test_val("bitfield", 1)
        self._test_val("bitfield", 0)

    def testInt(self):
        self._test_val("intfield", 1)
        self._test_val("intfield", 0)
        self._test_val("intfield", sys.maxsize)

    def testFloat(self):
        self._test_val("floatfield", 1.01)
        self._test_val("floatfield", 0)

    def testVarchar(
        self,
    ):
        self._test_val("username", "foo")

    def testLongVarchar(self):
        """Test a long text field in excess of internal cursor data size (65536)"""
        self._test_val("longtextfield", "abc" * 70000)

    def testLongBinary(self):
        """Test a long raw field in excess of internal cursor data size (65536)"""
        self._test_val("longbinaryfield", memoryview(b"\0\1\2" * 70000))

    def testRaw(self):
        ## Test binary data
        self._test_val("rawfield", memoryview(b"\1\2\3\4\0\5\6\7"))

    def test_widechar(self):
        """Test a unicode character that would be mangled if bound as plain character.
        For example, previously the below was returned as ascii 'a'
        """
        self._test_val("username", "\u0101")

    def testDates(self):
        import datetime

        for v in ((1900, 12, 25, 23, 39, 59),):
            d = datetime.datetime(*v)
            self._test_val("datefield", d)

    def test_set_nonzero_length(self):
        self.assertEqual(
            self.cur.execute(
                "insert into %s (userid,username) values (?,?)" % self.tablename,
                ["Frank", "Frank Millman"],
            ),
            1,
        )
        self.assertEqual(
            self.cur.execute("update %s set username = ?" % self.tablename, ["Frank"]),
            1,
        )
        self.assertEqual(self.cur.execute("select * from %s" % self.tablename), 0)
        self.assertEqual(len(self.cur.fetchone()[1]), 5)

    def test_set_zero_length(self):
        self.assertEqual(
            self.cur.execute(
                "insert into %s (userid,username) values (?,?)" % self.tablename,
                [b"Frank", ""],
            ),
            1,
        )
        self.assertEqual(self.cur.execute("select * from %s" % self.tablename), 0)
        self.assertEqual(len(self.cur.fetchone()[1]), 0)

    def test_set_zero_length_unicode(self):
        self.assertEqual(
            self.cur.execute(
                "insert into %s (userid,username) values (?,?)" % self.tablename,
                ["Frank", ""],
            ),
            1,
        )
        self.assertEqual(self.cur.execute("select * from %s" % self.tablename), 0)
        self.assertEqual(len(self.cur.fetchone()[1]), 0)


if __name__ == "__main__":
    unittest.main()
