]> Git Repo - qemu.git/blob - python/qemu/qmp/protocol.py
python: rename qemu.aqmp to qemu.qmp
[qemu.git] / python / qemu / qmp / protocol.py
1 """
2 Generic Asynchronous Message-based Protocol Support
3
4 This module provides a generic framework for sending and receiving
5 messages over an asyncio stream. `AsyncProtocol` is an abstract class
6 that implements the core mechanisms of a simple send/receive protocol,
7 and is designed to be extended.
8
9 In this package, it is used as the implementation for the `QMPClient`
10 class.
11 """
12
13 # It's all the docstrings ... ! It's long for a good reason ^_^;
14 # pylint: disable=too-many-lines
15
16 import asyncio
17 from asyncio import StreamReader, StreamWriter
18 from enum import Enum
19 from functools import wraps
20 import logging
21 from ssl import SSLContext
22 from typing import (
23     Any,
24     Awaitable,
25     Callable,
26     Generic,
27     List,
28     Optional,
29     Tuple,
30     TypeVar,
31     Union,
32     cast,
33 )
34
35 from .error import QMPError
36 from .util import (
37     bottom_half,
38     create_task,
39     exception_summary,
40     flush,
41     is_closing,
42     pretty_traceback,
43     upper_half,
44     wait_closed,
45 )
46
47
48 T = TypeVar('T')
49 _U = TypeVar('_U')
50 _TaskFN = Callable[[], Awaitable[None]]  # aka ``async def func() -> None``
51
52 InternetAddrT = Tuple[str, int]
53 UnixAddrT = str
54 SocketAddrT = Union[UnixAddrT, InternetAddrT]
55
56
57 class Runstate(Enum):
58     """Protocol session runstate."""
59
60     #: Fully quiesced and disconnected.
61     IDLE = 0
62     #: In the process of connecting or establishing a session.
63     CONNECTING = 1
64     #: Fully connected and active session.
65     RUNNING = 2
66     #: In the process of disconnecting.
67     #: Runstate may be returned to `IDLE` by calling `disconnect()`.
68     DISCONNECTING = 3
69
70
71 class ConnectError(QMPError):
72     """
73     Raised when the initial connection process has failed.
74
75     This Exception always wraps a "root cause" exception that can be
76     interrogated for additional information.
77
78     :param error_message: Human-readable string describing the error.
79     :param exc: The root-cause exception.
80     """
81     def __init__(self, error_message: str, exc: Exception):
82         super().__init__(error_message)
83         #: Human-readable error string
84         self.error_message: str = error_message
85         #: Wrapped root cause exception
86         self.exc: Exception = exc
87
88     def __str__(self) -> str:
89         cause = str(self.exc)
90         if not cause:
91             # If there's no error string, use the exception name.
92             cause = exception_summary(self.exc)
93         return f"{self.error_message}: {cause}"
94
95
96 class StateError(QMPError):
97     """
98     An API command (connect, execute, etc) was issued at an inappropriate time.
99
100     This error is raised when a command like
101     :py:meth:`~AsyncProtocol.connect()` is issued at an inappropriate
102     time.
103
104     :param error_message: Human-readable string describing the state violation.
105     :param state: The actual `Runstate` seen at the time of the violation.
106     :param required: The `Runstate` required to process this command.
107     """
108     def __init__(self, error_message: str,
109                  state: Runstate, required: Runstate):
110         super().__init__(error_message)
111         self.error_message = error_message
112         self.state = state
113         self.required = required
114
115
116 F = TypeVar('F', bound=Callable[..., Any])  # pylint: disable=invalid-name
117
118
119 # Don't Panic.
120 def require(required_state: Runstate) -> Callable[[F], F]:
121     """
122     Decorator: protect a method so it can only be run in a certain `Runstate`.
123
124     :param required_state: The `Runstate` required to invoke this method.
125     :raise StateError: When the required `Runstate` is not met.
126     """
127     def _decorator(func: F) -> F:
128         # _decorator is the decorator that is built by calling the
129         # require() decorator factory; e.g.:
130         #
131         # @require(Runstate.IDLE) def foo(): ...
132         # will replace 'foo' with the result of '_decorator(foo)'.
133
134         @wraps(func)
135         def _wrapper(proto: 'AsyncProtocol[Any]',
136                      *args: Any, **kwargs: Any) -> Any:
137             # _wrapper is the function that gets executed prior to the
138             # decorated method.
139
140             name = type(proto).__name__
141
142             if proto.runstate != required_state:
143                 if proto.runstate == Runstate.CONNECTING:
144                     emsg = f"{name} is currently connecting."
145                 elif proto.runstate == Runstate.DISCONNECTING:
146                     emsg = (f"{name} is disconnecting."
147                             " Call disconnect() to return to IDLE state.")
148                 elif proto.runstate == Runstate.RUNNING:
149                     emsg = f"{name} is already connected and running."
150                 elif proto.runstate == Runstate.IDLE:
151                     emsg = f"{name} is disconnected and idle."
152                 else:
153                     assert False
154                 raise StateError(emsg, proto.runstate, required_state)
155             # No StateError, so call the wrapped method.
156             return func(proto, *args, **kwargs)
157
158         # Return the decorated method;
159         # Transforming Func to Decorated[Func].
160         return cast(F, _wrapper)
161
162     # Return the decorator instance from the decorator factory. Phew!
163     return _decorator
164
165
166 class AsyncProtocol(Generic[T]):
167     """
168     AsyncProtocol implements a generic async message-based protocol.
169
170     This protocol assumes the basic unit of information transfer between
171     client and server is a "message", the details of which are left up
172     to the implementation. It assumes the sending and receiving of these
173     messages is full-duplex and not necessarily correlated; i.e. it
174     supports asynchronous inbound messages.
175
176     It is designed to be extended by a specific protocol which provides
177     the implementations for how to read and send messages. These must be
178     defined in `_do_recv()` and `_do_send()`, respectively.
179
180     Other callbacks have a default implementation, but are intended to be
181     either extended or overridden:
182
183      - `_establish_session`:
184          The base implementation starts the reader/writer tasks.
185          A protocol implementation can override this call, inserting
186          actions to be taken prior to starting the reader/writer tasks
187          before the super() call; actions needing to occur afterwards
188          can be written after the super() call.
189      - `_on_message`:
190          Actions to be performed when a message is received.
191      - `_cb_outbound`:
192          Logging/Filtering hook for all outbound messages.
193      - `_cb_inbound`:
194          Logging/Filtering hook for all inbound messages.
195          This hook runs *before* `_on_message()`.
196
197     :param name:
198         Name used for logging messages, if any. By default, messages
199         will log to 'qemu.qmp.protocol', but each individual connection
200         can be given its own logger by giving it a name; messages will
201         then log to 'qemu.qmp.protocol.${name}'.
202     """
203     # pylint: disable=too-many-instance-attributes
204
205     #: Logger object for debugging messages from this connection.
206     logger = logging.getLogger(__name__)
207
208     # Maximum allowable size of read buffer
209     _limit = (64 * 1024)
210
211     # -------------------------
212     # Section: Public interface
213     # -------------------------
214
215     def __init__(self, name: Optional[str] = None) -> None:
216         #: The nickname for this connection, if any.
217         self.name: Optional[str] = name
218         if self.name is not None:
219             self.logger = self.logger.getChild(self.name)
220
221         # stream I/O
222         self._reader: Optional[StreamReader] = None
223         self._writer: Optional[StreamWriter] = None
224
225         # Outbound Message queue
226         self._outgoing: asyncio.Queue[T]
227
228         # Special, long-running tasks:
229         self._reader_task: Optional[asyncio.Future[None]] = None
230         self._writer_task: Optional[asyncio.Future[None]] = None
231
232         # Aggregate of the above two tasks, used for Exception management.
233         self._bh_tasks: Optional[asyncio.Future[Tuple[None, None]]] = None
234
235         #: Disconnect task. The disconnect implementation runs in a task
236         #: so that asynchronous disconnects (initiated by the
237         #: reader/writer) are allowed to wait for the reader/writers to
238         #: exit.
239         self._dc_task: Optional[asyncio.Future[None]] = None
240
241         self._runstate = Runstate.IDLE
242         self._runstate_changed: Optional[asyncio.Event] = None
243
244         # Server state for start_server() and _incoming()
245         self._server: Optional[asyncio.AbstractServer] = None
246         self._accepted: Optional[asyncio.Event] = None
247
248     def __repr__(self) -> str:
249         cls_name = type(self).__name__
250         tokens = []
251         if self.name is not None:
252             tokens.append(f"name={self.name!r}")
253         tokens.append(f"runstate={self.runstate.name}")
254         return f"<{cls_name} {' '.join(tokens)}>"
255
256     @property  # @upper_half
257     def runstate(self) -> Runstate:
258         """The current `Runstate` of the connection."""
259         return self._runstate
260
261     @upper_half
262     async def runstate_changed(self) -> Runstate:
263         """
264         Wait for the `runstate` to change, then return that runstate.
265         """
266         await self._runstate_event.wait()
267         return self.runstate
268
269     @upper_half
270     @require(Runstate.IDLE)
271     async def start_server_and_accept(
272             self, address: SocketAddrT,
273             ssl: Optional[SSLContext] = None
274     ) -> None:
275         """
276         Accept a connection and begin processing message queues.
277
278         If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
279         This method is precisely equivalent to calling `start_server()`
280         followed by `accept()`.
281
282         :param address:
283             Address to listen on; UNIX socket path or TCP address/port.
284         :param ssl: SSL context to use, if any.
285
286         :raise StateError: When the `Runstate` is not `IDLE`.
287         :raise ConnectError:
288             When a connection or session cannot be established.
289
290             This exception will wrap a more concrete one. In most cases,
291             the wrapped exception will be `OSError` or `EOFError`. If a
292             protocol-level failure occurs while establishing a new
293             session, the wrapped error may also be an `QMPError`.
294         """
295         await self.start_server(address, ssl)
296         await self.accept()
297         assert self.runstate == Runstate.RUNNING
298
299     @upper_half
300     @require(Runstate.IDLE)
301     async def start_server(self, address: SocketAddrT,
302                            ssl: Optional[SSLContext] = None) -> None:
303         """
304         Start listening for an incoming connection, but do not wait for a peer.
305
306         This method starts listening for an incoming connection, but
307         does not block waiting for a peer. This call will return
308         immediately after binding and listening on a socket. A later
309         call to `accept()` must be made in order to finalize the
310         incoming connection.
311
312         :param address:
313             Address to listen on; UNIX socket path or TCP address/port.
314         :param ssl: SSL context to use, if any.
315
316         :raise StateError: When the `Runstate` is not `IDLE`.
317         :raise ConnectError:
318             When the server could not start listening on this address.
319
320             This exception will wrap a more concrete one. In most cases,
321             the wrapped exception will be `OSError`.
322         """
323         await self._session_guard(
324             self._do_start_server(address, ssl),
325             'Failed to establish connection')
326         assert self.runstate == Runstate.CONNECTING
327
328     @upper_half
329     @require(Runstate.CONNECTING)
330     async def accept(self) -> None:
331         """
332         Accept an incoming connection and begin processing message queues.
333
334         If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
335
336         :raise StateError: When the `Runstate` is not `CONNECTING`.
337         :raise QMPError: When `start_server()` was not called yet.
338         :raise ConnectError:
339             When a connection or session cannot be established.
340
341             This exception will wrap a more concrete one. In most cases,
342             the wrapped exception will be `OSError` or `EOFError`. If a
343             protocol-level failure occurs while establishing a new
344             session, the wrapped error may also be an `QMPError`.
345         """
346         if self._accepted is None:
347             raise QMPError("Cannot call accept() before start_server().")
348         await self._session_guard(
349             self._do_accept(),
350             'Failed to establish connection')
351         await self._session_guard(
352             self._establish_session(),
353             'Failed to establish session')
354         assert self.runstate == Runstate.RUNNING
355
356     @upper_half
357     @require(Runstate.IDLE)
358     async def connect(self, address: SocketAddrT,
359                       ssl: Optional[SSLContext] = None) -> None:
360         """
361         Connect to the server and begin processing message queues.
362
363         If this call fails, `runstate` is guaranteed to be set back to `IDLE`.
364
365         :param address:
366             Address to connect to; UNIX socket path or TCP address/port.
367         :param ssl: SSL context to use, if any.
368
369         :raise StateError: When the `Runstate` is not `IDLE`.
370         :raise ConnectError:
371             When a connection or session cannot be established.
372
373             This exception will wrap a more concrete one. In most cases,
374             the wrapped exception will be `OSError` or `EOFError`. If a
375             protocol-level failure occurs while establishing a new
376             session, the wrapped error may also be an `QMPError`.
377         """
378         await self._session_guard(
379             self._do_connect(address, ssl),
380             'Failed to establish connection')
381         await self._session_guard(
382             self._establish_session(),
383             'Failed to establish session')
384         assert self.runstate == Runstate.RUNNING
385
386     @upper_half
387     async def disconnect(self) -> None:
388         """
389         Disconnect and wait for all tasks to fully stop.
390
391         If there was an exception that caused the reader/writers to
392         terminate prematurely, it will be raised here.
393
394         :raise Exception: When the reader or writer terminate unexpectedly.
395         """
396         self.logger.debug("disconnect() called.")
397         self._schedule_disconnect()
398         await self._wait_disconnect()
399
400     # --------------------------
401     # Section: Session machinery
402     # --------------------------
403
404     async def _session_guard(self, coro: Awaitable[None], emsg: str) -> None:
405         """
406         Async guard function used to roll back to `IDLE` on any error.
407
408         On any Exception, the state machine will be reset back to
409         `IDLE`. Most Exceptions will be wrapped with `ConnectError`, but
410         `BaseException` events will be left alone (This includes
411         asyncio.CancelledError, even prior to Python 3.8).
412
413         :param error_message:
414             Human-readable string describing what connection phase failed.
415
416         :raise BaseException:
417             When `BaseException` occurs in the guarded block.
418         :raise ConnectError:
419             When any other error is encountered in the guarded block.
420         """
421         # Note: After Python 3.6 support is removed, this should be an
422         # @asynccontextmanager instead of accepting a callback.
423         try:
424             await coro
425         except BaseException as err:
426             self.logger.error("%s: %s", emsg, exception_summary(err))
427             self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
428             try:
429                 # Reset the runstate back to IDLE.
430                 await self.disconnect()
431             except:
432                 # We don't expect any Exceptions from the disconnect function
433                 # here, because we failed to connect in the first place.
434                 # The disconnect() function is intended to perform
435                 # only cannot-fail cleanup here, but you never know.
436                 emsg = (
437                     "Unexpected bottom half exception. "
438                     "This is a bug in the QMP library. "
439                     "Please report it to <[email protected]> and "
440                     "CC: John Snow <[email protected]>."
441                 )
442                 self.logger.critical("%s:\n%s\n", emsg, pretty_traceback())
443                 raise
444
445             # CancelledError is an Exception with special semantic meaning;
446             # We do NOT want to wrap it up under ConnectError.
447             # NB: CancelledError is not a BaseException before Python 3.8
448             if isinstance(err, asyncio.CancelledError):
449                 raise
450
451             # Any other kind of error can be treated as some kind of connection
452             # failure broadly. Inspect the 'exc' field to explore the root
453             # cause in greater detail.
454             if isinstance(err, Exception):
455                 raise ConnectError(emsg, err) from err
456
457             # Raise BaseExceptions un-wrapped, they're more important.
458             raise
459
460     @property
461     def _runstate_event(self) -> asyncio.Event:
462         # asyncio.Event() objects should not be created prior to entrance into
463         # an event loop, so we can ensure we create it in the correct context.
464         # Create it on-demand *only* at the behest of an 'async def' method.
465         if not self._runstate_changed:
466             self._runstate_changed = asyncio.Event()
467         return self._runstate_changed
468
469     @upper_half
470     @bottom_half
471     def _set_state(self, state: Runstate) -> None:
472         """
473         Change the `Runstate` of the protocol connection.
474
475         Signals the `runstate_changed` event.
476         """
477         if state == self._runstate:
478             return
479
480         self.logger.debug("Transitioning from '%s' to '%s'.",
481                           str(self._runstate), str(state))
482         self._runstate = state
483         self._runstate_event.set()
484         self._runstate_event.clear()
485
486     @bottom_half
487     async def _stop_server(self) -> None:
488         """
489         Stop listening for / accepting new incoming connections.
490         """
491         if self._server is None:
492             return
493
494         try:
495             self.logger.debug("Stopping server.")
496             self._server.close()
497             await self._server.wait_closed()
498             self.logger.debug("Server stopped.")
499         finally:
500             self._server = None
501
502     @bottom_half  # However, it does not run from the R/W tasks.
503     async def _incoming(self,
504                         reader: asyncio.StreamReader,
505                         writer: asyncio.StreamWriter) -> None:
506         """
507         Accept an incoming connection and signal the upper_half.
508
509         This method does the minimum necessary to accept a single
510         incoming connection. It signals back to the upper_half ASAP so
511         that any errors during session initialization can occur
512         naturally in the caller's stack.
513
514         :param reader: Incoming `asyncio.StreamReader`
515         :param writer: Incoming `asyncio.StreamWriter`
516         """
517         peer = writer.get_extra_info('peername', 'Unknown peer')
518         self.logger.debug("Incoming connection from %s", peer)
519
520         if self._reader or self._writer:
521             # Sadly, we can have more than one pending connection
522             # because of https://bugs.python.org/issue46715
523             # Close any extra connections we don't actually want.
524             self.logger.warning("Extraneous connection inadvertently accepted")
525             writer.close()
526             return
527
528         # A connection has been accepted; stop listening for new ones.
529         assert self._accepted is not None
530         await self._stop_server()
531         self._reader, self._writer = (reader, writer)
532         self._accepted.set()
533
534     @upper_half
535     async def _do_start_server(self, address: SocketAddrT,
536                                ssl: Optional[SSLContext] = None) -> None:
537         """
538         Start listening for an incoming connection, but do not wait for a peer.
539
540         This method starts listening for an incoming connection, but does not
541         block waiting for a peer. This call will return immediately after
542         binding and listening to a socket. A later call to accept() must be
543         made in order to finalize the incoming connection.
544
545         :param address:
546             Address to listen on; UNIX socket path or TCP address/port.
547         :param ssl: SSL context to use, if any.
548
549         :raise OSError: For stream-related errors.
550         """
551         assert self.runstate == Runstate.IDLE
552         self._set_state(Runstate.CONNECTING)
553
554         self.logger.debug("Awaiting connection on %s ...", address)
555         self._accepted = asyncio.Event()
556
557         if isinstance(address, tuple):
558             coro = asyncio.start_server(
559                 self._incoming,
560                 host=address[0],
561                 port=address[1],
562                 ssl=ssl,
563                 backlog=1,
564                 limit=self._limit,
565             )
566         else:
567             coro = asyncio.start_unix_server(
568                 self._incoming,
569                 path=address,
570                 ssl=ssl,
571                 backlog=1,
572                 limit=self._limit,
573             )
574
575         # Allow runstate watchers to witness 'CONNECTING' state; some
576         # failures in the streaming layer are synchronous and will not
577         # otherwise yield.
578         await asyncio.sleep(0)
579
580         # This will start the server (bind(2), listen(2)). It will also
581         # call accept(2) if we yield, but we don't block on that here.
582         self._server = await coro
583         self.logger.debug("Server listening on %s", address)
584
585     @upper_half
586     async def _do_accept(self) -> None:
587         """
588         Wait for and accept an incoming connection.
589
590         Requires that we have not yet accepted an incoming connection
591         from the upper_half, but it's OK if the server is no longer
592         running because the bottom_half has already accepted the
593         connection.
594         """
595         assert self._accepted is not None
596         await self._accepted.wait()
597         assert self._server is None
598         self._accepted = None
599
600         self.logger.debug("Connection accepted.")
601
602     @upper_half
603     async def _do_connect(self, address: SocketAddrT,
604                           ssl: Optional[SSLContext] = None) -> None:
605         """
606         Acting as the transport client, initiate a connection to a server.
607
608         :param address:
609             Address to connect to; UNIX socket path or TCP address/port.
610         :param ssl: SSL context to use, if any.
611
612         :raise OSError: For stream-related errors.
613         """
614         assert self.runstate == Runstate.IDLE
615         self._set_state(Runstate.CONNECTING)
616
617         # Allow runstate watchers to witness 'CONNECTING' state; some
618         # failures in the streaming layer are synchronous and will not
619         # otherwise yield.
620         await asyncio.sleep(0)
621
622         self.logger.debug("Connecting to %s ...", address)
623
624         if isinstance(address, tuple):
625             connect = asyncio.open_connection(
626                 address[0],
627                 address[1],
628                 ssl=ssl,
629                 limit=self._limit,
630             )
631         else:
632             connect = asyncio.open_unix_connection(
633                 path=address,
634                 ssl=ssl,
635                 limit=self._limit,
636             )
637         self._reader, self._writer = await connect
638
639         self.logger.debug("Connected.")
640
641     @upper_half
642     async def _establish_session(self) -> None:
643         """
644         Establish a new session.
645
646         Starts the readers/writer tasks; subclasses may perform their
647         own negotiations here. The Runstate will be RUNNING upon
648         successful conclusion.
649         """
650         assert self.runstate == Runstate.CONNECTING
651
652         self._outgoing = asyncio.Queue()
653
654         reader_coro = self._bh_loop_forever(self._bh_recv_message, 'Reader')
655         writer_coro = self._bh_loop_forever(self._bh_send_message, 'Writer')
656
657         self._reader_task = create_task(reader_coro)
658         self._writer_task = create_task(writer_coro)
659
660         self._bh_tasks = asyncio.gather(
661             self._reader_task,
662             self._writer_task,
663         )
664
665         self._set_state(Runstate.RUNNING)
666         await asyncio.sleep(0)  # Allow runstate_event to process
667
668     @upper_half
669     @bottom_half
670     def _schedule_disconnect(self) -> None:
671         """
672         Initiate a disconnect; idempotent.
673
674         This method is used both in the upper-half as a direct
675         consequence of `disconnect()`, and in the bottom-half in the
676         case of unhandled exceptions in the reader/writer tasks.
677
678         It can be invoked no matter what the `runstate` is.
679         """
680         if not self._dc_task:
681             self._set_state(Runstate.DISCONNECTING)
682             self.logger.debug("Scheduling disconnect.")
683             self._dc_task = create_task(self._bh_disconnect())
684
685     @upper_half
686     async def _wait_disconnect(self) -> None:
687         """
688         Waits for a previously scheduled disconnect to finish.
689
690         This method will gather any bottom half exceptions and re-raise
691         the one that occurred first; presuming it to be the root cause
692         of any subsequent Exceptions. It is intended to be used in the
693         upper half of the call chain.
694
695         :raise Exception:
696             Arbitrary exception re-raised on behalf of the reader/writer.
697         """
698         assert self.runstate == Runstate.DISCONNECTING
699         assert self._dc_task
700
701         aws: List[Awaitable[object]] = [self._dc_task]
702         if self._bh_tasks:
703             aws.insert(0, self._bh_tasks)
704         all_defined_tasks = asyncio.gather(*aws)
705
706         # Ensure disconnect is done; Exception (if any) is not raised here:
707         await asyncio.wait((self._dc_task,))
708
709         try:
710             await all_defined_tasks  # Raise Exceptions from the bottom half.
711         finally:
712             self._cleanup()
713             self._set_state(Runstate.IDLE)
714
715     @upper_half
716     def _cleanup(self) -> None:
717         """
718         Fully reset this object to a clean state and return to `IDLE`.
719         """
720         def _paranoid_task_erase(task: Optional['asyncio.Future[_U]']
721                                  ) -> Optional['asyncio.Future[_U]']:
722             # Help to erase a task, ENSURING it is fully quiesced first.
723             assert (task is None) or task.done()
724             return None if (task and task.done()) else task
725
726         assert self.runstate == Runstate.DISCONNECTING
727         self._dc_task = _paranoid_task_erase(self._dc_task)
728         self._reader_task = _paranoid_task_erase(self._reader_task)
729         self._writer_task = _paranoid_task_erase(self._writer_task)
730         self._bh_tasks = _paranoid_task_erase(self._bh_tasks)
731
732         self._reader = None
733         self._writer = None
734         self._accepted = None
735
736         # NB: _runstate_changed cannot be cleared because we still need it to
737         # send the final runstate changed event ...!
738
739     # ----------------------------
740     # Section: Bottom Half methods
741     # ----------------------------
742
743     @bottom_half
744     async def _bh_disconnect(self) -> None:
745         """
746         Disconnect and cancel all outstanding tasks.
747
748         It is designed to be called from its task context,
749         :py:obj:`~AsyncProtocol._dc_task`. By running in its own task,
750         it is free to wait on any pending actions that may still need to
751         occur in either the reader or writer tasks.
752         """
753         assert self.runstate == Runstate.DISCONNECTING
754
755         def _done(task: Optional['asyncio.Future[Any]']) -> bool:
756             return task is not None and task.done()
757
758         # If the server is running, stop it.
759         await self._stop_server()
760
761         # Are we already in an error pathway? If either of the tasks are
762         # already done, or if we have no tasks but a reader/writer; we
763         # must be.
764         #
765         # NB: We can't use _bh_tasks to check for premature task
766         # completion, because it may not yet have had a chance to run
767         # and gather itself.
768         tasks = tuple(filter(None, (self._writer_task, self._reader_task)))
769         error_pathway = _done(self._reader_task) or _done(self._writer_task)
770         if not tasks:
771             error_pathway |= bool(self._reader) or bool(self._writer)
772
773         try:
774             # Try to flush the writer, if possible.
775             # This *may* cause an error and force us over into the error path.
776             if not error_pathway:
777                 await self._bh_flush_writer()
778         except BaseException as err:
779             error_pathway = True
780             emsg = "Failed to flush the writer"
781             self.logger.error("%s: %s", emsg, exception_summary(err))
782             self.logger.debug("%s:\n%s\n", emsg, pretty_traceback())
783             raise
784         finally:
785             # Cancel any still-running tasks (Won't raise):
786             if self._writer_task is not None and not self._writer_task.done():
787                 self.logger.debug("Cancelling writer task.")
788                 self._writer_task.cancel()
789             if self._reader_task is not None and not self._reader_task.done():
790                 self.logger.debug("Cancelling reader task.")
791                 self._reader_task.cancel()
792
793             # Close out the tasks entirely (Won't raise):
794             if tasks:
795                 self.logger.debug("Waiting for tasks to complete ...")
796                 await asyncio.wait(tasks)
797
798             # Lastly, close the stream itself. (*May raise*!):
799             await self._bh_close_stream(error_pathway)
800             self.logger.debug("Disconnected.")
801
802     @bottom_half
803     async def _bh_flush_writer(self) -> None:
804         if not self._writer_task:
805             return
806
807         self.logger.debug("Draining the outbound queue ...")
808         await self._outgoing.join()
809         if self._writer is not None:
810             self.logger.debug("Flushing the StreamWriter ...")
811             await flush(self._writer)
812
813     @bottom_half
814     async def _bh_close_stream(self, error_pathway: bool = False) -> None:
815         # NB: Closing the writer also implcitly closes the reader.
816         if not self._writer:
817             return
818
819         if not is_closing(self._writer):
820             self.logger.debug("Closing StreamWriter.")
821             self._writer.close()
822
823         self.logger.debug("Waiting for StreamWriter to close ...")
824         try:
825             await wait_closed(self._writer)
826         except Exception:  # pylint: disable=broad-except
827             # It's hard to tell if the Stream is already closed or
828             # not. Even if one of the tasks has failed, it may have
829             # failed for a higher-layered protocol reason. The
830             # stream could still be open and perfectly fine.
831             # I don't know how to discern its health here.
832
833             if error_pathway:
834                 # We already know that *something* went wrong. Let's
835                 # just trust that the Exception we already have is the
836                 # better one to present to the user, even if we don't
837                 # genuinely *know* the relationship between the two.
838                 self.logger.debug(
839                     "Discarding Exception from wait_closed:\n%s\n",
840                     pretty_traceback(),
841                 )
842             else:
843                 # Oops, this is a brand-new error!
844                 raise
845         finally:
846             self.logger.debug("StreamWriter closed.")
847
848     @bottom_half
849     async def _bh_loop_forever(self, async_fn: _TaskFN, name: str) -> None:
850         """
851         Run one of the bottom-half methods in a loop forever.
852
853         If the bottom half ever raises any exception, schedule a
854         disconnect that will terminate the entire loop.
855
856         :param async_fn: The bottom-half method to run in a loop.
857         :param name: The name of this task, used for logging.
858         """
859         try:
860             while True:
861                 await async_fn()
862         except asyncio.CancelledError:
863             # We have been cancelled by _bh_disconnect, exit gracefully.
864             self.logger.debug("Task.%s: cancelled.", name)
865             return
866         except BaseException as err:
867             self.logger.log(
868                 logging.INFO if isinstance(err, EOFError) else logging.ERROR,
869                 "Task.%s: %s",
870                 name, exception_summary(err)
871             )
872             self.logger.debug("Task.%s: failure:\n%s\n",
873                               name, pretty_traceback())
874             self._schedule_disconnect()
875             raise
876         finally:
877             self.logger.debug("Task.%s: exiting.", name)
878
879     @bottom_half
880     async def _bh_send_message(self) -> None:
881         """
882         Wait for an outgoing message, then send it.
883
884         Designed to be run in `_bh_loop_forever()`.
885         """
886         msg = await self._outgoing.get()
887         try:
888             await self._send(msg)
889         finally:
890             self._outgoing.task_done()
891
892     @bottom_half
893     async def _bh_recv_message(self) -> None:
894         """
895         Wait for an incoming message and call `_on_message` to route it.
896
897         Designed to be run in `_bh_loop_forever()`.
898         """
899         msg = await self._recv()
900         await self._on_message(msg)
901
902     # --------------------
903     # Section: Message I/O
904     # --------------------
905
906     @upper_half
907     @bottom_half
908     def _cb_outbound(self, msg: T) -> T:
909         """
910         Callback: outbound message hook.
911
912         This is intended for subclasses to be able to add arbitrary
913         hooks to filter or manipulate outgoing messages. The base
914         implementation does nothing but log the message without any
915         manipulation of the message.
916
917         :param msg: raw outbound message
918         :return: final outbound message
919         """
920         self.logger.debug("--> %s", str(msg))
921         return msg
922
923     @upper_half
924     @bottom_half
925     def _cb_inbound(self, msg: T) -> T:
926         """
927         Callback: inbound message hook.
928
929         This is intended for subclasses to be able to add arbitrary
930         hooks to filter or manipulate incoming messages. The base
931         implementation does nothing but log the message without any
932         manipulation of the message.
933
934         This method does not "handle" incoming messages; it is a filter.
935         The actual "endpoint" for incoming messages is `_on_message()`.
936
937         :param msg: raw inbound message
938         :return: processed inbound message
939         """
940         self.logger.debug("<-- %s", str(msg))
941         return msg
942
943     @upper_half
944     @bottom_half
945     async def _readline(self) -> bytes:
946         """
947         Wait for a newline from the incoming reader.
948
949         This method is provided as a convenience for upper-layer
950         protocols, as many are line-based.
951
952         This method *may* return a sequence of bytes without a trailing
953         newline if EOF occurs, but *some* bytes were received. In this
954         case, the next call will raise `EOFError`. It is assumed that
955         the layer 5 protocol will decide if there is anything meaningful
956         to be done with a partial message.
957
958         :raise OSError: For stream-related errors.
959         :raise EOFError:
960             If the reader stream is at EOF and there are no bytes to return.
961         :return: bytes, including the newline.
962         """
963         assert self._reader is not None
964         msg_bytes = await self._reader.readline()
965
966         if not msg_bytes:
967             if self._reader.at_eof():
968                 raise EOFError
969
970         return msg_bytes
971
972     @upper_half
973     @bottom_half
974     async def _do_recv(self) -> T:
975         """
976         Abstract: Read from the stream and return a message.
977
978         Very low-level; intended to only be called by `_recv()`.
979         """
980         raise NotImplementedError
981
982     @upper_half
983     @bottom_half
984     async def _recv(self) -> T:
985         """
986         Read an arbitrary protocol message.
987
988         .. warning::
989             This method is intended primarily for `_bh_recv_message()`
990             to use in an asynchronous task loop. Using it outside of
991             this loop will "steal" messages from the normal routing
992             mechanism. It is safe to use prior to `_establish_session()`,
993             but should not be used otherwise.
994
995         This method uses `_do_recv()` to retrieve the raw message, and
996         then transforms it using `_cb_inbound()`.
997
998         :return: A single (filtered, processed) protocol message.
999         """
1000         message = await self._do_recv()
1001         return self._cb_inbound(message)
1002
1003     @upper_half
1004     @bottom_half
1005     def _do_send(self, msg: T) -> None:
1006         """
1007         Abstract: Write a message to the stream.
1008
1009         Very low-level; intended to only be called by `_send()`.
1010         """
1011         raise NotImplementedError
1012
1013     @upper_half
1014     @bottom_half
1015     async def _send(self, msg: T) -> None:
1016         """
1017         Send an arbitrary protocol message.
1018
1019         This method will transform any outgoing messages according to
1020         `_cb_outbound()`.
1021
1022         .. warning::
1023             Like `_recv()`, this method is intended to be called by
1024             the writer task loop that processes outgoing
1025             messages. Calling it directly may circumvent logic
1026             implemented by the caller meant to correlate outgoing and
1027             incoming messages.
1028
1029         :raise OSError: For problems with the underlying stream.
1030         """
1031         msg = self._cb_outbound(msg)
1032         self._do_send(msg)
1033
1034     @bottom_half
1035     async def _on_message(self, msg: T) -> None:
1036         """
1037         Called to handle the receipt of a new message.
1038
1039         .. caution::
1040             This is executed from within the reader loop, so be advised
1041             that waiting on either the reader or writer task will lead
1042             to deadlock. Additionally, any unhandled exceptions will
1043             directly cause the loop to halt, so logic may be best-kept
1044             to a minimum if at all possible.
1045
1046         :param msg: The incoming message, already logged/filtered.
1047         """
1048         # Nothing to do in the abstract case.
This page took 0.086043 seconds and 4 git commands to generate.