# vim: sw=4:expandtab:foldmethod=marker # # Copyright (c) 2007-2009, Mathieu Fenniak # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # # * Redistributions of source code must retain the above copyright notice, # this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # * The name of the author may not be used to endorse or promote products # derived from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. __author__ = "Mathieu Fenniak" import socket try: import ssl as sslmodule except ImportError: sslmodule = None import select import threading import struct import hashlib from cStringIO import StringIO from errors import * from util import MulticastDelegate import types ## # An SSLRequest message. To initiate an SSL-encrypted connection, an # SSLRequest message is used rather than a {@link StartupMessage # StartupMessage}. A StartupMessage is still sent, but only after SSL # negotiation (if accepted). #

# Stability: This is an internal class. No stability guarantee is made. class SSLRequest(object): def __init__(self): pass # Int32(8) - Message length, including self.
# Int32(80877103) - The SSL request code.
def serialize(self): return struct.pack("!ii", 8, 80877103) ## # A StartupMessage message. Begins a DB session, identifying the user to be # authenticated as and the database to connect to. #

# Stability: This is an internal class. No stability guarantee is made. class StartupMessage(object): def __init__(self, user, database=None): self.user = user self.database = database # Int32 - Message length, including self. # Int32(196608) - Protocol version number. Version 3.0. # Any number of key/value pairs, terminated by a zero byte: # String - A parameter name (user, database, or options) # String - Parameter value def serialize(self): protocol = 196608 val = struct.pack("!i", protocol) val += "user\x00" + self.user + "\x00" if self.database: val += "database\x00" + self.database + "\x00" val += "\x00" val = struct.pack("!i", len(val) + 4) + val return val ## # Parse message. Creates a prepared statement in the DB session. #

# Stability: This is an internal class. No stability guarantee is made. # # @param ps Name of the prepared statement to create. # @param qs Query string. # @param type_oids An iterable that contains the PostgreSQL type OIDs for # parameters in the query string. class Parse(object): def __init__(self, ps, qs, type_oids): if isinstance(qs, unicode): raise TypeError("qs must be encoded byte data") self.ps = ps self.qs = qs self.type_oids = type_oids def __repr__(self): return "" % (self.ps, self.qs) # Byte1('P') - Identifies the message as a Parse command. # Int32 - Message length, including self. # String - Prepared statement name. An empty string selects the unnamed # prepared statement. # String - The query string. # Int16 - Number of parameter data types specified (can be zero). # For each parameter: # Int32 - The OID of the parameter data type. def serialize(self): val = self.ps + "\x00" + self.qs + "\x00" val = val + struct.pack("!h", len(self.type_oids)) for oid in self.type_oids: # Parse message doesn't seem to handle the -1 type_oid for NULL # values that other messages handle. So we'll provide type_oid 705, # the PG "unknown" type. if oid == -1: oid = 705 val = val + struct.pack("!i", oid) val = struct.pack("!i", len(val) + 4) + val val = "P" + val return val ## # Bind message. Readies a prepared statement for execution. #

# Stability: This is an internal class. No stability guarantee is made. # # @param portal Name of the destination portal. # @param ps Name of the source prepared statement. # @param in_fc An iterable containing the format codes for input # parameters. 0 = Text, 1 = Binary. # @param params The parameters. # @param out_fc An iterable containing the format codes for output # parameters. 0 = Text, 1 = Binary. # @param kwargs Additional arguments to pass to the type conversion # methods. class Bind(object): def __init__(self, portal, ps, in_fc, params, out_fc, **kwargs): self.portal = portal self.ps = ps self.in_fc = in_fc self.params = [] for i in range(len(params)): if len(self.in_fc) == 0: fc = 0 elif len(self.in_fc) == 1: fc = self.in_fc[0] else: fc = self.in_fc[i] self.params.append(types.pg_value(params[i], fc, **kwargs)) self.out_fc = out_fc def __repr__(self): return "" % (self.portal, self.ps) # Byte1('B') - Identifies the Bind command. # Int32 - Message length, including self. # String - Name of the destination portal. # String - Name of the source prepared statement. # Int16 - Number of parameter format codes. # For each parameter format code: # Int16 - The parameter format code. # Int16 - Number of parameter values. # For each parameter value: # Int32 - The length of the parameter value, in bytes, not including this # this length. -1 indicates a NULL parameter value, in which no # value bytes follow. # Byte[n] - Value of the parameter. # Int16 - The number of result-column format codes. # For each result-column format code: # Int16 - The format code. def serialize(self): retval = StringIO() retval.write(self.portal + "\x00") retval.write(self.ps + "\x00") retval.write(struct.pack("!h", len(self.in_fc))) for fc in self.in_fc: retval.write(struct.pack("!h", fc)) retval.write(struct.pack("!h", len(self.params))) for param in self.params: if param == None: # special case, NULL value retval.write(struct.pack("!i", -1)) else: retval.write(struct.pack("!i", len(param))) retval.write(param) retval.write(struct.pack("!h", len(self.out_fc))) for fc in self.out_fc: retval.write(struct.pack("!h", fc)) val = retval.getvalue() val = struct.pack("!i", len(val) + 4) + val val = "B" + val return val ## # A Close message, used for closing prepared statements and portals. #

