]> Git Repo - qemu.git/blob - python/qemu/aqmp/protocol.py
python/aqmp: refactor _do_accept() into two distinct steps
[qemu.git] / python / qemu / aqmp / protocol.py
1 """
2 Generic Asynchronous Message-based Protocol Support
3
4 This module provides a generic framework for sending and receiving
5 messages over an asyncio stream. `AsyncProtocol` is an abstract class
6 that implements the core mechanisms of a simple send/receive protocol,
7 and is designed to be extended.
8
9 In this package, it is used as the implementation for the `QMPClient`
10 class.
11 """
12
13 # It's all the docstrings ... ! It's long for a good reason ^_^;
14 # pylint: disable=too-many-lines
15
16 import asyncio
17 from asyncio import StreamReader, StreamWriter
18 from enum import Enum
19 from functools import wraps
20 import logging
21 import socket
22 from ssl import SSLContext
23 from typing import (
24     Any,
25     Awaitable,
26     Callable,
27     Generic,
28     List,
29     Optional,
30     Tuple,
31     TypeVar,
32     Union,
33     cast,
34 )
35
36 from .error import QMPError
37 from .util import (
38     bottom_half,
39     create_task,
40     exception_summary,
41     flush,
42     is_closing,
43     pretty_traceback,
44     upper_half,
45     wait_closed,
46 )
47
48
49 T = TypeVar('T')
50 _U = TypeVar('_U')
51 _TaskFN = Callable[[], Awaitable[None]]  # aka ``async def func() -> None``
52
53 InternetAddrT = Tuple[str, int]
54 UnixAddrT = str
55 SocketAddrT = Union[UnixAddrT, InternetAddrT]
56
57
58 class Runstate(Enum):
59     """Protocol session runstate."""
60
61     #: Fully quiesced and disconnected.
62     IDLE = 0
63     #: In the process of connecting or establishing a session.
64     CONNECTING = 1
65     #: Fully connected and active session.
66     RUNNING = 2
67     #: In the process of disconnecting.
68     #: Runstate may be returned to `IDLE` by calling `disconnect()`.
69     DISCONNECTING = 3
70
71
72 class ConnectError(QMPError):
73     """
74     Raised when the initial connection process has failed.
75
76     This Exception always wraps a "root cause" exception that can be
77     interrogated for additional information.
78
79     :param error_message: Human-readable string describing the error.
80     :param exc: The root-cause exception.
81     """
82     def __init__(self, error_message: str, exc: Exception):
83         super().__init__(error_message)
84         #: Human-readable error string
85         self.error_message: str = error_message
86         #: Wrapped root cause exception
87         self.exc: Exception = exc
88
89     def __str__(self) -> str:
90         cause = str(self.exc)
91         if not cause:
92             # If there's no error string, use the exception name.
93             cause = exception_summary(self.exc)
94         return f"{self.error_message}: {cause}"
95
96
97 class StateError(QMPError):
98     """
99     An API command (connect, execute, etc) was issued at an inappropriate time.
100
101     This error is raised when a command like
102     :py:meth:`~AsyncProtocol.connect()` is issued at an inappropriate
103     time.
104
105     :param error_message: Human-readable string describing the state violation.
106     :param state: The actual `Runstate` seen at the time of the violation.
107     :param required: The `Runstate` required to process this command.
108     """
109     def __init__(self, error_message: str,
110                  state: Runstate, required: Runstate):
111         super().__init__(error_message)
112         self.error_message = error_message
113         self.state = state
114         self.required = required
115
116
117 F = TypeVar('F', bound=Callable[..., Any])  # pylint: disable=invalid-name
118
119
120 # Don't Panic.
121 def require(required_state: Runstate) -> Callable[[F], F]:
122     """
123     Decorator: protect a method so it can only be run in a certain `Runstate`.
124
125     :param required_state: The `Runstate` required to invoke this method.
126     :raise StateError: When the required `Runstate` is not met.
127     """
128     def _decorator(func: F) -> F:
129         # _decorator is the decorator that is built by calling the
130         # require() decorator factory; e.g.:
131         #
132         # @require(Runstate.IDLE) def foo(): ...
133         # will replace 'foo' with the result of '_decorator(foo)'.
134
135         @wraps(func)
136         def _wrapper(proto: 'AsyncProtocol[Any]',
137                      *args: Any, **kwargs: Any) -> Any:
138             # _wrapper is the function that gets executed prior to the
139             # decorated method.
140
141             name = type(proto).__name__
142
143             if proto.runstate != required_state:
144                 if proto.runstate == Runstate.CONNECTING:
145                     emsg = f"{name} is currently connecting."
146                 elif proto.runstate == Runstate.DISCONNECTING:
147                     emsg = (f"{name} is disconnecting."
148                             " Call disconnect() to return to IDLE state.")
149                 elif proto.runstate == Runstate.RUNNING:
150                     emsg = f"{name} is already connected and running."
151                 elif proto.runstate == Runstate.IDLE:
152                     emsg = f"{name} is disconnected and idle."
153                 else:
154                     assert False
155                 raise StateError(emsg, proto.runstate, required_state)
156             # No StateError, so call the wrapped method.
157             return func(proto, *args, **kwargs)
158
159         # Return the decorated method;
160         # Transforming Func to Decorated[Func].
161         return cast(F, _wrapper)
162
163     # Return the decorator instance from the decorator factory. Phew!
164     return _decorator
165
166
167 class AsyncProtocol(Generic[T]):
168     """
169     AsyncProtocol implements a generic async message-based protocol.
170
171     This protocol assumes the basic unit of information transfer between
172     client and server is a "message", the details of which are left up
173     to the implementation. It assumes the sending and receiving of these
174     messages is full-duplex and not necessarily correlated; i.e. it
175     supports asynchronous inbound messages.
176
177     It is designed to be extended by a specific protocol which provides
178     the implementations for how to read and send messages. These must be
179     defined in `_do_recv()` and `_do_send()`, respectively.
180
181     Other callbacks have a default implementation, but are intended to be
182     either extended or overridden:
183
184      - `_establish_session`:
185          The base implementation starts the reader/writer tasks.
186          A protocol implementation can override this call, inserting
187          actions to be taken prior to starting the reader/writer tasks
188          before the super() call; actions needing to occur afterwards
189          can be written after the super() call.
190      - `_on_message`:
191          Actions to be performed when a message is received.
192      - `_cb_outbound`:
193          Logging/Filtering hook for all outbound messages.
194      - `_cb_inbound`:
195          Logging/Filtering hook for all inbound messages.
196          This hook runs *before* `_on_message()`.
197
198     :param name:
199         Name used for logging messages, if any. By default, messages
200         will log to 'qemu.aqmp.protocol', but each individual connection
201         can be given its own logger by giving it a name; messages will
202         then log to 'qemu.aqmp.protocol.${name}'.
203     """
204     # pylint: disable=too-many-instance-attributes
205
206     #: Logger object for debugging messages from this connection.
207     logger = logging.getLogger(__name__)
208
209     # Maximum allowable size of read buffer
210     _limit = (64 * 1024)
211
212     # -------------------------
213     # Section: Public interface
214     # -------------------------
215
216     def __init__(self, name: Optional[str] = None) -> None:
217         #: The nickname for this connection, if any.
218         self.name: Optional[str] = name
219         if self.name is not None:
220             self.logger = self.logger.getChild(self.name)
221
222         # stream I/O
223         self._reader: Optional[StreamReader] = None
224         self._writer: Optional[StreamWriter] = None
225
226         # Outbound Message queue
227         self._outgoing: asyncio.Queue[T]
228
229         # Special, long-running tasks:
230         self._reader_task: Optional[asyncio.Future[None]] = None
231         self._writer_task: Optional[asyncio.Future[None]] = None
232
233         # Aggregate of the above two tasks, used for Exception management.
234         self._bh_tasks: Optional[asyncio.Future[Tuple[None, None]]] = None
235
236         #: Disconnect task. The disconnect implementation runs in a task
237         #: so that asynchronous disconnects (initiated by the
238         #: reader/writer) are allowed to wait for the reader/writers to
239         #: exit.
240         self._dc_task: Optional[asyncio.Future[None]] = None
241
242         self._runstate = Runstate.IDLE
243         self._runstate_changed: Optional[asyncio.Event] = None
244
245         # Workaround for bind()
246         self._sock: Optional[socket.socket] = None
247
248         # Server state for start_server() and _incoming()
249         self._server: Optional[asyncio.AbstractServer] = None
250         self._accepted: Optional[asyncio.Event] = None
251
252     def __repr__(self) -> str:
253         cls_name = type(self).__name__
254         tokens = []
255         if self.name is not None:
256             tokens.append(f"name={self.name!r}")
257         tokens.append(f"runstate={self.runstate.name}")
258         return f"<{cls_name} {' '.join(tokens)}>"
259
260     @property  # @upper_half
261     def runstate(self) -> Runstate:
262         """The current `Runstate` of the connection."""
263         return self._runstate
264
265     @upper_half
266     async def runstate_changed(self) -> Runstate:
267         """
268         Wait for the `runstate` to change, then return that runstate.
269         """
270         await self._runstate_event.wait()
271         return self.runstate
272
273     @upper_half
274     @require(Runstate.IDLE)
275     async def start_server_and_accept(
276             self, address: SocketAddrT,
277             ssl: Optional[SSLContext] = None
278     ) -> None:
279         """
280         Accept a connection and begin processing message queues.
281
282         If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
283
284         :param address:
285             Address to listen on; UNIX socket path or TCP address/port.
286         :param ssl: SSL context to use, if any.
287
288         :raise StateError: When the `Runstate` is not `IDLE`.
289         :raise ConnectError:
290             When a connection or session cannot be established.
291
292             This exception will wrap a more concrete one. In most cases,
293             the wrapped exception will be `OSError` or `EOFError`. If a
294             protocol-level failure occurs while establishing a new
295             session, the wrapped error may also be an `QMPError`.
296         """
297         await self._session_guard(
298             self._do_start_server(address, ssl),
299             'Failed to establish connection')
300         await self._session_guard(
301             self._establish_session(),
302             'Failed to establish session')
303         assert self.runstate == Runstate.RUNNING
304
305     @upper_half
306     @require(Runstate.IDLE)
307     async def connect(self, address: SocketAddrT,
308                       ssl: Optional[SSLContext] = None) -> None:
309         """
310         Connect to the server and begin processing message queues.
311
312         If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
313
314         :param address:
315             Address to connect to; UNIX socket path or TCP address/port.
316         :param ssl: SSL context to use, if any.
317
318         :raise StateError: When the `Runstate` is not `IDLE`.
319         :raise ConnectError:
320             When a connection or session cannot be established.
321
322             This exception will wrap a more concrete one. In most cases,
323             the wrapped exception will be `OSError` or `EOFError`. If a
324             protocol-level failure occurs while establishing a new
325             session, the wrapped error may also be an `QMPError`.
326         """
327         await self._session_guard(
328             self._do_connect(address, ssl),
329             'Failed to establish connection')
330         await self._session_guard(
331             self._establish_session(),
332             'Failed to establish session')
333         assert self.runstate == Runstate.RUNNING
334
335     @upper_half
336     async def disconnect(self) -> None:
337         """
338         Disconnect and wait for all tasks to fully stop.
339
340         If there was an exception that caused the reader/writers to
341         terminate prematurely, it will be raised here.
342
343         :raise Exception: When the reader or writer terminate unexpectedly.
344         """
345         self.logger.debug("disconnect() called.")
346         self._schedule_disconnect()
347         await self._wait_disconnect()
348
349     # --------------------------
350     # Section: Session machinery
351     # --------------------------
352
353     async def _session_guard(self, coro: Awaitable[None], emsg: str) -> None:
354         """
355         Async guard function used to roll back to `IDLE` on any error.
356
357         On any Exception, the state machine will be reset back to
358         `IDLE`. Most Exceptions will be wrapped with `ConnectError`, but
359         `BaseException` events will be left alone (This includes
360         asyncio.CancelledError, even prior to Python 3.8).
361
362         :param error_message:
363             Human-readable string describing what connection phase failed.
364
365         :raise BaseException:
366             When `BaseException` occurs in the guarded block.
367         :raise ConnectError:
368             When any other error is encountered in the guarded block.
369         """
370         # Note: After Python 3.6 support is removed, this should be an
371         # @asynccontextmanager instead of accepting a callback.
372         try:
373             await coro
374         except BaseException as err:
375             self.logger.error("%s: %s", emsg, exception_summary(err))
376             self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
377             try:
378                 # Reset the runstate back to IDLE.
379                 await self.disconnect()
380             except:
381                 # We don't expect any Exceptions from the disconnect function
382                 # here, because we failed to connect in the first place.
383                 # The disconnect() function is intended to perform
384                 # only cannot-fail cleanup here, but you never know.
385                 emsg = (
386                     "Unexpected bottom half exception. "
387                     "This is a bug in the QMP library. "
388                     "Please report it to <[email protected]> and "
389                     "CC: John Snow <[email protected]>."
390                 )
391                 self.logger.critical("%s:\n%s\n", emsg, pretty_traceback())
392                 raise
393
394             # CancelledError is an Exception with special semantic meaning;
395             # We do NOT want to wrap it up under ConnectError.
396             # NB: CancelledError is not a BaseException before Python 3.8
397             if isinstance(err, asyncio.CancelledError):
398                 raise
399
400             # Any other kind of error can be treated as some kind of connection
401             # failure broadly. Inspect the 'exc' field to explore the root
402             # cause in greater detail.
403             if isinstance(err, Exception):
404                 raise ConnectError(emsg, err) from err
405
406             # Raise BaseExceptions un-wrapped, they're more important.
407             raise
408
409     @property
410     def _runstate_event(self) -> asyncio.Event:
411         # asyncio.Event() objects should not be created prior to entrance into
412         # an event loop, so we can ensure we create it in the correct context.
413         # Create it on-demand *only* at the behest of an 'async def' method.
414         if not self._runstate_changed:
415             self._runstate_changed = asyncio.Event()
416         return self._runstate_changed
417
418     @upper_half
419     @bottom_half
420     def _set_state(self, state: Runstate) -> None:
421         """
422         Change the `Runstate` of the protocol connection.
423
424         Signals the `runstate_changed` event.
425         """
426         if state == self._runstate:
427             return
428
429         self.logger.debug("Transitioning from '%s' to '%s'.",
430                           str(self._runstate), str(state))
431         self._runstate = state
432         self._runstate_event.set()
433         self._runstate_event.clear()
434
435     @bottom_half  # However, it does not run from the R/W tasks.
436     async def _stop_server(self) -> None:
437         """
438         Stop listening for / accepting new incoming connections.
439         """
440         if self._server is None:
441             return
442
443         try:
444             self.logger.debug("Stopping server.")
445             self._server.close()
446             await self._server.wait_closed()
447             self.logger.debug("Server stopped.")
448         finally:
449             self._server = None
450
451     @bottom_half  # However, it does not run from the R/W tasks.
452     async def _incoming(self,
453                         reader: asyncio.StreamReader,
454                         writer: asyncio.StreamWriter) -> None:
455         """
456         Accept an incoming connection and signal the upper_half.
457
458         This method does the minimum necessary to accept a single
459         incoming connection. It signals back to the upper_half ASAP so
460         that any errors during session initialization can occur
461         naturally in the caller's stack.
462
463         :param reader: Incoming `asyncio.StreamReader`
464         :param writer: Incoming `asyncio.StreamWriter`
465         """
466         peer = writer.get_extra_info('peername', 'Unknown peer')
467         self.logger.debug("Incoming connection from %s", peer)
468
469         if self._reader or self._writer:
470             # Sadly, we can have more than one pending connection
471             # because of https://bugs.python.org/issue46715
472             # Close any extra connections we don't actually want.
473             self.logger.warning("Extraneous connection inadvertently accepted")
474             writer.close()
475             return
476
477         # A connection has been accepted; stop listening for new ones.
478         assert self._accepted is not None
479         await self._stop_server()
480         self._reader, self._writer = (reader, writer)
481         self._accepted.set()
482
483     def _bind_hack(self, address: Union[str, Tuple[str, int]]) -> None:
484         """
485         Used to create a socket in advance of accept().
486
487         This is a workaround to ensure that we can guarantee timing of
488         precisely when a socket exists to avoid a connection attempt
489         bouncing off of nothing.
490
491         Python 3.7+ adds a feature to separate the server creation and
492         listening phases instead, and should be used instead of this
493         hack.
494         """
495         if isinstance(address, tuple):
496             family = socket.AF_INET
497         else:
498             family = socket.AF_UNIX
499
500         sock = socket.socket(family, socket.SOCK_STREAM)
501         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
502
503         try:
504             sock.bind(address)
505         except:
506             sock.close()
507             raise
508
509         self._sock = sock
510
511     @upper_half
512     async def _do_start_server(self, address: SocketAddrT,
513                                ssl: Optional[SSLContext] = None) -> None:
514         """
515         Acting as the transport server, accept a single connection.
516
517         :param address:
518             Address to listen on; UNIX socket path or TCP address/port.
519         :param ssl: SSL context to use, if any.
520
521         :raise OSError: For stream-related errors.
522         """
523         assert self.runstate == Runstate.IDLE
524         self._set_state(Runstate.CONNECTING)
525
526         self.logger.debug("Awaiting connection on %s ...", address)
527         self._accepted = asyncio.Event()
528
529         if isinstance(address, tuple):
530             coro = asyncio.start_server(
531                 self._incoming,
532                 host=None if self._sock else address[0],
533                 port=None if self._sock else address[1],
534                 ssl=ssl,
535                 backlog=1,
536                 limit=self._limit,
537                 sock=self._sock,
538             )
539         else:
540             coro = asyncio.start_unix_server(
541                 self._incoming,
542                 path=None if self._sock else address,
543                 ssl=ssl,
544                 backlog=1,
545                 limit=self._limit,
546                 sock=self._sock,
547             )
548
549         # Allow runstate watchers to witness 'CONNECTING' state; some
550         # failures in the streaming layer are synchronous and will not
551         # otherwise yield.
552         await asyncio.sleep(0)
553
554         # This will start the server (bind(2), listen(2)). It will also
555         # call accept(2) if we yield, but we don't block on that here.
556         self._server = await coro
557
558         # Just for this one commit, wait for a peer.
559         # This gets split out in the next patch.
560         await self._do_accept()
561
562     @upper_half
563     async def _do_accept(self) -> None:
564         """
565         Wait for and accept an incoming connection.
566
567         Requires that we have not yet accepted an incoming connection
568         from the upper_half, but it's OK if the server is no longer
569         running because the bottom_half has already accepted the
570         connection.
571         """
572         assert self._accepted is not None
573         await self._accepted.wait()
574         assert self._server is None
575         self._accepted = None
576         self._sock = None
577
578         self.logger.debug("Connection accepted.")
579
580     @upper_half
581     async def _do_connect(self, address: SocketAddrT,
582                           ssl: Optional[SSLContext] = None) -> None:
583         """
584         Acting as the transport client, initiate a connection to a server.
585
586         :param address:
587             Address to connect to; UNIX socket path or TCP address/port.
588         :param ssl: SSL context to use, if any.
589
590         :raise OSError: For stream-related errors.
591         """
592         assert self.runstate == Runstate.IDLE
593         self._set_state(Runstate.CONNECTING)
594
595         # Allow runstate watchers to witness 'CONNECTING' state; some
596         # failures in the streaming layer are synchronous and will not
597         # otherwise yield.
598         await asyncio.sleep(0)
599
600         self.logger.debug("Connecting to %s ...", address)
601
602         if isinstance(address, tuple):
603             connect = asyncio.open_connection(
604                 address[0],
605                 address[1],
606                 ssl=ssl,
607                 limit=self._limit,
608             )
609         else:
610             connect = asyncio.open_unix_connection(
611                 path=address,
612                 ssl=ssl,
613                 limit=self._limit,
614             )
615         self._reader, self._writer = await connect
616
617         self.logger.debug("Connected.")
618
619     @upper_half
620     async def _establish_session(self) -> None:
621         """
622         Establish a new session.
623
624         Starts the readers/writer tasks; subclasses may perform their
625         own negotiations here. The Runstate will be RUNNING upon
626         successful conclusion.
627         """
628         assert self.runstate == Runstate.CONNECTING
629
630         self._outgoing = asyncio.Queue()
631
632         reader_coro = self._bh_loop_forever(self._bh_recv_message, 'Reader')
633         writer_coro = self._bh_loop_forever(self._bh_send_message, 'Writer')
634
635         self._reader_task = create_task(reader_coro)
636         self._writer_task = create_task(writer_coro)
637
638         self._bh_tasks = asyncio.gather(
639             self._reader_task,
640             self._writer_task,
641         )
642
643         self._set_state(Runstate.RUNNING)
644         await asyncio.sleep(0)  # Allow runstate_event to process
645
646     @upper_half
647     @bottom_half
648     def _schedule_disconnect(self) -> None:
649         """
650         Initiate a disconnect; idempotent.
651
652         This method is used both in the upper-half as a direct
653         consequence of `disconnect()`, and in the bottom-half in the
654         case of unhandled exceptions in the reader/writer tasks.
655
656         It can be invoked no matter what the `runstate` is.
657         """
658         if not self._dc_task:
659             self._set_state(Runstate.DISCONNECTING)
660             self.logger.debug("Scheduling disconnect.")
661             self._dc_task = create_task(self._bh_disconnect())
662
663     @upper_half
664     async def _wait_disconnect(self) -> None:
665         """
666         Waits for a previously scheduled disconnect to finish.
667
668         This method will gather any bottom half exceptions and re-raise
669         the one that occurred first; presuming it to be the root cause
670         of any subsequent Exceptions. It is intended to be used in the
671         upper half of the call chain.
672
673         :raise Exception:
674             Arbitrary exception re-raised on behalf of the reader/writer.
675         """
676         assert self.runstate == Runstate.DISCONNECTING
677         assert self._dc_task
678
679         aws: List[Awaitable[object]] = [self._dc_task]
680         if self._bh_tasks:
681             aws.insert(0, self._bh_tasks)
682         all_defined_tasks = asyncio.gather(*aws)
683
684         # Ensure disconnect is done; Exception (if any) is not raised here:
685         await asyncio.wait((self._dc_task,))
686
687         try:
688             await all_defined_tasks  # Raise Exceptions from the bottom half.
689         finally:
690             self._cleanup()
691             self._set_state(Runstate.IDLE)
692
693     @upper_half
694     def _cleanup(self) -> None:
695         """
696         Fully reset this object to a clean state and return to `IDLE`.
697         """
698         def _paranoid_task_erase(task: Optional['asyncio.Future[_U]']
699                                  ) -> Optional['asyncio.Future[_U]']:
700             # Help to erase a task, ENSURING it is fully quiesced first.
701             assert (task is None) or task.done()
702             return None if (task and task.done()) else task
703
704         assert self.runstate == Runstate.DISCONNECTING
705         self._dc_task = _paranoid_task_erase(self._dc_task)
706         self._reader_task = _paranoid_task_erase(self._reader_task)
707         self._writer_task = _paranoid_task_erase(self._writer_task)
708         self._bh_tasks = _paranoid_task_erase(self._bh_tasks)
709
710         self._reader = None
711         self._writer = None
712
713         # NB: _runstate_changed cannot be cleared because we still need it to
714         # send the final runstate changed event ...!
715
716     # ----------------------------
717     # Section: Bottom Half methods
718     # ----------------------------
719
720     @bottom_half
721     async def _bh_disconnect(self) -> None:
722         """
723         Disconnect and cancel all outstanding tasks.
724
725         It is designed to be called from its task context,
726         :py:obj:`~AsyncProtocol._dc_task`. By running in its own task,
727         it is free to wait on any pending actions that may still need to
728         occur in either the reader or writer tasks.
729         """
730         assert self.runstate == Runstate.DISCONNECTING
731
732         def _done(task: Optional['asyncio.Future[Any]']) -> bool:
733             return task is not None and task.done()
734
735         # Are we already in an error pathway? If either of the tasks are
736         # already done, or if we have no tasks but a reader/writer; we
737         # must be.
738         #
739         # NB: We can't use _bh_tasks to check for premature task
740         # completion, because it may not yet have had a chance to run
741         # and gather itself.
742         tasks = tuple(filter(None, (self._writer_task, self._reader_task)))
743         error_pathway = _done(self._reader_task) or _done(self._writer_task)
744         if not tasks:
745             error_pathway |= bool(self._reader) or bool(self._writer)
746
747         try:
748             # Try to flush the writer, if possible.
749             # This *may* cause an error and force us over into the error path.
750             if not error_pathway:
751                 await self._bh_flush_writer()
752         except BaseException as err:
753             error_pathway = True
754             emsg = "Failed to flush the writer"
755             self.logger.error("%s: %s", emsg, exception_summary(err))
756             self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
757             raise
758         finally:
759             # Cancel any still-running tasks (Won't raise):
760             if self._writer_task is not None and not self._writer_task.done():
761                 self.logger.debug("Cancelling writer task.")
762                 self._writer_task.cancel()
763             if self._reader_task is not None and not self._reader_task.done():
764                 self.logger.debug("Cancelling reader task.")
765                 self._reader_task.cancel()
766
767             # Close out the tasks entirely (Won't raise):
768             if tasks:
769                 self.logger.debug("Waiting for tasks to complete ...")
770                 await asyncio.wait(tasks)
771
772             # Lastly, close the stream itself. (*May raise*!):
773             await self._bh_close_stream(error_pathway)
774             self.logger.debug("Disconnected.")
775
776     @bottom_half
777     async def _bh_flush_writer(self) -> None:
778         if not self._writer_task:
779             return
780
781         self.logger.debug("Draining the outbound queue ...")
782         await self._outgoing.join()
783         if self._writer is not None:
784             self.logger.debug("Flushing the StreamWriter ...")
785             await flush(self._writer)
786
787     @bottom_half
788     async def _bh_close_stream(self, error_pathway: bool = False) -> None:
789         # NB: Closing the writer also implcitly closes the reader.
790         if not self._writer:
791             return
792
793         if not is_closing(self._writer):
794             self.logger.debug("Closing StreamWriter.")
795             self._writer.close()
796
797         self.logger.debug("Waiting for StreamWriter to close ...")
798         try:
799             await wait_closed(self._writer)
800         except Exception:  # pylint: disable=broad-except
801             # It's hard to tell if the Stream is already closed or
802             # not. Even if one of the tasks has failed, it may have
803             # failed for a higher-layered protocol reason. The
804             # stream could still be open and perfectly fine.
805             # I don't know how to discern its health here.
806
807             if error_pathway:
808                 # We already know that *something* went wrong. Let's
809                 # just trust that the Exception we already have is the
810                 # better one to present to the user, even if we don't
811                 # genuinely *know* the relationship between the two.
812                 self.logger.debug(
813                     "Discarding Exception from wait_closed:\n%s\n",
814                     pretty_traceback(),
815                 )
816             else:
817                 # Oops, this is a brand-new error!
818                 raise
819         finally:
820             self.logger.debug("StreamWriter closed.")
821
822     @bottom_half
823     async def _bh_loop_forever(self, async_fn: _TaskFN, name: str) -> None:
824         """
825         Run one of the bottom-half methods in a loop forever.
826
827         If the bottom half ever raises any exception, schedule a
828         disconnect that will terminate the entire loop.
829
830         :param async_fn: The bottom-half method to run in a loop.
831         :param name: The name of this task, used for logging.
832         """
833         try:
834             while True:
835                 await async_fn()
836         except asyncio.CancelledError:
837             # We have been cancelled by _bh_disconnect, exit gracefully.
838             self.logger.debug("Task.%s: cancelled.", name)
839             return
840         except BaseException as err:
841             self.logger.log(
842                 logging.INFO if isinstance(err, EOFError) else logging.ERROR,
843                 "Task.%s: %s",
844                 name, exception_summary(err)
845             )
846             self.logger.debug("Task.%s: failure:\n%s\n",
847                               name, pretty_traceback())
848             self._schedule_disconnect()
849             raise
850         finally:
851             self.logger.debug("Task.%s: exiting.", name)
852
853     @bottom_half
854     async def _bh_send_message(self) -> None:
855         """
856         Wait for an outgoing message, then send it.
857
858         Designed to be run in `_bh_loop_forever()`.
859         """
860         msg = await self._outgoing.get()
861         try:
862             await self._send(msg)
863         finally:
864             self._outgoing.task_done()
865
866     @bottom_half
867     async def _bh_recv_message(self) -> None:
868         """
869         Wait for an incoming message and call `_on_message` to route it.
870
871         Designed to be run in `_bh_loop_forever()`.
872         """
873         msg = await self._recv()
874         await self._on_message(msg)
875
876     # --------------------
877     # Section: Message I/O
878     # --------------------
879
880     @upper_half
881     @bottom_half
882     def _cb_outbound(self, msg: T) -> T:
883         """
884         Callback: outbound message hook.
885
886         This is intended for subclasses to be able to add arbitrary
887         hooks to filter or manipulate outgoing messages. The base
888         implementation does nothing but log the message without any
889         manipulation of the message.
890
891         :param msg: raw outbound message
892         :return: final outbound message
893         """
894         self.logger.debug("--> %s", str(msg))
895         return msg
896
897     @upper_half
898     @bottom_half
899     def _cb_inbound(self, msg: T) -> T:
900         """
901         Callback: inbound message hook.
902
903         This is intended for subclasses to be able to add arbitrary
904         hooks to filter or manipulate incoming messages. The base
905         implementation does nothing but log the message without any
906         manipulation of the message.
907
908         This method does not "handle" incoming messages; it is a filter.
909         The actual "endpoint" for incoming messages is `_on_message()`.
910
911         :param msg: raw inbound message
912         :return: processed inbound message
913         """
914         self.logger.debug("<-- %s", str(msg))
915         return msg
916
917     @upper_half
918     @bottom_half
919     async def _readline(self) -> bytes:
920         """
921         Wait for a newline from the incoming reader.
922
923         This method is provided as a convenience for upper-layer
924         protocols, as many are line-based.
925
926         This method *may* return a sequence of bytes without a trailing
927         newline if EOF occurs, but *some* bytes were received. In this
928         case, the next call will raise `EOFError`. It is assumed that
929         the layer 5 protocol will decide if there is anything meaningful
930         to be done with a partial message.
931
932         :raise OSError: For stream-related errors.
933         :raise EOFError:
934             If the reader stream is at EOF and there are no bytes to return.
935         :return: bytes, including the newline.
936         """
937         assert self._reader is not None
938         msg_bytes = await self._reader.readline()
939
940         if not msg_bytes:
941             if self._reader.at_eof():
942                 raise EOFError
943
944         return msg_bytes
945
946     @upper_half
947     @bottom_half
948     async def _do_recv(self) -> T:
949         """
950         Abstract: Read from the stream and return a message.
951
952         Very low-level; intended to only be called by `_recv()`.
953         """
954         raise NotImplementedError
955
956     @upper_half
957     @bottom_half
958     async def _recv(self) -> T:
959         """
960         Read an arbitrary protocol message.
961
962         .. warning::
963             This method is intended primarily for `_bh_recv_message()`
964             to use in an asynchronous task loop. Using it outside of
965             this loop will "steal" messages from the normal routing
966             mechanism. It is safe to use prior to `_establish_session()`,
967             but should not be used otherwise.
968
969         This method uses `_do_recv()` to retrieve the raw message, and
970         then transforms it using `_cb_inbound()`.
971
972         :return: A single (filtered, processed) protocol message.
973         """
974         message = await self._do_recv()
975         return self._cb_inbound(message)
976
977     @upper_half
978     @bottom_half
979     def _do_send(self, msg: T) -> None:
980         """
981         Abstract: Write a message to the stream.
982
983         Very low-level; intended to only be called by `_send()`.
984         """
985         raise NotImplementedError
986
987     @upper_half
988     @bottom_half
989     async def _send(self, msg: T) -> None:
990         """
991         Send an arbitrary protocol message.
992
993         This method will transform any outgoing messages according to
994         `_cb_outbound()`.
995
996         .. warning::
997             Like `_recv()`, this method is intended to be called by
998             the writer task loop that processes outgoing
999             messages. Calling it directly may circumvent logic
1000             implemented by the caller meant to correlate outgoing and
1001             incoming messages.
1002
1003         :raise OSError: For problems with the underlying stream.
1004         """
1005         msg = self._cb_outbound(msg)
1006         self._do_send(msg)
1007
1008     @bottom_half
1009     async def _on_message(self, msg: T) -> None:
1010         """
1011         Called to handle the receipt of a new message.
1012
1013         .. caution::
1014             This is executed from within the reader loop, so be advised
1015             that waiting on either the reader or writer task will lead
1016             to deadlock. Additionally, any unhandled exceptions will
1017             directly cause the loop to halt, so logic may be best-kept
1018             to a minimum if at all possible.
1019
1020         :param msg: The incoming message, already logged/filtered.
1021         """
1022         # Nothing to do in the abstract case.
This page took 0.083922 seconds and 4 git commands to generate.