mirror of
https://github.com/neogeek23/drawshare.git
synced 2026-02-04 02:58:16 +00:00
1118 lines
39 KiB
Python
1118 lines
39 KiB
Python
|
|
from __future__ import with_statement
|
||
|
|
from distutils.version import StrictVersion
|
||
|
|
from itertools import chain
|
||
|
|
import os
|
||
|
|
import socket
|
||
|
|
import sys
|
||
|
|
import threading
|
||
|
|
import warnings
|
||
|
|
|
||
|
|
try:
|
||
|
|
import ssl
|
||
|
|
ssl_available = True
|
||
|
|
except ImportError:
|
||
|
|
ssl_available = False
|
||
|
|
|
||
|
|
from redis._compat import (b, xrange, imap, byte_to_chr, unicode, bytes, long,
|
||
|
|
BytesIO, nativestr, basestring, iteritems,
|
||
|
|
LifoQueue, Empty, Full, urlparse, parse_qs,
|
||
|
|
recv, recv_into, select, unquote)
|
||
|
|
from redis.exceptions import (
|
||
|
|
RedisError,
|
||
|
|
ConnectionError,
|
||
|
|
TimeoutError,
|
||
|
|
BusyLoadingError,
|
||
|
|
ResponseError,
|
||
|
|
InvalidResponse,
|
||
|
|
AuthenticationError,
|
||
|
|
NoScriptError,
|
||
|
|
ExecAbortError,
|
||
|
|
ReadOnlyError
|
||
|
|
)
|
||
|
|
from redis.utils import HIREDIS_AVAILABLE
|
||
|
|
if HIREDIS_AVAILABLE:
|
||
|
|
import hiredis
|
||
|
|
|
||
|
|
hiredis_version = StrictVersion(hiredis.__version__)
|
||
|
|
HIREDIS_SUPPORTS_CALLABLE_ERRORS = \
|
||
|
|
hiredis_version >= StrictVersion('0.1.3')
|
||
|
|
HIREDIS_SUPPORTS_BYTE_BUFFER = \
|
||
|
|
hiredis_version >= StrictVersion('0.1.4')
|
||
|
|
|
||
|
|
if not HIREDIS_SUPPORTS_BYTE_BUFFER:
|
||
|
|
msg = ("redis-py works best with hiredis >= 0.1.4. You're running "
|
||
|
|
"hiredis %s. Please consider upgrading." % hiredis.__version__)
|
||
|
|
warnings.warn(msg)
|
||
|
|
|
||
|
|
HIREDIS_USE_BYTE_BUFFER = True
|
||
|
|
# only use byte buffer if hiredis supports it and the Python version
|
||
|
|
# is >= 2.7
|
||
|
|
if not HIREDIS_SUPPORTS_BYTE_BUFFER or (
|
||
|
|
sys.version_info[0] == 2 and sys.version_info[1] < 7):
|
||
|
|
HIREDIS_USE_BYTE_BUFFER = False
|
||
|
|
|
||
|
|
SYM_STAR = b('*')
|
||
|
|
SYM_DOLLAR = b('$')
|
||
|
|
SYM_CRLF = b('\r\n')
|
||
|
|
SYM_EMPTY = b('')
|
||
|
|
|
||
|
|
SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
|
||
|
|
|
||
|
|
|
||
|
|
class Token(object):
|
||
|
|
"""
|
||
|
|
Literal strings in Redis commands, such as the command names and any
|
||
|
|
hard-coded arguments are wrapped in this class so we know not to apply
|
||
|
|
and encoding rules on them.
|
||
|
|
"""
|
||
|
|
|
||
|
|
_cache = {}
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_token(cls, value):
|
||
|
|
"Gets a cached token object or creates a new one if not already cached"
|
||
|
|
|
||
|
|
# Use try/except because after running for a short time most tokens
|
||
|
|
# should already be cached
|
||
|
|
try:
|
||
|
|
return cls._cache[value]
|
||
|
|
except KeyError:
|
||
|
|
token = Token(value)
|
||
|
|
cls._cache[value] = token
|
||
|
|
return token
|
||
|
|
|
||
|
|
def __init__(self, value):
|
||
|
|
if isinstance(value, Token):
|
||
|
|
value = value.value
|
||
|
|
self.value = value
|
||
|
|
self.encoded_value = b(value)
|
||
|
|
|
||
|
|
def __repr__(self):
|
||
|
|
return self.value
|
||
|
|
|
||
|
|
def __str__(self):
|
||
|
|
return self.value
|
||
|
|
|
||
|
|
|
||
|
|
class Encoder(object):
|
||
|
|
"Encode strings to bytes and decode bytes to strings"
|
||
|
|
|
||
|
|
def __init__(self, encoding, encoding_errors, decode_responses):
|
||
|
|
self.encoding = encoding
|
||
|
|
self.encoding_errors = encoding_errors
|
||
|
|
self.decode_responses = decode_responses
|
||
|
|
|
||
|
|
def encode(self, value):
|
||
|
|
"Return a bytestring representation of the value"
|
||
|
|
if isinstance(value, Token):
|
||
|
|
return value.encoded_value
|
||
|
|
elif isinstance(value, bytes):
|
||
|
|
return value
|
||
|
|
elif isinstance(value, (int, long)):
|
||
|
|
value = b(str(value))
|
||
|
|
elif isinstance(value, float):
|
||
|
|
value = b(repr(value))
|
||
|
|
elif not isinstance(value, basestring):
|
||
|
|
# an object we don't know how to deal with. default to unicode()
|
||
|
|
value = unicode(value)
|
||
|
|
if isinstance(value, unicode):
|
||
|
|
value = value.encode(self.encoding, self.encoding_errors)
|
||
|
|
return value
|
||
|
|
|
||
|
|
def decode(self, value, force=False):
|
||
|
|
"Return a unicode string from the byte representation"
|
||
|
|
if (self.decode_responses or force) and isinstance(value, bytes):
|
||
|
|
value = value.decode(self.encoding, self.encoding_errors)
|
||
|
|
return value
|
||
|
|
|
||
|
|
|
||
|
|
class BaseParser(object):
|
||
|
|
EXCEPTION_CLASSES = {
|
||
|
|
'ERR': {
|
||
|
|
'max number of clients reached': ConnectionError
|
||
|
|
},
|
||
|
|
'EXECABORT': ExecAbortError,
|
||
|
|
'LOADING': BusyLoadingError,
|
||
|
|
'NOSCRIPT': NoScriptError,
|
||
|
|
'READONLY': ReadOnlyError,
|
||
|
|
}
|
||
|
|
|
||
|
|
def parse_error(self, response):
|
||
|
|
"Parse an error response"
|
||
|
|
error_code = response.split(' ')[0]
|
||
|
|
if error_code in self.EXCEPTION_CLASSES:
|
||
|
|
response = response[len(error_code) + 1:]
|
||
|
|
exception_class = self.EXCEPTION_CLASSES[error_code]
|
||
|
|
if isinstance(exception_class, dict):
|
||
|
|
exception_class = exception_class.get(response, ResponseError)
|
||
|
|
return exception_class(response)
|
||
|
|
return ResponseError(response)
|
||
|
|
|
||
|
|
|
||
|
|
class SocketBuffer(object):
|
||
|
|
def __init__(self, socket, socket_read_size):
|
||
|
|
self._sock = socket
|
||
|
|
self.socket_read_size = socket_read_size
|
||
|
|
self._buffer = BytesIO()
|
||
|
|
# number of bytes written to the buffer from the socket
|
||
|
|
self.bytes_written = 0
|
||
|
|
# number of bytes read from the buffer
|
||
|
|
self.bytes_read = 0
|
||
|
|
|
||
|
|
@property
|
||
|
|
def length(self):
|
||
|
|
return self.bytes_written - self.bytes_read
|
||
|
|
|
||
|
|
def _read_from_socket(self, length=None):
|
||
|
|
socket_read_size = self.socket_read_size
|
||
|
|
buf = self._buffer
|
||
|
|
buf.seek(self.bytes_written)
|
||
|
|
marker = 0
|
||
|
|
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
data = recv(self._sock, socket_read_size)
|
||
|
|
# an empty string indicates the server shutdown the socket
|
||
|
|
if isinstance(data, bytes) and len(data) == 0:
|
||
|
|
raise socket.error(SERVER_CLOSED_CONNECTION_ERROR)
|
||
|
|
buf.write(data)
|
||
|
|
data_length = len(data)
|
||
|
|
self.bytes_written += data_length
|
||
|
|
marker += data_length
|
||
|
|
|
||
|
|
if length is not None and length > marker:
|
||
|
|
continue
|
||
|
|
break
|
||
|
|
except socket.timeout:
|
||
|
|
raise TimeoutError("Timeout reading from socket")
|
||
|
|
except socket.error:
|
||
|
|
e = sys.exc_info()[1]
|
||
|
|
raise ConnectionError("Error while reading from socket: %s" %
|
||
|
|
(e.args,))
|
||
|
|
|
||
|
|
def read(self, length):
|
||
|
|
length = length + 2 # make sure to read the \r\n terminator
|
||
|
|
# make sure we've read enough data from the socket
|
||
|
|
if length > self.length:
|
||
|
|
self._read_from_socket(length - self.length)
|
||
|
|
|
||
|
|
self._buffer.seek(self.bytes_read)
|
||
|
|
data = self._buffer.read(length)
|
||
|
|
self.bytes_read += len(data)
|
||
|
|
|
||
|
|
# purge the buffer when we've consumed it all so it doesn't
|
||
|
|
# grow forever
|
||
|
|
if self.bytes_read == self.bytes_written:
|
||
|
|
self.purge()
|
||
|
|
|
||
|
|
return data[:-2]
|
||
|
|
|
||
|
|
def readline(self):
|
||
|
|
buf = self._buffer
|
||
|
|
buf.seek(self.bytes_read)
|
||
|
|
data = buf.readline()
|
||
|
|
while not data.endswith(SYM_CRLF):
|
||
|
|
# there's more data in the socket that we need
|
||
|
|
self._read_from_socket()
|
||
|
|
buf.seek(self.bytes_read)
|
||
|
|
data = buf.readline()
|
||
|
|
|
||
|
|
self.bytes_read += len(data)
|
||
|
|
|
||
|
|
# purge the buffer when we've consumed it all so it doesn't
|
||
|
|
# grow forever
|
||
|
|
if self.bytes_read == self.bytes_written:
|
||
|
|
self.purge()
|
||
|
|
|
||
|
|
return data[:-2]
|
||
|
|
|
||
|
|
def purge(self):
|
||
|
|
self._buffer.seek(0)
|
||
|
|
self._buffer.truncate()
|
||
|
|
self.bytes_written = 0
|
||
|
|
self.bytes_read = 0
|
||
|
|
|
||
|
|
def close(self):
|
||
|
|
try:
|
||
|
|
self.purge()
|
||
|
|
self._buffer.close()
|
||
|
|
except:
|
||
|
|
# issue #633 suggests the purge/close somehow raised a
|
||
|
|
# BadFileDescriptor error. Perhaps the client ran out of
|
||
|
|
# memory or something else? It's probably OK to ignore
|
||
|
|
# any error being raised from purge/close since we're
|
||
|
|
# removing the reference to the instance below.
|
||
|
|
pass
|
||
|
|
self._buffer = None
|
||
|
|
self._sock = None
|
||
|
|
|
||
|
|
|
||
|
|
class PythonParser(BaseParser):
|
||
|
|
"Plain Python parsing class"
|
||
|
|
def __init__(self, socket_read_size):
|
||
|
|
self.socket_read_size = socket_read_size
|
||
|
|
self.encoder = None
|
||
|
|
self._sock = None
|
||
|
|
self._buffer = None
|
||
|
|
|
||
|
|
def __del__(self):
|
||
|
|
try:
|
||
|
|
self.on_disconnect()
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
|
||
|
|
def on_connect(self, connection):
|
||
|
|
"Called when the socket connects"
|
||
|
|
self._sock = connection._sock
|
||
|
|
self._buffer = SocketBuffer(self._sock, self.socket_read_size)
|
||
|
|
self.encoder = connection.encoder
|
||
|
|
|
||
|
|
def on_disconnect(self):
|
||
|
|
"Called when the socket disconnects"
|
||
|
|
if self._sock is not None:
|
||
|
|
self._sock.close()
|
||
|
|
self._sock = None
|
||
|
|
if self._buffer is not None:
|
||
|
|
self._buffer.close()
|
||
|
|
self._buffer = None
|
||
|
|
self.encoder = None
|
||
|
|
|
||
|
|
def can_read(self):
|
||
|
|
return self._buffer and bool(self._buffer.length)
|
||
|
|
|
||
|
|
def read_response(self):
|
||
|
|
response = self._buffer.readline()
|
||
|
|
if not response:
|
||
|
|
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||
|
|
|
||
|
|
byte, response = byte_to_chr(response[0]), response[1:]
|
||
|
|
|
||
|
|
if byte not in ('-', '+', ':', '$', '*'):
|
||
|
|
raise InvalidResponse("Protocol Error: %s, %s" %
|
||
|
|
(str(byte), str(response)))
|
||
|
|
|
||
|
|
# server returned an error
|
||
|
|
if byte == '-':
|
||
|
|
response = nativestr(response)
|
||
|
|
error = self.parse_error(response)
|
||
|
|
# if the error is a ConnectionError, raise immediately so the user
|
||
|
|
# is notified
|
||
|
|
if isinstance(error, ConnectionError):
|
||
|
|
raise error
|
||
|
|
# otherwise, we're dealing with a ResponseError that might belong
|
||
|
|
# inside a pipeline response. the connection's read_response()
|
||
|
|
# and/or the pipeline's execute() will raise this error if
|
||
|
|
# necessary, so just return the exception instance here.
|
||
|
|
return error
|
||
|
|
# single value
|
||
|
|
elif byte == '+':
|
||
|
|
pass
|
||
|
|
# int value
|
||
|
|
elif byte == ':':
|
||
|
|
response = long(response)
|
||
|
|
# bulk response
|
||
|
|
elif byte == '$':
|
||
|
|
length = int(response)
|
||
|
|
if length == -1:
|
||
|
|
return None
|
||
|
|
response = self._buffer.read(length)
|
||
|
|
# multi-bulk response
|
||
|
|
elif byte == '*':
|
||
|
|
length = int(response)
|
||
|
|
if length == -1:
|
||
|
|
return None
|
||
|
|
response = [self.read_response() for i in xrange(length)]
|
||
|
|
if isinstance(response, bytes):
|
||
|
|
response = self.encoder.decode(response)
|
||
|
|
return response
|
||
|
|
|
||
|
|
|
||
|
|
class HiredisParser(BaseParser):
|
||
|
|
"Parser class for connections using Hiredis"
|
||
|
|
def __init__(self, socket_read_size):
|
||
|
|
if not HIREDIS_AVAILABLE:
|
||
|
|
raise RedisError("Hiredis is not installed")
|
||
|
|
self.socket_read_size = socket_read_size
|
||
|
|
|
||
|
|
if HIREDIS_USE_BYTE_BUFFER:
|
||
|
|
self._buffer = bytearray(socket_read_size)
|
||
|
|
|
||
|
|
def __del__(self):
|
||
|
|
try:
|
||
|
|
self.on_disconnect()
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
|
||
|
|
def on_connect(self, connection):
|
||
|
|
self._sock = connection._sock
|
||
|
|
kwargs = {
|
||
|
|
'protocolError': InvalidResponse,
|
||
|
|
'replyError': self.parse_error,
|
||
|
|
}
|
||
|
|
|
||
|
|
# hiredis < 0.1.3 doesn't support functions that create exceptions
|
||
|
|
if not HIREDIS_SUPPORTS_CALLABLE_ERRORS:
|
||
|
|
kwargs['replyError'] = ResponseError
|
||
|
|
|
||
|
|
if connection.encoder.decode_responses:
|
||
|
|
kwargs['encoding'] = connection.encoder.encoding
|
||
|
|
self._reader = hiredis.Reader(**kwargs)
|
||
|
|
self._next_response = False
|
||
|
|
|
||
|
|
def on_disconnect(self):
|
||
|
|
self._sock = None
|
||
|
|
self._reader = None
|
||
|
|
self._next_response = False
|
||
|
|
|
||
|
|
def can_read(self):
|
||
|
|
if not self._reader:
|
||
|
|
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||
|
|
|
||
|
|
if self._next_response is False:
|
||
|
|
self._next_response = self._reader.gets()
|
||
|
|
return self._next_response is not False
|
||
|
|
|
||
|
|
def read_response(self):
|
||
|
|
if not self._reader:
|
||
|
|
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||
|
|
|
||
|
|
# _next_response might be cached from a can_read() call
|
||
|
|
if self._next_response is not False:
|
||
|
|
response = self._next_response
|
||
|
|
self._next_response = False
|
||
|
|
return response
|
||
|
|
|
||
|
|
response = self._reader.gets()
|
||
|
|
socket_read_size = self.socket_read_size
|
||
|
|
while response is False:
|
||
|
|
try:
|
||
|
|
if HIREDIS_USE_BYTE_BUFFER:
|
||
|
|
bufflen = recv_into(self._sock, self._buffer)
|
||
|
|
if bufflen == 0:
|
||
|
|
raise socket.error(SERVER_CLOSED_CONNECTION_ERROR)
|
||
|
|
else:
|
||
|
|
buffer = recv(self._sock, socket_read_size)
|
||
|
|
# an empty string indicates the server shutdown the socket
|
||
|
|
if not isinstance(buffer, bytes) or len(buffer) == 0:
|
||
|
|
raise socket.error(SERVER_CLOSED_CONNECTION_ERROR)
|
||
|
|
except socket.timeout:
|
||
|
|
raise TimeoutError("Timeout reading from socket")
|
||
|
|
except socket.error:
|
||
|
|
e = sys.exc_info()[1]
|
||
|
|
raise ConnectionError("Error while reading from socket: %s" %
|
||
|
|
(e.args,))
|
||
|
|
if HIREDIS_USE_BYTE_BUFFER:
|
||
|
|
self._reader.feed(self._buffer, 0, bufflen)
|
||
|
|
else:
|
||
|
|
self._reader.feed(buffer)
|
||
|
|
response = self._reader.gets()
|
||
|
|
# if an older version of hiredis is installed, we need to attempt
|
||
|
|
# to convert ResponseErrors to their appropriate types.
|
||
|
|
if not HIREDIS_SUPPORTS_CALLABLE_ERRORS:
|
||
|
|
if isinstance(response, ResponseError):
|
||
|
|
response = self.parse_error(response.args[0])
|
||
|
|
elif isinstance(response, list) and response and \
|
||
|
|
isinstance(response[0], ResponseError):
|
||
|
|
response[0] = self.parse_error(response[0].args[0])
|
||
|
|
# if the response is a ConnectionError or the response is a list and
|
||
|
|
# the first item is a ConnectionError, raise it as something bad
|
||
|
|
# happened
|
||
|
|
if isinstance(response, ConnectionError):
|
||
|
|
raise response
|
||
|
|
elif isinstance(response, list) and response and \
|
||
|
|
isinstance(response[0], ConnectionError):
|
||
|
|
raise response[0]
|
||
|
|
return response
|
||
|
|
|
||
|
|
|
||
|
|
if HIREDIS_AVAILABLE:
|
||
|
|
DefaultParser = HiredisParser
|
||
|
|
else:
|
||
|
|
DefaultParser = PythonParser
|
||
|
|
|
||
|
|
|
||
|
|
class Connection(object):
|
||
|
|
"Manages TCP communication to and from a Redis server"
|
||
|
|
description_format = "Connection<host=%(host)s,port=%(port)s,db=%(db)s>"
|
||
|
|
|
||
|
|
def __init__(self, host='localhost', port=6379, db=0, password=None,
|
||
|
|
socket_timeout=None, socket_connect_timeout=None,
|
||
|
|
socket_keepalive=False, socket_keepalive_options=None,
|
||
|
|
retry_on_timeout=False, encoding='utf-8',
|
||
|
|
encoding_errors='strict', decode_responses=False,
|
||
|
|
parser_class=DefaultParser, socket_read_size=65536):
|
||
|
|
self.pid = os.getpid()
|
||
|
|
self.host = host
|
||
|
|
self.port = int(port)
|
||
|
|
self.db = db
|
||
|
|
self.password = password
|
||
|
|
self.socket_timeout = socket_timeout
|
||
|
|
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
|
||
|
|
self.socket_keepalive = socket_keepalive
|
||
|
|
self.socket_keepalive_options = socket_keepalive_options or {}
|
||
|
|
self.retry_on_timeout = retry_on_timeout
|
||
|
|
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
|
||
|
|
self._sock = None
|
||
|
|
self._parser = parser_class(socket_read_size=socket_read_size)
|
||
|
|
self._description_args = {
|
||
|
|
'host': self.host,
|
||
|
|
'port': self.port,
|
||
|
|
'db': self.db,
|
||
|
|
}
|
||
|
|
self._connect_callbacks = []
|
||
|
|
|
||
|
|
def __repr__(self):
|
||
|
|
return self.description_format % self._description_args
|
||
|
|
|
||
|
|
def __del__(self):
|
||
|
|
try:
|
||
|
|
self.disconnect()
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
|
||
|
|
def register_connect_callback(self, callback):
|
||
|
|
self._connect_callbacks.append(callback)
|
||
|
|
|
||
|
|
def clear_connect_callbacks(self):
|
||
|
|
self._connect_callbacks = []
|
||
|
|
|
||
|
|
def connect(self):
|
||
|
|
"Connects to the Redis server if not already connected"
|
||
|
|
if self._sock:
|
||
|
|
return
|
||
|
|
try:
|
||
|
|
sock = self._connect()
|
||
|
|
except socket.timeout:
|
||
|
|
raise TimeoutError("Timeout connecting to server")
|
||
|
|
except socket.error:
|
||
|
|
e = sys.exc_info()[1]
|
||
|
|
raise ConnectionError(self._error_message(e))
|
||
|
|
|
||
|
|
self._sock = sock
|
||
|
|
try:
|
||
|
|
self.on_connect()
|
||
|
|
except RedisError:
|
||
|
|
# clean up after any error in on_connect
|
||
|
|
self.disconnect()
|
||
|
|
raise
|
||
|
|
|
||
|
|
# run any user callbacks. right now the only internal callback
|
||
|
|
# is for pubsub channel/pattern resubscription
|
||
|
|
for callback in self._connect_callbacks:
|
||
|
|
callback(self)
|
||
|
|
|
||
|
|
def _connect(self):
|
||
|
|
"Create a TCP socket connection"
|
||
|
|
# we want to mimic what socket.create_connection does to support
|
||
|
|
# ipv4/ipv6, but we want to set options prior to calling
|
||
|
|
# socket.connect()
|
||
|
|
err = None
|
||
|
|
for res in socket.getaddrinfo(self.host, self.port, 0,
|
||
|
|
socket.SOCK_STREAM):
|
||
|
|
family, socktype, proto, canonname, socket_address = res
|
||
|
|
sock = None
|
||
|
|
try:
|
||
|
|
sock = socket.socket(family, socktype, proto)
|
||
|
|
# TCP_NODELAY
|
||
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||
|
|
|
||
|
|
# TCP_KEEPALIVE
|
||
|
|
if self.socket_keepalive:
|
||
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||
|
|
for k, v in iteritems(self.socket_keepalive_options):
|
||
|
|
sock.setsockopt(socket.SOL_TCP, k, v)
|
||
|
|
|
||
|
|
# set the socket_connect_timeout before we connect
|
||
|
|
sock.settimeout(self.socket_connect_timeout)
|
||
|
|
|
||
|
|
# connect
|
||
|
|
sock.connect(socket_address)
|
||
|
|
|
||
|
|
# set the socket_timeout now that we're connected
|
||
|
|
sock.settimeout(self.socket_timeout)
|
||
|
|
return sock
|
||
|
|
|
||
|
|
except socket.error as _:
|
||
|
|
err = _
|
||
|
|
if sock is not None:
|
||
|
|
sock.close()
|
||
|
|
|
||
|
|
if err is not None:
|
||
|
|
raise err
|
||
|
|
raise socket.error("socket.getaddrinfo returned an empty list")
|
||
|
|
|
||
|
|
def _error_message(self, exception):
|
||
|
|
# args for socket.error can either be (errno, "message")
|
||
|
|
# or just "message"
|
||
|
|
if len(exception.args) == 1:
|
||
|
|
return "Error connecting to %s:%s. %s." % \
|
||
|
|
(self.host, self.port, exception.args[0])
|
||
|
|
else:
|
||
|
|
return "Error %s connecting to %s:%s. %s." % \
|
||
|
|
(exception.args[0], self.host, self.port, exception.args[1])
|
||
|
|
|
||
|
|
def on_connect(self):
|
||
|
|
"Initialize the connection, authenticate and select a database"
|
||
|
|
self._parser.on_connect(self)
|
||
|
|
|
||
|
|
# if a password is specified, authenticate
|
||
|
|
if self.password:
|
||
|
|
self.send_command('AUTH', self.password)
|
||
|
|
if nativestr(self.read_response()) != 'OK':
|
||
|
|
raise AuthenticationError('Invalid Password')
|
||
|
|
|
||
|
|
# if a database is specified, switch to it
|
||
|
|
if self.db:
|
||
|
|
self.send_command('SELECT', self.db)
|
||
|
|
if nativestr(self.read_response()) != 'OK':
|
||
|
|
raise ConnectionError('Invalid Database')
|
||
|
|
|
||
|
|
def disconnect(self):
|
||
|
|
"Disconnects from the Redis server"
|
||
|
|
self._parser.on_disconnect()
|
||
|
|
if self._sock is None:
|
||
|
|
return
|
||
|
|
try:
|
||
|
|
self._sock.shutdown(socket.SHUT_RDWR)
|
||
|
|
self._sock.close()
|
||
|
|
except socket.error:
|
||
|
|
pass
|
||
|
|
self._sock = None
|
||
|
|
|
||
|
|
def send_packed_command(self, command):
|
||
|
|
"Send an already packed command to the Redis server"
|
||
|
|
if not self._sock:
|
||
|
|
self.connect()
|
||
|
|
try:
|
||
|
|
if isinstance(command, str):
|
||
|
|
command = [command]
|
||
|
|
for item in command:
|
||
|
|
self._sock.sendall(item)
|
||
|
|
except socket.timeout:
|
||
|
|
self.disconnect()
|
||
|
|
raise TimeoutError("Timeout writing to socket")
|
||
|
|
except socket.error:
|
||
|
|
e = sys.exc_info()[1]
|
||
|
|
self.disconnect()
|
||
|
|
if len(e.args) == 1:
|
||
|
|
errno, errmsg = 'UNKNOWN', e.args[0]
|
||
|
|
else:
|
||
|
|
errno = e.args[0]
|
||
|
|
errmsg = e.args[1]
|
||
|
|
raise ConnectionError("Error %s while writing to socket. %s." %
|
||
|
|
(errno, errmsg))
|
||
|
|
except:
|
||
|
|
self.disconnect()
|
||
|
|
raise
|
||
|
|
|
||
|
|
def send_command(self, *args):
|
||
|
|
"Pack and send a command to the Redis server"
|
||
|
|
self.send_packed_command(self.pack_command(*args))
|
||
|
|
|
||
|
|
def can_read(self, timeout=0):
|
||
|
|
"Poll the socket to see if there's data that can be read."
|
||
|
|
sock = self._sock
|
||
|
|
if not sock:
|
||
|
|
self.connect()
|
||
|
|
sock = self._sock
|
||
|
|
return self._parser.can_read() or \
|
||
|
|
bool(select([sock], [], [], timeout)[0])
|
||
|
|
|
||
|
|
def read_response(self):
|
||
|
|
"Read the response from a previously sent command"
|
||
|
|
try:
|
||
|
|
response = self._parser.read_response()
|
||
|
|
except:
|
||
|
|
self.disconnect()
|
||
|
|
raise
|
||
|
|
if isinstance(response, ResponseError):
|
||
|
|
raise response
|
||
|
|
return response
|
||
|
|
|
||
|
|
def pack_command(self, *args):
|
||
|
|
"Pack a series of arguments into the Redis protocol"
|
||
|
|
output = []
|
||
|
|
# the client might have included 1 or more literal arguments in
|
||
|
|
# the command name, e.g., 'CONFIG GET'. The Redis server expects these
|
||
|
|
# arguments to be sent separately, so split the first argument
|
||
|
|
# manually. All of these arguements get wrapped in the Token class
|
||
|
|
# to prevent them from being encoded.
|
||
|
|
command = args[0]
|
||
|
|
if ' ' in command:
|
||
|
|
args = tuple([Token.get_token(s)
|
||
|
|
for s in command.split()]) + args[1:]
|
||
|
|
else:
|
||
|
|
args = (Token.get_token(command),) + args[1:]
|
||
|
|
|
||
|
|
buff = SYM_EMPTY.join(
|
||
|
|
(SYM_STAR, b(str(len(args))), SYM_CRLF))
|
||
|
|
|
||
|
|
for arg in imap(self.encoder.encode, args):
|
||
|
|
# to avoid large string mallocs, chunk the command into the
|
||
|
|
# output list if we're sending large values
|
||
|
|
if len(buff) > 6000 or len(arg) > 6000:
|
||
|
|
buff = SYM_EMPTY.join(
|
||
|
|
(buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF))
|
||
|
|
output.append(buff)
|
||
|
|
output.append(arg)
|
||
|
|
buff = SYM_CRLF
|
||
|
|
else:
|
||
|
|
buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))),
|
||
|
|
SYM_CRLF, arg, SYM_CRLF))
|
||
|
|
output.append(buff)
|
||
|
|
return output
|
||
|
|
|
||
|
|
def pack_commands(self, commands):
|
||
|
|
"Pack multiple commands into the Redis protocol"
|
||
|
|
output = []
|
||
|
|
pieces = []
|
||
|
|
buffer_length = 0
|
||
|
|
|
||
|
|
for cmd in commands:
|
||
|
|
for chunk in self.pack_command(*cmd):
|
||
|
|
pieces.append(chunk)
|
||
|
|
buffer_length += len(chunk)
|
||
|
|
|
||
|
|
if buffer_length > 6000:
|
||
|
|
output.append(SYM_EMPTY.join(pieces))
|
||
|
|
buffer_length = 0
|
||
|
|
pieces = []
|
||
|
|
|
||
|
|
if pieces:
|
||
|
|
output.append(SYM_EMPTY.join(pieces))
|
||
|
|
return output
|
||
|
|
|
||
|
|
|
||
|
|
class SSLConnection(Connection):
|
||
|
|
description_format = "SSLConnection<host=%(host)s,port=%(port)s,db=%(db)s>"
|
||
|
|
|
||
|
|
def __init__(self, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
|
||
|
|
ssl_ca_certs=None, **kwargs):
|
||
|
|
if not ssl_available:
|
||
|
|
raise RedisError("Python wasn't built with SSL support")
|
||
|
|
|
||
|
|
super(SSLConnection, self).__init__(**kwargs)
|
||
|
|
|
||
|
|
self.keyfile = ssl_keyfile
|
||
|
|
self.certfile = ssl_certfile
|
||
|
|
if ssl_cert_reqs is None:
|
||
|
|
ssl_cert_reqs = ssl.CERT_NONE
|
||
|
|
elif isinstance(ssl_cert_reqs, basestring):
|
||
|
|
CERT_REQS = {
|
||
|
|
'none': ssl.CERT_NONE,
|
||
|
|
'optional': ssl.CERT_OPTIONAL,
|
||
|
|
'required': ssl.CERT_REQUIRED
|
||
|
|
}
|
||
|
|
if ssl_cert_reqs not in CERT_REQS:
|
||
|
|
raise RedisError(
|
||
|
|
"Invalid SSL Certificate Requirements Flag: %s" %
|
||
|
|
ssl_cert_reqs)
|
||
|
|
ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
|
||
|
|
self.cert_reqs = ssl_cert_reqs
|
||
|
|
self.ca_certs = ssl_ca_certs
|
||
|
|
|
||
|
|
def _connect(self):
|
||
|
|
"Wrap the socket with SSL support"
|
||
|
|
sock = super(SSLConnection, self)._connect()
|
||
|
|
sock = ssl.wrap_socket(sock,
|
||
|
|
cert_reqs=self.cert_reqs,
|
||
|
|
keyfile=self.keyfile,
|
||
|
|
certfile=self.certfile,
|
||
|
|
ca_certs=self.ca_certs)
|
||
|
|
return sock
|
||
|
|
|
||
|
|
|
||
|
|
class UnixDomainSocketConnection(Connection):
|
||
|
|
description_format = "UnixDomainSocketConnection<path=%(path)s,db=%(db)s>"
|
||
|
|
|
||
|
|
def __init__(self, path='', db=0, password=None,
|
||
|
|
socket_timeout=None, encoding='utf-8',
|
||
|
|
encoding_errors='strict', decode_responses=False,
|
||
|
|
retry_on_timeout=False,
|
||
|
|
parser_class=DefaultParser, socket_read_size=65536):
|
||
|
|
self.pid = os.getpid()
|
||
|
|
self.path = path
|
||
|
|
self.db = db
|
||
|
|
self.password = password
|
||
|
|
self.socket_timeout = socket_timeout
|
||
|
|
self.retry_on_timeout = retry_on_timeout
|
||
|
|
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
|
||
|
|
self._sock = None
|
||
|
|
self._parser = parser_class(socket_read_size=socket_read_size)
|
||
|
|
self._description_args = {
|
||
|
|
'path': self.path,
|
||
|
|
'db': self.db,
|
||
|
|
}
|
||
|
|
self._connect_callbacks = []
|
||
|
|
|
||
|
|
def _connect(self):
|
||
|
|
"Create a Unix domain socket connection"
|
||
|
|
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||
|
|
sock.settimeout(self.socket_timeout)
|
||
|
|
sock.connect(self.path)
|
||
|
|
return sock
|
||
|
|
|
||
|
|
def _error_message(self, exception):
|
||
|
|
# args for socket.error can either be (errno, "message")
|
||
|
|
# or just "message"
|
||
|
|
if len(exception.args) == 1:
|
||
|
|
return "Error connecting to unix socket: %s. %s." % \
|
||
|
|
(self.path, exception.args[0])
|
||
|
|
else:
|
||
|
|
return "Error %s connecting to unix socket: %s. %s." % \
|
||
|
|
(exception.args[0], self.path, exception.args[1])
|
||
|
|
|
||
|
|
|
||
|
|
FALSE_STRINGS = ('0', 'F', 'FALSE', 'N', 'NO')
|
||
|
|
|
||
|
|
|
||
|
|
def to_bool(value):
|
||
|
|
if value is None or value == '':
|
||
|
|
return None
|
||
|
|
if isinstance(value, basestring) and value.upper() in FALSE_STRINGS:
|
||
|
|
return False
|
||
|
|
return bool(value)
|
||
|
|
|
||
|
|
|
||
|
|
URL_QUERY_ARGUMENT_PARSERS = {
|
||
|
|
'socket_timeout': float,
|
||
|
|
'socket_connect_timeout': float,
|
||
|
|
'socket_keepalive': to_bool,
|
||
|
|
'retry_on_timeout': to_bool
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
class ConnectionPool(object):
|
||
|
|
"Generic connection pool"
|
||
|
|
@classmethod
|
||
|
|
def from_url(cls, url, db=None, decode_components=False, **kwargs):
|
||
|
|
"""
|
||
|
|
Return a connection pool configured from the given URL.
|
||
|
|
|
||
|
|
For example::
|
||
|
|
|
||
|
|
redis://[:password]@localhost:6379/0
|
||
|
|
rediss://[:password]@localhost:6379/0
|
||
|
|
unix://[:password]@/path/to/socket.sock?db=0
|
||
|
|
|
||
|
|
Three URL schemes are supported:
|
||
|
|
|
||
|
|
- ```redis://``
|
||
|
|
<http://www.iana.org/assignments/uri-schemes/prov/redis>`_ creates a
|
||
|
|
normal TCP socket connection
|
||
|
|
- ```rediss://``
|
||
|
|
<http://www.iana.org/assignments/uri-schemes/prov/rediss>`_ creates a
|
||
|
|
SSL wrapped TCP socket connection
|
||
|
|
- ``unix://`` creates a Unix Domain Socket connection
|
||
|
|
|
||
|
|
There are several ways to specify a database number. The parse function
|
||
|
|
will return the first specified option:
|
||
|
|
1. A ``db`` querystring option, e.g. redis://localhost?db=0
|
||
|
|
2. If using the redis:// scheme, the path argument of the url, e.g.
|
||
|
|
redis://localhost/0
|
||
|
|
3. The ``db`` argument to this function.
|
||
|
|
|
||
|
|
If none of these options are specified, db=0 is used.
|
||
|
|
|
||
|
|
The ``decode_components`` argument allows this function to work with
|
||
|
|
percent-encoded URLs. If this argument is set to ``True`` all ``%xx``
|
||
|
|
escapes will be replaced by their single-character equivalents after
|
||
|
|
the URL has been parsed. This only applies to the ``hostname``,
|
||
|
|
``path``, and ``password`` components.
|
||
|
|
|
||
|
|
Any additional querystring arguments and keyword arguments will be
|
||
|
|
passed along to the ConnectionPool class's initializer. The querystring
|
||
|
|
arguments ``socket_connect_timeout`` and ``socket_timeout`` if supplied
|
||
|
|
are parsed as float values. The arguments ``socket_keepalive`` and
|
||
|
|
``retry_on_timeout`` are parsed to boolean values that accept
|
||
|
|
True/False, Yes/No values to indicate state. Invalid types cause a
|
||
|
|
``UserWarning`` to be raised. In the case of conflicting arguments,
|
||
|
|
querystring arguments always win.
|
||
|
|
"""
|
||
|
|
url_string = url
|
||
|
|
url = urlparse(url)
|
||
|
|
qs = ''
|
||
|
|
|
||
|
|
# in python2.6, custom URL schemes don't recognize querystring values
|
||
|
|
# they're left as part of the url.path.
|
||
|
|
if '?' in url.path and not url.query:
|
||
|
|
# chop the querystring including the ? off the end of the url
|
||
|
|
# and reparse it.
|
||
|
|
qs = url.path.split('?', 1)[1]
|
||
|
|
url = urlparse(url_string[:-(len(qs) + 1)])
|
||
|
|
else:
|
||
|
|
qs = url.query
|
||
|
|
|
||
|
|
url_options = {}
|
||
|
|
|
||
|
|
for name, value in iteritems(parse_qs(qs)):
|
||
|
|
if value and len(value) > 0:
|
||
|
|
parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
|
||
|
|
if parser:
|
||
|
|
try:
|
||
|
|
url_options[name] = parser(value[0])
|
||
|
|
except (TypeError, ValueError):
|
||
|
|
warnings.warn(UserWarning(
|
||
|
|
"Invalid value for `%s` in connection URL." % name
|
||
|
|
))
|
||
|
|
else:
|
||
|
|
url_options[name] = value[0]
|
||
|
|
|
||
|
|
if decode_components:
|
||
|
|
password = unquote(url.password) if url.password else None
|
||
|
|
path = unquote(url.path) if url.path else None
|
||
|
|
hostname = unquote(url.hostname) if url.hostname else None
|
||
|
|
else:
|
||
|
|
password = url.password
|
||
|
|
path = url.path
|
||
|
|
hostname = url.hostname
|
||
|
|
|
||
|
|
# We only support redis:// and unix:// schemes.
|
||
|
|
if url.scheme == 'unix':
|
||
|
|
url_options.update({
|
||
|
|
'password': password,
|
||
|
|
'path': path,
|
||
|
|
'connection_class': UnixDomainSocketConnection,
|
||
|
|
})
|
||
|
|
|
||
|
|
else:
|
||
|
|
url_options.update({
|
||
|
|
'host': hostname,
|
||
|
|
'port': int(url.port or 6379),
|
||
|
|
'password': password,
|
||
|
|
})
|
||
|
|
|
||
|
|
# If there's a path argument, use it as the db argument if a
|
||
|
|
# querystring value wasn't specified
|
||
|
|
if 'db' not in url_options and path:
|
||
|
|
try:
|
||
|
|
url_options['db'] = int(path.replace('/', ''))
|
||
|
|
except (AttributeError, ValueError):
|
||
|
|
pass
|
||
|
|
|
||
|
|
if url.scheme == 'rediss':
|
||
|
|
url_options['connection_class'] = SSLConnection
|
||
|
|
|
||
|
|
# last shot at the db value
|
||
|
|
url_options['db'] = int(url_options.get('db', db or 0))
|
||
|
|
|
||
|
|
# update the arguments from the URL values
|
||
|
|
kwargs.update(url_options)
|
||
|
|
|
||
|
|
# backwards compatability
|
||
|
|
if 'charset' in kwargs:
|
||
|
|
warnings.warn(DeprecationWarning(
|
||
|
|
'"charset" is deprecated. Use "encoding" instead'))
|
||
|
|
kwargs['encoding'] = kwargs.pop('charset')
|
||
|
|
if 'errors' in kwargs:
|
||
|
|
warnings.warn(DeprecationWarning(
|
||
|
|
'"errors" is deprecated. Use "encoding_errors" instead'))
|
||
|
|
kwargs['encoding_errors'] = kwargs.pop('errors')
|
||
|
|
|
||
|
|
return cls(**kwargs)
|
||
|
|
|
||
|
|
def __init__(self, connection_class=Connection, max_connections=None,
|
||
|
|
**connection_kwargs):
|
||
|
|
"""
|
||
|
|
Create a connection pool. If max_connections is set, then this
|
||
|
|
object raises redis.ConnectionError when the pool's limit is reached.
|
||
|
|
|
||
|
|
By default, TCP connections are created unless connection_class is
|
||
|
|
specified. Use redis.UnixDomainSocketConnection for unix sockets.
|
||
|
|
|
||
|
|
Any additional keyword arguments are passed to the constructor of
|
||
|
|
connection_class.
|
||
|
|
"""
|
||
|
|
max_connections = max_connections or 2 ** 31
|
||
|
|
if not isinstance(max_connections, (int, long)) or max_connections < 0:
|
||
|
|
raise ValueError('"max_connections" must be a positive integer')
|
||
|
|
|
||
|
|
self.connection_class = connection_class
|
||
|
|
self.connection_kwargs = connection_kwargs
|
||
|
|
self.max_connections = max_connections
|
||
|
|
|
||
|
|
self.reset()
|
||
|
|
|
||
|
|
def __repr__(self):
|
||
|
|
return "%s<%s>" % (
|
||
|
|
type(self).__name__,
|
||
|
|
self.connection_class.description_format % self.connection_kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
def reset(self):
|
||
|
|
self.pid = os.getpid()
|
||
|
|
self._created_connections = 0
|
||
|
|
self._available_connections = []
|
||
|
|
self._in_use_connections = set()
|
||
|
|
self._check_lock = threading.Lock()
|
||
|
|
|
||
|
|
def _checkpid(self):
|
||
|
|
if self.pid != os.getpid():
|
||
|
|
with self._check_lock:
|
||
|
|
if self.pid == os.getpid():
|
||
|
|
# another thread already did the work while we waited
|
||
|
|
# on the lock.
|
||
|
|
return
|
||
|
|
self.disconnect()
|
||
|
|
self.reset()
|
||
|
|
|
||
|
|
def get_connection(self, command_name, *keys, **options):
|
||
|
|
"Get a connection from the pool"
|
||
|
|
self._checkpid()
|
||
|
|
try:
|
||
|
|
connection = self._available_connections.pop()
|
||
|
|
except IndexError:
|
||
|
|
connection = self.make_connection()
|
||
|
|
self._in_use_connections.add(connection)
|
||
|
|
return connection
|
||
|
|
|
||
|
|
def get_encoder(self):
|
||
|
|
"Return an encoder based on encoding settings"
|
||
|
|
kwargs = self.connection_kwargs
|
||
|
|
return Encoder(
|
||
|
|
encoding=kwargs.get('encoding', 'utf-8'),
|
||
|
|
encoding_errors=kwargs.get('encoding_errors', 'strict'),
|
||
|
|
decode_responses=kwargs.get('decode_responses', False)
|
||
|
|
)
|
||
|
|
|
||
|
|
def make_connection(self):
|
||
|
|
"Create a new connection"
|
||
|
|
if self._created_connections >= self.max_connections:
|
||
|
|
raise ConnectionError("Too many connections")
|
||
|
|
self._created_connections += 1
|
||
|
|
return self.connection_class(**self.connection_kwargs)
|
||
|
|
|
||
|
|
def release(self, connection):
|
||
|
|
"Releases the connection back to the pool"
|
||
|
|
self._checkpid()
|
||
|
|
if connection.pid != self.pid:
|
||
|
|
return
|
||
|
|
self._in_use_connections.remove(connection)
|
||
|
|
self._available_connections.append(connection)
|
||
|
|
|
||
|
|
def disconnect(self):
|
||
|
|
"Disconnects all connections in the pool"
|
||
|
|
all_conns = chain(self._available_connections,
|
||
|
|
self._in_use_connections)
|
||
|
|
for connection in all_conns:
|
||
|
|
connection.disconnect()
|
||
|
|
|
||
|
|
|
||
|
|
class BlockingConnectionPool(ConnectionPool):
|
||
|
|
"""
|
||
|
|
Thread-safe blocking connection pool::
|
||
|
|
|
||
|
|
>>> from redis.client import Redis
|
||
|
|
>>> client = Redis(connection_pool=BlockingConnectionPool())
|
||
|
|
|
||
|
|
It performs the same function as the default
|
||
|
|
``:py:class: ~redis.connection.ConnectionPool`` implementation, in that,
|
||
|
|
it maintains a pool of reusable connections that can be shared by
|
||
|
|
multiple redis clients (safely across threads if required).
|
||
|
|
|
||
|
|
The difference is that, in the event that a client tries to get a
|
||
|
|
connection from the pool when all of connections are in use, rather than
|
||
|
|
raising a ``:py:class: ~redis.exceptions.ConnectionError`` (as the default
|
||
|
|
``:py:class: ~redis.connection.ConnectionPool`` implementation does), it
|
||
|
|
makes the client wait ("blocks") for a specified number of seconds until
|
||
|
|
a connection becomes available.
|
||
|
|
|
||
|
|
Use ``max_connections`` to increase / decrease the pool size::
|
||
|
|
|
||
|
|
>>> pool = BlockingConnectionPool(max_connections=10)
|
||
|
|
|
||
|
|
Use ``timeout`` to tell it either how many seconds to wait for a connection
|
||
|
|
to become available, or to block forever:
|
||
|
|
|
||
|
|
# Block forever.
|
||
|
|
>>> pool = BlockingConnectionPool(timeout=None)
|
||
|
|
|
||
|
|
# Raise a ``ConnectionError`` after five seconds if a connection is
|
||
|
|
# not available.
|
||
|
|
>>> pool = BlockingConnectionPool(timeout=5)
|
||
|
|
"""
|
||
|
|
def __init__(self, max_connections=50, timeout=20,
|
||
|
|
connection_class=Connection, queue_class=LifoQueue,
|
||
|
|
**connection_kwargs):
|
||
|
|
|
||
|
|
self.queue_class = queue_class
|
||
|
|
self.timeout = timeout
|
||
|
|
super(BlockingConnectionPool, self).__init__(
|
||
|
|
connection_class=connection_class,
|
||
|
|
max_connections=max_connections,
|
||
|
|
**connection_kwargs)
|
||
|
|
|
||
|
|
def reset(self):
|
||
|
|
self.pid = os.getpid()
|
||
|
|
self._check_lock = threading.Lock()
|
||
|
|
|
||
|
|
# Create and fill up a thread safe queue with ``None`` values.
|
||
|
|
self.pool = self.queue_class(self.max_connections)
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
self.pool.put_nowait(None)
|
||
|
|
except Full:
|
||
|
|
break
|
||
|
|
|
||
|
|
# Keep a list of actual connection instances so that we can
|
||
|
|
# disconnect them later.
|
||
|
|
self._connections = []
|
||
|
|
|
||
|
|
def make_connection(self):
|
||
|
|
"Make a fresh connection."
|
||
|
|
connection = self.connection_class(**self.connection_kwargs)
|
||
|
|
self._connections.append(connection)
|
||
|
|
return connection
|
||
|
|
|
||
|
|
def get_connection(self, command_name, *keys, **options):
|
||
|
|
"""
|
||
|
|
Get a connection, blocking for ``self.timeout`` until a connection
|
||
|
|
is available from the pool.
|
||
|
|
|
||
|
|
If the connection returned is ``None`` then creates a new connection.
|
||
|
|
Because we use a last-in first-out queue, the existing connections
|
||
|
|
(having been returned to the pool after the initial ``None`` values
|
||
|
|
were added) will be returned before ``None`` values. This means we only
|
||
|
|
create new connections when we need to, i.e.: the actual number of
|
||
|
|
connections will only increase in response to demand.
|
||
|
|
"""
|
||
|
|
# Make sure we haven't changed process.
|
||
|
|
self._checkpid()
|
||
|
|
|
||
|
|
# Try and get a connection from the pool. If one isn't available within
|
||
|
|
# self.timeout then raise a ``ConnectionError``.
|
||
|
|
connection = None
|
||
|
|
try:
|
||
|
|
connection = self.pool.get(block=True, timeout=self.timeout)
|
||
|
|
except Empty:
|
||
|
|
# Note that this is not caught by the redis client and will be
|
||
|
|
# raised unless handled by application code. If you want never to
|
||
|
|
raise ConnectionError("No connection available.")
|
||
|
|
|
||
|
|
# If the ``connection`` is actually ``None`` then that's a cue to make
|
||
|
|
# a new connection to add to the pool.
|
||
|
|
if connection is None:
|
||
|
|
connection = self.make_connection()
|
||
|
|
|
||
|
|
return connection
|
||
|
|
|
||
|
|
def release(self, connection):
|
||
|
|
"Releases the connection back to the pool."
|
||
|
|
# Make sure we haven't changed process.
|
||
|
|
self._checkpid()
|
||
|
|
if connection.pid != self.pid:
|
||
|
|
return
|
||
|
|
|
||
|
|
# Put the connection back into the pool.
|
||
|
|
try:
|
||
|
|
self.pool.put_nowait(connection)
|
||
|
|
except Full:
|
||
|
|
# perhaps the pool has been reset() after a fork? regardless,
|
||
|
|
# we don't want this connection
|
||
|
|
pass
|
||
|
|
|
||
|
|
def disconnect(self):
|
||
|
|
"Disconnects all connections in the pool."
|
||
|
|
for connection in self._connections:
|
||
|
|
connection.disconnect()
|