diff --git a/nydus/db/backends/redis.py b/nydus/db/backends/redis.py index f2c19c7..57cef3e 100644 --- a/nydus/db/backends/redis.py +++ b/nydus/db/backends/redis.py @@ -9,7 +9,7 @@ from __future__ import absolute_import from itertools import izip -from redis import Redis as RedisClient +from redis import Redis as RedisClient, StrictRedis as StrictRedisClient from redis import RedisError from nydus.db.backends import BaseConnection, BasePipeline @@ -31,7 +31,11 @@ def execute(self): return dict(izip(self.pending, self.pipe.execute())) -class Redis(BaseConnection): +class RedisBase(BaseConnection): + """ + Base class shared by Redis and StrictRedis + Child classes should implement connect() + """ # Exceptions that can be retried by this backend retryable_exceptions = frozenset([RedisError]) supports_pipelines = True @@ -43,8 +47,8 @@ def __init__(self, num, host='localhost', port=6379, db=0, timeout=None, self.db = db self.unix_socket_path = unix_socket_path self.timeout = timeout - self.__password = password - super(Redis, self).__init__(num) + self._password = password + super(RedisBase, self).__init__(num) @property def identifier(self): @@ -52,14 +56,26 @@ def identifier(self): mapping['klass'] = self.__class__.__name__ return "redis://%(host)s:%(port)s/%(db)s" % mapping + def disconnect(self): + self.connection.disconnect() + + def get_pipeline(self, *args, **kwargs): + return RedisPipeline(self) + + +class Redis(RedisBase): + def connect(self): return RedisClient( host=self.host, port=self.port, db=self.db, - socket_timeout=self.timeout, password=self.__password, + socket_timeout=self.timeout, password=self._password, unix_socket_path=self.unix_socket_path) - def disconnect(self): - self.connection.disconnect() - def get_pipeline(self, *args, **kwargs): - return RedisPipeline(self) +class StrictRedis(RedisBase): + + def connect(self): + return StrictRedisClient( + host=self.host, port=self.port, db=self.db, + socket_timeout=self.timeout, password=self._password, + unix_socket_path=self.unix_socket_path) diff --git a/tests/nydus/db/backends/redis/tests.py b/tests/nydus/db/backends/redis/tests.py index e9edd6f..5d9129a 100644 --- a/tests/nydus/db/backends/redis/tests.py +++ b/tests/nydus/db/backends/redis/tests.py @@ -2,7 +2,7 @@ from nydus.db import create_cluster from nydus.db.base import BaseCluster -from nydus.db.backends.redis import Redis +from nydus.db.backends.redis import Redis, StrictRedis from nydus.testutils import BaseTest, fixture import mock import redis @@ -122,3 +122,119 @@ def test_map_only_runs_on_required_nodes(self, RedisClient): self.assertEquals(RedisClient().pipeline().execute.call_count, 1) RedisClient().pipeline().execute.assert_called_with() + + +class StrictRedisPipelineTest(BaseTest): + @fixture + def cluster(self): + return create_cluster({ + 'backend': 'nydus.db.backends.redis.StrictRedis', + 'router': 'nydus.db.routers.keyvalue.PartitionRouter', + 'hosts': { + 0: {'db': 5}, + 1: {'db': 6}, + 2: {'db': 7}, + 3: {'db': 8}, + 4: {'db': 9}, + } + }) + + # XXX: technically we're testing the Nydus map code, and not ours + def test_pipelined_map(self): + chars = ('a', 'b', 'c', 'd', 'e', 'f') + with self.cluster.map() as conn: + [conn.set(c, i) for i, c in enumerate(chars)] + res = [conn.get(c) for c in chars] + self.assertEqual(range(len(chars)), [int(r) for r in res]) + + def test_map_single_connection(self): + with self.cluster.map() as conn: + conn.set('a', '1') + self.assertEquals(self.cluster.get('a'), '1') + + +class StrictRedisTest(BaseTest): + + def setUp(self): + self.redis = StrictRedis(num=0, db=1) + self.redis.flushdb() + + def test_proxy(self): + self.assertEquals(self.redis.incr('RedisTest_proxy'), 1) + + def test_with_cluster(self): + p = BaseCluster( + backend=StrictRedis, + hosts={0: {'db': 1}}, + ) + self.assertEquals(p.incr('RedisTest_with_cluster'), 1) + + def test_provides_retryable_exceptions(self): + self.assertEquals(StrictRedis.retryable_exceptions, frozenset([redis.RedisError])) + + def test_provides_identifier(self): + self.assertEquals(self.redis.identifier, str(self.redis.identifier)) + + @mock.patch('nydus.db.backends.redis.StrictRedisClient') + def test_client_instantiates_with_kwargs(self, StrictRedisClient): + client = StrictRedis(num=0) + client.connect() + + self.assertEquals(StrictRedisClient.call_count, 1) + StrictRedisClient.assert_any_call(host='localhost', port=6379, db=0, socket_timeout=None, + password=None, unix_socket_path=None) + + @mock.patch('nydus.db.backends.redis.StrictRedisClient') + def test_map_does_pipeline(self, StrictRedisClient): + redis = create_cluster({ + 'backend': 'nydus.db.backends.redis.StrictRedis', + 'router': 'nydus.db.routers.keyvalue.PartitionRouter', + 'hosts': { + 0: {'db': 0}, + 1: {'db': 1}, + } + }) + + with redis.map() as conn: + conn.set('a', 0) + conn.set('d', 1) + + # ensure this was actually called through the pipeline + self.assertFalse(StrictRedisClient().set.called) + + self.assertEquals(StrictRedisClient().pipeline.call_count, 2) + StrictRedisClient().pipeline.assert_called_with() + + self.assertEquals(StrictRedisClient().pipeline().set.call_count, 2) + StrictRedisClient().pipeline().set.assert_any_call('a', 0) + StrictRedisClient().pipeline().set.assert_any_call('d', 1) + + self.assertEquals(StrictRedisClient().pipeline().execute.call_count, 2) + StrictRedisClient().pipeline().execute.assert_called_with() + + @mock.patch('nydus.db.backends.redis.StrictRedisClient') + def test_map_only_runs_on_required_nodes(self, StrictRedisClient): + redis = create_cluster({ + 'engine': 'nydus.db.backends.redis.StrictRedis', + 'router': 'nydus.db.routers.keyvalue.PartitionRouter', + 'hosts': { + 0: {'db': 0}, + 1: {'db': 1}, + } + }) + with redis.map() as conn: + conn.set('a', 0) + conn.set('b', 1) + + # ensure this was actually called through the pipeline + self.assertFalse(StrictRedisClient().set.called) + + self.assertEquals(StrictRedisClient().pipeline.call_count, 1) + StrictRedisClient().pipeline.assert_called_with() + + self.assertEquals(StrictRedisClient().pipeline().set.call_count, 2) + StrictRedisClient().pipeline().set.assert_any_call('a', 0) + StrictRedisClient().pipeline().set.assert_any_call('b', 1) + + self.assertEquals(StrictRedisClient().pipeline().execute.call_count, 1) + StrictRedisClient().pipeline().execute.assert_called_with()