# Stability: This is an internal class. No stability guarantee is made. # # @param typ 'S' for prepared statement, 'P' for portal. # @param name The name of the item to close. class Close(object): def __init__(self, typ, name): if len(typ) != 1: raise InternalError("Close typ must be 1 char") self.typ = typ self.name = name # Byte1('C') - Identifies the message as a close command. # Int32 - Message length, including self. # Byte1 - 'S' for prepared statement, 'P' for portal. # String - The name of the item to close. def serialize(self): val = self.typ + self.name + "\x00" val = struct.pack("!i", len(val) + 4) + val val = "C" + val return val ## # A specialized Close message for a portal. #

# Stability: This is an internal class. No stability guarantee is made. class ClosePortal(Close): def __init__(self, name): Close.__init__(self, "P", name) ## # A specialized Close message for a prepared statement. #

# Stability: This is an internal class. No stability guarantee is made. class ClosePreparedStatement(Close): def __init__(self, name): Close.__init__(self, "S", name) ## # A Describe message, used for obtaining information on prepared statements # and portals. #

# Stability: This is an internal class. No stability guarantee is made. # # @param typ 'S' for prepared statement, 'P' for portal. # @param name The name of the item to close. class Describe(object): def __init__(self, typ, name): if len(typ) != 1: raise InternalError("Describe typ must be 1 char") self.typ = typ self.name = name # Byte1('D') - Identifies the message as a describe command. # Int32 - Message length, including self. # Byte1 - 'S' for prepared statement, 'P' for portal. # String - The name of the item to close. def serialize(self): val = self.typ + self.name + "\x00" val = struct.pack("!i", len(val) + 4) + val val = "D" + val return val ## # A specialized Describe message for a portal. #

# Stability: This is an internal class. No stability guarantee is made. class DescribePortal(Describe): def __init__(self, name): Describe.__init__(self, "P", name) def __repr__(self): return "" % (self.name) ## # A specialized Describe message for a prepared statement. #

# Stability: This is an internal class. No stability guarantee is made. class DescribePreparedStatement(Describe): def __init__(self, name): Describe.__init__(self, "S", name) def __repr__(self): return "" % (self.name) ## # A Flush message forces the backend to deliver any data pending in its # output buffers. #

# Stability: This is an internal class. No stability guarantee is made. class Flush(object): # Byte1('H') - Identifies the message as a flush command. # Int32(4) - Length of message, including self. def serialize(self): return 'H\x00\x00\x00\x04' def __repr__(self): return "" ## # Causes the backend to close the current transaction (if not in a BEGIN/COMMIT # block), and issue ReadyForQuery. #

# Stability: This is an internal class. No stability guarantee is made. class Sync(object): # Byte1('S') - Identifies the message as a sync command. # Int32(4) - Length of message, including self. def serialize(self): return 'S\x00\x00\x00\x04' def __repr__(self): return "" ## # Transmits a password. #

# Stability: This is an internal class. No stability guarantee is made. class PasswordMessage(object): def __init__(self, pwd): self.pwd = pwd # Byte1('p') - Identifies the message as a password message. # Int32 - Message length including self. # String - The password. Password may be encrypted. def serialize(self): val = self.pwd + "\x00" val = struct.pack("!i", len(val) + 4) + val val = "p" + val return val ## # Requests that the backend execute a portal and retrieve any number of rows. #

# Stability: This is an internal class. No stability guarantee is made. # @param row_count The number of rows to return. Can be zero to indicate the # backend should return all rows. If the portal represents a # query that does not return rows, no rows will be returned # no matter what the row_count. class Execute(object): def __init__(self, portal, row_count): self.portal = portal self.row_count = row_count # Byte1('E') - Identifies the message as an execute message. # Int32 - Message length, including self. # String - The name of the portal to execute. # Int32 - Maximum number of rows to return, if portal contains a query that # returns rows. 0 = no limit. def serialize(self): val = self.portal + "\x00" + struct.pack("!i", self.row_count) val = struct.pack("!i", len(val) + 4) + val val = "E" + val return val class SimpleQuery(object): "Requests that the backend execute a Simple Query (SQL string)" def __init__(self, query_string): self.query_string = query_string # Byte1('Q') - Identifies the message as an query message. # Int32 - Message length, including self. # String - The query string itself. def serialize(self): val = self.query_string + "\x00" val = struct.pack("!i", len(val) + 4) + val val = "Q" + val return val def __repr__(self): return "" % (self.query_string) ## # Informs the backend that the connection is being closed. #

# Stability: This is an internal class. No stability guarantee is made. class Terminate(object): def __init__(self): pass # Byte1('X') - Identifies the message as a terminate message. # Int32(4) - Message length, including self. def serialize(self): return 'X\x00\x00\x00\x04' ## # Base class of all Authentication[*] messages. #

# Stability: This is an internal class. No stability guarantee is made. class AuthenticationRequest(object): def __init__(self, data): pass # Byte1('R') - Identifies the message as an authentication request. # Int32(8) - Message length, including self. # Int32 - An authentication code that represents different # authentication messages: # 0 = AuthenticationOk # 5 = MD5 pwd # 2 = Kerberos v5 (not supported by pg8000) # 3 = Cleartext pwd (not supported by pg8000) # 4 = crypt() pwd (not supported by pg8000) # 6 = SCM credential (not supported by pg8000) # 7 = GSSAPI (not supported by pg8000) # 8 = GSSAPI data (not supported by pg8000) # 9 = SSPI (not supported by pg8000) # Some authentication messages have additional data following the # authentication code. That data is documented in the appropriate class. def createFromData(data): ident = struct.unpack("!i", data[:4])[0] klass = authentication_codes.get(ident, None) if klass != None: return klass(data[4:]) else: raise NotSupportedError("authentication method %r not supported" % (ident,)) createFromData = staticmethod(createFromData) def ok(self, conn, user, **kwargs): raise InternalError("ok method should be overridden on AuthenticationRequest instance") ## # A message representing that the backend accepting the provided username # without any challenge. #

# Stability: This is an internal class. No stability guarantee is made. class AuthenticationOk(AuthenticationRequest): def ok(self, conn, user, **kwargs): return True ## # A message representing the backend requesting an MD5 hashed password # response. The response will be sent as md5(md5(pwd + login) + salt). #

# Stability: This is an internal class. No stability guarantee is made. class AuthenticationMD5Password(AuthenticationRequest): # Additional message data: # Byte4 - Hash salt. def __init__(self, data): self.salt = "".join(struct.unpack("4c", data)) def ok(self, conn, user, password=None, **kwargs): if password == None: raise InterfaceError("server requesting MD5 password authentication, but no password was provided") pwd = "md5" + hashlib.md5(hashlib.md5(password + user).hexdigest() + self.salt).hexdigest() conn._send(PasswordMessage(pwd)) conn._flush() reader = MessageReader(conn) reader.add_message(AuthenticationRequest, lambda msg, reader: reader.return_value(msg.ok(conn, user)), reader) reader.add_message(ErrorResponse, self._ok_error) return reader.handle_messages() def _ok_error(self, msg): if msg.code == "28000": raise InterfaceError("md5 password authentication failed") else: raise msg.createException() authentication_codes = { 0: AuthenticationOk, 5: AuthenticationMD5Password, } ## # ParameterStatus message sent from backend, used to inform the frotnend of # runtime configuration parameter changes. #

# Stability: This is an internal class. No stability guarantee is made. class ParameterStatus(object): def __init__(self, key, value): self.key = key self.value = value # Byte1('S') - Identifies ParameterStatus # Int32 - Message length, including self. # String - Runtime parameter name. # String - Runtime parameter value. def createFromData(data): key = data[:data.find("\x00")] value = data[data.find("\x00")+1:-1] return ParameterStatus(key, value) createFromData = staticmethod(createFromData) ## # BackendKeyData message sent from backend. Contains a connection's process # ID and a secret key. Can be used to terminate the connection's current # actions, such as a long running query. Not supported by pg8000 yet. #

# Stability: This is an internal class. No stability guarantee is made. class BackendKeyData(object): def __init__(self, process_id, secret_key): self.process_id = process_id self.secret_key = secret_key # Byte1('K') - Identifier. # Int32(12) - Message length, including self. # Int32 - Process ID. # Int32 - Secret key. def createFromData(data): process_id, secret_key = struct.unpack("!2i", data) return BackendKeyData(process_id, secret_key) createFromData = staticmethod(createFromData) ## # Message representing a query with no data. #

# Stability: This is an internal class. No stability guarantee is made. class NoData(object): # Byte1('n') - Identifier. # Int32(4) - Message length, including self. def createFromData(data): return NoData() createFromData = staticmethod(createFromData) ## # Message representing a successful Parse. #

# Stability: This is an internal class. No stability guarantee is made. class ParseComplete(object): # Byte1('1') - Identifier. # Int32(4) - Message length, including self. def createFromData(data): return ParseComplete() createFromData = staticmethod(createFromData) ## # Message representing a successful Bind. #

# Stability: This is an internal class. No stability guarantee is made. class BindComplete(object): # Byte1('2') - Identifier. # Int32(4) - Message length, including self. def createFromData(data): return BindComplete() createFromData = staticmethod(createFromData) ## # Message representing a successful Close. #

# Stability: This is an internal class. No stability guarantee is made. class CloseComplete(object): # Byte1('3') - Identifier. # Int32(4) - Message length, including self. def createFromData(data): return CloseComplete() createFromData = staticmethod(createFromData) ## # Message representing data from an Execute has been received, but more data # exists in the portal. #

# Stability: This is an internal class. No stability guarantee is made. class PortalSuspended(object): # Byte1('s') - Identifier. # Int32(4) - Message length, including self. def createFromData(data): return PortalSuspended() createFromData = staticmethod(createFromData) ## # Message representing the backend is ready to process a new query. #

# Stability: This is an internal class. No stability guarantee is made. class ReadyForQuery(object): def __init__(self, status): self._status = status ## # I = Idle, T = Idle in Transaction, E = idle in failed transaction. status = property(lambda self: self._status) def __repr__(self): return "" % \ {"I": "Idle", "T": "Idle in Transaction", "E": "Idle in Failed Transaction"}[self.status] # Byte1('Z') - Identifier. # Int32(5) - Message length, including self. # Byte1 - Status indicator. def createFromData(data): return ReadyForQuery(data) createFromData = staticmethod(createFromData) ## # Represents a notice sent from the server. This is not the same as a # notification. A notice is just additional information about a query, such # as a notice that a primary key has automatically been created for a table. #

# A NoticeResponse instance will have properties containing the data sent # from the server: #

#

# Stability: Added in pg8000 v1.03. Required properties severity, code, and # msg are guaranteed for v1.xx. Other properties should be checked with # hasattr before accessing. class NoticeResponse(object): responseKeys = { "S": "severity", # always present "C": "code", # always present "M": "msg", # always present "D": "detail", "H": "hint", "P": "position", "p": "_position", "q": "_query", "W": "where", "F": "file", "L": "line", "R": "routine", } def __init__(self, **kwargs): for arg, value in kwargs.items(): setattr(self, arg, value) def __repr__(self): return "" % (self.severity, self.code, self.msg) def dataIntoDict(data): retval = {} for s in data.split("\x00"): if not s: continue key, value = s[0], s[1:] key = NoticeResponse.responseKeys.get(key, key) retval[key] = value return retval dataIntoDict = staticmethod(dataIntoDict) # Byte1('N') - Identifier # Int32 - Message length # Any number of these, followed by a zero byte: # Byte1 - code identifying the field type (see responseKeys) # String - field value def createFromData(data): return NoticeResponse(**NoticeResponse.dataIntoDict(data)) createFromData = staticmethod(createFromData) ## # A message sent in case of a server-side error. Contains the same properties # that {@link NoticeResponse NoticeResponse} contains. #

# Stability: Added in pg8000 v1.03. Required properties severity, code, and # msg are guaranteed for v1.xx. Other properties should be checked with # hasattr before accessing. class ErrorResponse(object): def __init__(self, **kwargs): for arg, value in kwargs.items(): setattr(self, arg, value) def __repr__(self): return "" % (self.severity, self.code, self.msg) def createException(self): return ProgrammingError(self.severity, self.code, self.msg) def createFromData(data): return ErrorResponse(**NoticeResponse.dataIntoDict(data)) createFromData = staticmethod(createFromData) ## # A message sent if this connection receives a NOTIFY that it was LISTENing for. #

# Stability: Added in pg8000 v1.03. When limited to accessing properties from # a notification event dispatch, stability is guaranteed for v1.xx. class NotificationResponse(object): def __init__(self, backend_pid, condition, additional_info): self._backend_pid = backend_pid self._condition = condition self._additional_info = additional_info ## # An integer representing the process ID of the backend that triggered # the NOTIFY. #

# Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx. backend_pid = property(lambda self: self._backend_pid) ## # The name of the notification fired. #

# Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx. condition = property(lambda self: self._condition) ## # Currently unspecified by the PostgreSQL documentation as of v8.3.1. #

# Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx. additional_info = property(lambda self: self._additional_info) def __repr__(self): return "" % (self.backend_pid, self.condition, self.additional_info) def createFromData(data): backend_pid = struct.unpack("!i", data[:4])[0] data = data[4:] null = data.find("\x00") condition = data[:null] data = data[null+1:] null = data.find("\x00") additional_info = data[:null] return NotificationResponse(backend_pid, condition, additional_info) createFromData = staticmethod(createFromData) class ParameterDescription(object): def __init__(self, type_oids): self.type_oids = type_oids def createFromData(data): count = struct.unpack("!h", data[:2])[0] type_oids = struct.unpack("!" + "i"*count, data[2:]) return ParameterDescription(type_oids) createFromData = staticmethod(createFromData) class RowDescription(object): def __init__(self, fields): self.fields = fields def createFromData(data): count = struct.unpack("!h", data[:2])[0] data = data[2:] fields = [] for i in range(count): null = data.find("\x00") field = {"name": data[:null]} data = data[null+1:] field["table_oid"], field["column_attrnum"], field["type_oid"], field["type_size"], field["type_modifier"], field["format"] = struct.unpack("!ihihih", data[:18]) data = data[18:] fields.append(field) return RowDescription(fields) createFromData = staticmethod(createFromData) class CommandComplete(object): def __init__(self, command, rows=None, oid=None): self.command = command self.rows = rows self.oid = oid def createFromData(data): values = data[:-1].split(" ") args = {} args['command'] = values[0] if args['command'] in ("INSERT", "DELETE", "UPDATE", "MOVE", "FETCH", "COPY", "SELECT"): args['rows'] = int(values[-1]) if args['command'] == "INSERT": args['oid'] = int(values[1]) else: args['command'] = data[:-1] return CommandComplete(**args) createFromData = staticmethod(createFromData) class DataRow(object): def __init__(self, fields): self.fields = fields def createFromData(data): count = struct.unpack("!h", data[:2])[0] data = data[2:] fields = [] for i in range(count): val_len = struct.unpack("!i", data[:4])[0] data = data[4:] if val_len == -1: fields.append(None) else: fields.append(data[:val_len]) data = data[val_len:] return DataRow(fields) createFromData = staticmethod(createFromData) class CopyData(object): # "d": CopyData, def __init__(self, data): self.data = data def createFromData(data): return CopyData(data) createFromData = staticmethod(createFromData) def serialize(self): return 'd' + struct.pack('!i', len(self.data) + 4) + self.data class CopyDone(object): # Byte1('c') - Identifier. # Int32(4) - Message length, including self. def createFromData(data): return CopyDone() createFromData = staticmethod(createFromData) def serialize(self): return 'c\x00\x00\x00\x04' class CopyOutResponse(object): # Byte1('H') # Int32(4) - Length of message contents in bytes, including self. # Int8(1) - 0 textual, 1 binary # Int16(2) - Number of columns # Int16(N) - Format codes for each column (0 text, 1 binary) def __init__(self, is_binary, column_formats): self.is_binary = is_binary self.column_formats = column_formats def createFromData(data): is_binary, num_cols = struct.unpack('!bh', data[:3]) column_formats = struct.unpack('!' + ('h' * num_cols), data[3:]) return CopyOutResponse(is_binary, column_formats) createFromData = staticmethod(createFromData) class CopyInResponse(object): # Byte1('G') # Otherwise the same as CopyOutResponse def __init__(self, is_binary, column_formats): self.is_binary = is_binary self.column_formats = column_formats def createFromData(data): is_binary, num_cols = struct.unpack('!bh', data[:3]) column_formats = struct.unpack('!' + ('h' * num_cols), data[3:]) return CopyInResponse(is_binary, column_formats) createFromData = staticmethod(createFromData) class EmptyQueryResponse(object): # Byte1('I') # Response to an empty query string. (This substitutes for CommandComplete.) def createFromData(data): return EmptyQueryResponse() createFromData = staticmethod(createFromData) class MessageReader(object): def __init__(self, connection): self._conn = connection self._msgs = [] # If true, raise exception from an ErrorResponse after messages are # processed. This can be used to leave the connection in a usable # state after an error response, rather than having unconsumed # messages that won't be understood in another context. self.delay_raising_exception = False self.ignore_unhandled_messages = False def add_message(self, msg_class, handler, *args, **kwargs): self._msgs.append((msg_class, handler, args, kwargs)) def clear_messages(self): self._msgs = [] def return_value(self, value): self._retval = value def handle_messages(self): exc = None while 1: msg = self._conn._read_message() msg_handled = False for (msg_class, handler, args, kwargs) in self._msgs: if isinstance(msg, msg_class): msg_handled = True retval = handler(msg, *args, **kwargs) if retval: # The handler returned a true value, meaning that the # message loop should be aborted. if exc != None: raise exc return retval elif hasattr(self, "_retval"): # The handler told us to return -- used for non-true # return values if exc != None: raise exc return self._retval if msg_handled: continue elif isinstance(msg, ErrorResponse): exc = msg.createException() if not self.delay_raising_exception: raise exc elif isinstance(msg, NoticeResponse): self._conn.handleNoticeResponse(msg) elif isinstance(msg, ParameterStatus): self._conn.handleParameterStatus(msg) elif isinstance(msg, NotificationResponse): self._conn.handleNotificationResponse(msg) elif not self.ignore_unhandled_messages: raise InternalError("Unexpected response msg %r" % (msg)) def sync_on_error(fn): def _fn(self, *args, **kwargs): try: self._sock_lock.acquire() return fn(self, *args, **kwargs) except: self._sync() raise finally: self._sock_lock.release() return _fn class Connection(object): def __init__(self, unix_sock=None, host=None, port=5432, socket_timeout=60, ssl=False): self._client_encoding = "ascii" self._integer_datetimes = False self._server_version = None self._sock_buf = "" self._sock_buf_pos = 0 self._send_sock_buf = [] self._block_size = 8192 self._sock_lock = threading.Lock() if unix_sock == None and host != None: self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) elif unix_sock != None: if not hasattr(socket, "AF_UNIX"): raise InterfaceError("attempt to connect to unix socket on unsupported platform") self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) else: raise ProgrammingError("one of host or unix_sock must be provided") if unix_sock == None and host != None: self._sock.connect((host, port)) elif unix_sock != None: self._sock.connect(unix_sock) if ssl: self._sock_lock.acquire() try: self._send(SSLRequest()) self._flush() resp = self._sock.recv(1) if resp == 'S' and sslmodule is not None: self._sock = sslmodule.wrap_socket(self._sock) elif sslmodule is None: raise InterfaceError("SSL required but ssl module not available in this python installation") else: raise InterfaceError("server refuses SSL") finally: self._sock_lock.release() else: # settimeout causes ssl failure, on windows. Python bug 1462352. self._sock.settimeout(socket_timeout) self._state = "noauth" self._backend_key_data = None self.NoticeReceived = MulticastDelegate() self.ParameterStatusReceived = MulticastDelegate() self.NotificationReceived = MulticastDelegate() self.ParameterStatusReceived += self._onParameterStatusReceived def verifyState(self, state): if self._state != state: raise InternalError("connection state must be %s, is %s" % (state, self._state)) def _send(self, msg): assert self._sock_lock.locked() ##print "_send(%r)" % msg data = msg.serialize() if not isinstance(data, str): raise TypeError("bytes data expected") self._send_sock_buf.append(data) def _flush(self): assert self._sock_lock.locked() self._sock.sendall("".join(self._send_sock_buf)) del self._send_sock_buf[:] def _read_bytes(self, byte_count): retval = [] bytes_read = 0 while bytes_read < byte_count: if self._sock_buf_pos == len(self._sock_buf): self._sock_buf = self._sock.recv(1024) self._sock_buf_pos = 0 rpos = min(len(self._sock_buf), self._sock_buf_pos + (byte_count - bytes_read)) addt_data = self._sock_buf[self._sock_buf_pos:rpos] bytes_read += (rpos - self._sock_buf_pos) assert bytes_read <= byte_count self._sock_buf_pos = rpos retval.append(addt_data) return "".join(retval) def _read_message(self): assert self._sock_lock.locked() bytes = self._read_bytes(5) message_code = bytes[0] data_len = struct.unpack("!i", bytes[1:])[0] - 4 bytes = self._read_bytes(data_len) assert len(bytes) == data_len msg = message_types[message_code].createFromData(bytes) ##print "_read_message() -> %r" % msg return msg def authenticate(self, user, **kwargs): self.verifyState("noauth") self._sock_lock.acquire() try: self._send(StartupMessage(user, database=kwargs.get("database",None))) self._flush() reader = MessageReader(self) reader.add_message(AuthenticationRequest, self._authentication_request(user, **kwargs)) reader.handle_messages() finally: self._sock_lock.release() def _authentication_request(self, user, **kwargs): def _func(msg): assert self._sock_lock.locked() if not msg.ok(self, user, **kwargs): raise InterfaceError("authentication method %s failed" % msg.__class__.__name__) self._state = "auth" reader = MessageReader(self) reader.add_message(ReadyForQuery, self._ready_for_query) reader.add_message(BackendKeyData, self._receive_backend_key_data) reader.handle_messages() return 1 return _func def _ready_for_query(self, msg): self._state = "ready" return True def _receive_backend_key_data(self, msg): self._backend_key_data = msg @sync_on_error def parse(self, statement, qs, param_types): self.verifyState("ready") type_info = [types.pg_type_info(x) for x in param_types] param_types, param_fc = [x[0] for x in type_info], [x[1] for x in type_info] # zip(*type_info) -- fails on empty arr if isinstance(qs, unicode): qs = qs.encode(self._client_encoding) self._send(Parse(statement, qs, param_types)) self._send(DescribePreparedStatement(statement)) self._send(Flush()) self._flush() reader = MessageReader(self) # ParseComplete is good. reader.add_message(ParseComplete, lambda msg: 0) # Well, we don't really care -- we're going to send whatever we # want and let the database deal with it. But thanks anyways! reader.add_message(ParameterDescription, lambda msg: 0) # We're not waiting for a row description. Return something # destinctive to let bind know that there is no output. reader.add_message(NoData, lambda msg: (None, param_fc)) # Common row description response reader.add_message(RowDescription, lambda msg: (msg, param_fc)) return reader.handle_messages() @sync_on_error def bind(self, portal, statement, params, parse_data, copy_stream): self.verifyState("ready") row_desc, param_fc = parse_data if row_desc == None: # no data coming out output_fc = () else: # We've got row_desc that allows us to identify what we're going to # get back from this statement. output_fc = [types.py_type_info(f) for f in row_desc.fields] self._send(Bind(portal, statement, param_fc, params, output_fc, client_encoding = self._client_encoding, integer_datetimes = self._integer_datetimes)) # We need to describe the portal after bind, since the return # format codes will be different (hopefully, always what we # requested). self._send(DescribePortal(portal)) self._send(Flush()) self._flush() # Read responses from server... reader = MessageReader(self) # BindComplete is good -- just ignore reader.add_message(BindComplete, lambda msg: 0) # NoData in this case means we're not executing a query. As a # result, we won't be fetching rows, so we'll never execute the # portal we just created... unless we execute it right away, which # we'll do. reader.add_message(NoData, self._bind_nodata, portal, reader, copy_stream) # Return the new row desc, since it will have the format types we # asked the server for reader.add_message(RowDescription, lambda msg: (msg, None)) return reader.handle_messages() def _copy_in_response(self, copyin, fileobj, old_reader): if fileobj == None: raise CopyQueryWithoutStreamError() while True: data = fileobj.read(self._block_size) if not data: break self._send(CopyData(data)) self._flush() self._send(CopyDone()) self._send(Sync()) self._flush() def _copy_out_response(self, copyout, fileobj, old_reader): if fileobj == None: raise CopyQueryWithoutStreamError() reader = MessageReader(self) reader.add_message(CopyData, self._copy_data, fileobj) reader.add_message(CopyDone, lambda msg: 1) reader.handle_messages() def _copy_data(self, copydata, fileobj): fileobj.write(copydata.data) def _bind_nodata(self, msg, portal, old_reader, copy_stream): # Bind message returned NoData, causing us to execute the command. self._send(Execute(portal, 0)) self._send(Sync()) self._flush() output = {} reader = MessageReader(self) reader.add_message(CopyOutResponse, self._copy_out_response, copy_stream, reader) reader.add_message(CopyInResponse, self._copy_in_response, copy_stream, reader) reader.add_message(CommandComplete, lambda msg, out: out.setdefault('msg', msg) and False, output) reader.add_message(ReadyForQuery, lambda msg: 1) reader.delay_raising_exception = True reader.handle_messages() old_reader.return_value((None, output['msg'])) @sync_on_error def send_simple_query(self, query_string, copy_stream=None): "Submit a simple query (PQsendQuery)" # Only use this for trivial queries, as its use is discouraged because: # CONS: # - Parameter are "injected" (they should be escaped by the app) # - Exesive memory usage (allways returns all rows on completion) # - Inneficient transmission of data in plain text (except for FETCH) # - No Prepared Statement support, each query is parsed every time # - Basic implementation: minimal error recovery and type support # PROS: # - compact: equivalent to Parse, Bind, Describe, Execute, Close, Sync # - doesn't returns ParseComplete, BindComplete, CloseComplete, NoData # - it supports multiple statements in a single query string # - it is available when the Streaming Replication Protocol is actived # NOTE: this is the protocol used by psycopg2 # (they also uses named cursors to overcome some drawbacks) self.verifyState("ready") if isinstance(query_string, unicode): query_string = query_string.encode(self._client_encoding) self._send(SimpleQuery(query_string)) self._flush() # define local storage for message handlers: output = {} rows = [] # create and add handlers for all the possible messages: reader = MessageReader(self) # read row description but continue processing messages... (return false) reader.add_message(RowDescription, lambda msg, out: out.setdefault('row_desc', msg) and False, output) reader.add_message(DataRow, lambda msg: self._fetch_datarow(msg, rows, output['row_desc'])) reader.add_message(EmptyQueryResponse, lambda msg: False) reader.add_message(CommandComplete, lambda msg, out: out.setdefault('complete', msg) and False, output) reader.add_message(CopyInResponse, self._copy_in_response, copy_stream, reader) reader.add_message(CopyOutResponse, self._copy_out_response, copy_stream, reader) # messages indicating that we've hit the end of the available data for this command reader.add_message(ReadyForQuery, lambda msg: 1) # process all messages and then raise exceptions (if any) reader.delay_raising_exception = True # start processing the messages from the backend: retval = reader.handle_messages() # return a dict with command complete / row description message values return output.get('row_desc'), output.get('complete'), rows @sync_on_error def fetch_rows(self, portal, row_count, row_desc): self.verifyState("ready") self._send(Execute(portal, row_count)) self._send(Flush()) self._flush() rows = [] reader = MessageReader(self) reader.add_message(DataRow, self._fetch_datarow, rows, row_desc) reader.add_message(PortalSuspended, lambda msg: 1) reader.add_message(CommandComplete, self._fetch_commandcomplete, portal) retval = reader.handle_messages() # retval = 2 when command complete, indicating that we've hit the # end of the available data for this command return (retval == 2), rows def _fetch_datarow(self, msg, rows, row_desc): rows.append( [ types.py_value( msg.fields[i], row_desc.fields[i], client_encoding=self._client_encoding, integer_datetimes=self._integer_datetimes, ) for i in range(len(msg.fields)) ] ) def _fetch_commandcomplete(self, msg, portal): self._send(ClosePortal(portal)) self._send(Sync()) self._flush() reader = MessageReader(self) reader.add_message(ReadyForQuery, self._fetch_commandcomplete_rfq) reader.add_message(CloseComplete, lambda msg: False) reader.handle_messages() return 2 # signal end-of-data def _fetch_commandcomplete_rfq(self, msg): self._state = "ready" return True # Send a Sync message, then read and discard all messages until we # receive a ReadyForQuery message. def _sync(self): # it is assumed _sync is called from sync_on_error, which holds # a _sock_lock throughout the call self._send(Sync()) self._flush() reader = MessageReader(self) reader.ignore_unhandled_messages = True reader.add_message(ReadyForQuery, lambda msg: True) reader.handle_messages() def close_statement(self, statement): if self._state == "closed": return self.verifyState("ready") self._sock_lock.acquire() try: self._send(ClosePreparedStatement(statement)) self._send(Sync()) self._flush() reader = MessageReader(self) reader.add_message(CloseComplete, lambda msg: 0) reader.add_message(ReadyForQuery, lambda msg: 1) reader.handle_messages() finally: self._sock_lock.release() def close_portal(self, portal): if self._state == "closed": return self.verifyState("ready") self._sock_lock.acquire() try: self._send(ClosePortal(portal)) self._send(Sync()) self._flush() reader = MessageReader(self) reader.add_message(CloseComplete, lambda msg: 0) reader.add_message(ReadyForQuery, lambda msg: 1) reader.handle_messages() finally: self._sock_lock.release() def close(self): self._sock_lock.acquire() try: self._send(Terminate()) self._flush() self._sock.close() self._state = "closed" finally: self._sock_lock.release() def _onParameterStatusReceived(self, msg): if msg.key == "client_encoding": self._client_encoding = types.encoding_convert(msg.value) ##print "_onParameterStatusReceived client_encoding", self._client_encoding elif msg.key == "integer_datetimes": self._integer_datetimes = (msg.value == "on") elif msg.key == "server_version": self._server_version = msg.value else: ##print "_onParameterStatusReceived ", msg.key, msg.value pass def handleNoticeResponse(self, msg): self.NoticeReceived(msg) def handleParameterStatus(self, msg): self.ParameterStatusReceived(msg) def handleNotificationResponse(self, msg): self.NotificationReceived(msg) def fileno(self): # This should be safe to do without a lock return self._sock.fileno() def isready(self): self._sock_lock.acquire() try: rlst, _wlst, _xlst = select.select([self], [], [], 0) if not rlst: return False self._sync() return True finally: self._sock_lock.release() def server_version(self): self.verifyState("ready") if not self._server_version: raise InterfaceError("Server did not provide server_version parameter.") return self._server_version def encoding(self): return self._client_encoding message_types = { "N": NoticeResponse, "R": AuthenticationRequest, "S": ParameterStatus, "K": BackendKeyData, "Z": ReadyForQuery, "T": RowDescription, "E": ErrorResponse, "D": DataRow, "C": CommandComplete, "1": ParseComplete, "2": BindComplete, "3": CloseComplete, "s": PortalSuspended, "n": NoData, "I": EmptyQueryResponse, "t": ParameterDescription, "A": NotificationResponse, "c": CopyDone, "d": CopyData, "G": CopyInResponse, "H": CopyOutResponse, }