]> Git Repo - qemu.git/blob - python/tests/protocol.py
python/aqmp: remove _new_session and _establish_connection
[qemu.git] / python / tests / protocol.py
1 import asyncio
2 from contextlib import contextmanager
3 import os
4 import socket
5 from tempfile import TemporaryDirectory
6
7 import avocado
8
9 from qemu.aqmp import ConnectError, Runstate
10 from qemu.aqmp.protocol import AsyncProtocol, StateError
11 from qemu.aqmp.util import asyncio_run, create_task
12
13
14 class NullProtocol(AsyncProtocol[None]):
15     """
16     NullProtocol is a test mockup of an AsyncProtocol implementation.
17
18     It adds a fake_session instance variable that enables a code path
19     that bypasses the actual connection logic, but still allows the
20     reader/writers to start.
21
22     Because the message type is defined as None, an asyncio.Event named
23     'trigger_input' is created that prohibits the reader from
24     incessantly being able to yield None; this event can be poked to
25     simulate an incoming message.
26
27     For testing symmetry with do_recv, an interface is added to "send" a
28     Null message.
29
30     For testing purposes, a "simulate_disconnection" method is also
31     added which allows us to trigger a bottom half disconnect without
32     injecting any real errors into the reader/writer loops; in essence
33     it performs exactly half of what disconnect() normally does.
34     """
35     def __init__(self, name=None):
36         self.fake_session = False
37         self.trigger_input: asyncio.Event
38         super().__init__(name)
39
40     async def _establish_session(self):
41         self.trigger_input = asyncio.Event()
42         await super()._establish_session()
43
44     async def _do_accept(self, address, ssl=None):
45         if self.fake_session:
46             self._set_state(Runstate.CONNECTING)
47             await asyncio.sleep(0)
48         else:
49             await super()._do_accept(address, ssl)
50
51     async def _do_connect(self, address, ssl=None):
52         if self.fake_session:
53             self._set_state(Runstate.CONNECTING)
54             await asyncio.sleep(0)
55         else:
56             await super()._do_connect(address, ssl)
57
58     async def _do_recv(self) -> None:
59         await self.trigger_input.wait()
60         self.trigger_input.clear()
61
62     def _do_send(self, msg: None) -> None:
63         pass
64
65     async def send_msg(self) -> None:
66         await self._outgoing.put(None)
67
68     async def simulate_disconnect(self) -> None:
69         """
70         Simulates a bottom-half disconnect.
71
72         This method schedules a disconnection but does not wait for it
73         to complete. This is used to put the loop into the DISCONNECTING
74         state without fully quiescing it back to IDLE. This is normally
75         something you cannot coax AsyncProtocol to do on purpose, but it
76         will be similar to what happens with an unhandled Exception in
77         the reader/writer.
78
79         Under normal circumstances, the library design requires you to
80         await on disconnect(), which awaits the disconnect task and
81         returns bottom half errors as a pre-condition to allowing the
82         loop to return back to IDLE.
83         """
84         self._schedule_disconnect()
85
86
87 class LineProtocol(AsyncProtocol[str]):
88     def __init__(self, name=None):
89         super().__init__(name)
90         self.rx_history = []
91
92     async def _do_recv(self) -> str:
93         raw = await self._readline()
94         msg = raw.decode()
95         self.rx_history.append(msg)
96         return msg
97
98     def _do_send(self, msg: str) -> None:
99         assert self._writer is not None
100         self._writer.write(msg.encode() + b'\n')
101
102     async def send_msg(self, msg: str) -> None:
103         await self._outgoing.put(msg)
104
105
106 def run_as_task(coro, allow_cancellation=False):
107     """
108     Run a given coroutine as a task.
109
110     Optionally, wrap it in a try..except block that allows this
111     coroutine to be canceled gracefully.
112     """
113     async def _runner():
114         try:
115             await coro
116         except asyncio.CancelledError:
117             if allow_cancellation:
118                 return
119             raise
120     return create_task(_runner())
121
122
123 @contextmanager
124 def jammed_socket():
125     """
126     Opens up a random unused TCP port on localhost, then jams it.
127     """
128     socks = []
129
130     try:
131         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
132         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
133         sock.bind(('127.0.0.1', 0))
134         sock.listen(1)
135         address = sock.getsockname()
136
137         socks.append(sock)
138
139         # I don't *fully* understand why, but it takes *two* un-accepted
140         # connections to start jamming the socket.
141         for _ in range(2):
142             sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
143             sock.connect(address)
144             socks.append(sock)
145
146         yield address
147
148     finally:
149         for sock in socks:
150             sock.close()
151
152
153 class Smoke(avocado.Test):
154
155     def setUp(self):
156         self.proto = NullProtocol()
157
158     def test__repr__(self):
159         self.assertEqual(
160             repr(self.proto),
161             "<NullProtocol runstate=IDLE>"
162         )
163
164     def testRunstate(self):
165         self.assertEqual(
166             self.proto.runstate,
167             Runstate.IDLE
168         )
169
170     def testDefaultName(self):
171         self.assertEqual(
172             self.proto.name,
173             None
174         )
175
176     def testLogger(self):
177         self.assertEqual(
178             self.proto.logger.name,
179             'qemu.aqmp.protocol'
180         )
181
182     def testName(self):
183         self.proto = NullProtocol('Steve')
184
185         self.assertEqual(
186             self.proto.name,
187             'Steve'
188         )
189
190         self.assertEqual(
191             self.proto.logger.name,
192             'qemu.aqmp.protocol.Steve'
193         )
194
195         self.assertEqual(
196             repr(self.proto),
197             "<NullProtocol name='Steve' runstate=IDLE>"
198         )
199
200
201 class TestBase(avocado.Test):
202
203     def setUp(self):
204         self.proto = NullProtocol(type(self).__name__)
205         self.assertEqual(self.proto.runstate, Runstate.IDLE)
206         self.runstate_watcher = None
207
208     def tearDown(self):
209         self.assertEqual(self.proto.runstate, Runstate.IDLE)
210
211     async def _asyncSetUp(self):
212         pass
213
214     async def _asyncTearDown(self):
215         if self.runstate_watcher:
216             await self.runstate_watcher
217
218     @staticmethod
219     def async_test(async_test_method):
220         """
221         Decorator; adds SetUp and TearDown to async tests.
222         """
223         async def _wrapper(self, *args, **kwargs):
224             loop = asyncio.get_event_loop()
225             loop.set_debug(True)
226
227             await self._asyncSetUp()
228             await async_test_method(self, *args, **kwargs)
229             await self._asyncTearDown()
230
231         return _wrapper
232
233     # Definitions
234
235     # The states we expect a "bad" connect/accept attempt to transition through
236     BAD_CONNECTION_STATES = (
237         Runstate.CONNECTING,
238         Runstate.DISCONNECTING,
239         Runstate.IDLE,
240     )
241
242     # The states we expect a "good" session to transition through
243     GOOD_CONNECTION_STATES = (
244         Runstate.CONNECTING,
245         Runstate.RUNNING,
246         Runstate.DISCONNECTING,
247         Runstate.IDLE,
248     )
249
250     # Helpers
251
252     async def _watch_runstates(self, *states):
253         """
254         This launches a task alongside (most) tests below to confirm that
255         the sequence of runstate changes that occur is exactly as
256         anticipated.
257         """
258         async def _watcher():
259             for state in states:
260                 new_state = await self.proto.runstate_changed()
261                 self.assertEqual(
262                     new_state,
263                     state,
264                     msg=f"Expected state '{state.name}'",
265                 )
266
267         self.runstate_watcher = create_task(_watcher())
268         # Kick the loop and force the task to block on the event.
269         await asyncio.sleep(0)
270
271
272 class State(TestBase):
273
274     @TestBase.async_test
275     async def testSuperfluousDisconnect(self):
276         """
277         Test calling disconnect() while already disconnected.
278         """
279         await self._watch_runstates(
280             Runstate.DISCONNECTING,
281             Runstate.IDLE,
282         )
283         await self.proto.disconnect()
284
285
286 class Connect(TestBase):
287     """
288     Tests primarily related to calling Connect().
289     """
290     async def _bad_connection(self, family: str):
291         assert family in ('INET', 'UNIX')
292
293         if family == 'INET':
294             await self.proto.connect(('127.0.0.1', 0))
295         elif family == 'UNIX':
296             await self.proto.connect('/dev/null')
297
298     async def _hanging_connection(self):
299         with jammed_socket() as addr:
300             await self.proto.connect(addr)
301
302     async def _bad_connection_test(self, family: str):
303         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
304
305         with self.assertRaises(ConnectError) as context:
306             await self._bad_connection(family)
307
308         self.assertIsInstance(context.exception.exc, OSError)
309         self.assertEqual(
310             context.exception.error_message,
311             "Failed to establish connection"
312         )
313
314     @TestBase.async_test
315     async def testBadINET(self):
316         """
317         Test an immediately rejected call to an IP target.
318         """
319         await self._bad_connection_test('INET')
320
321     @TestBase.async_test
322     async def testBadUNIX(self):
323         """
324         Test an immediately rejected call to a UNIX socket target.
325         """
326         await self._bad_connection_test('UNIX')
327
328     @TestBase.async_test
329     async def testCancellation(self):
330         """
331         Test what happens when a connection attempt is aborted.
332         """
333         # Note that accept() cannot be cancelled outright, as it isn't a task.
334         # However, we can wrap it in a task and cancel *that*.
335         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
336         task = run_as_task(self._hanging_connection(), allow_cancellation=True)
337
338         state = await self.proto.runstate_changed()
339         self.assertEqual(state, Runstate.CONNECTING)
340
341         # This is insider baseball, but the connection attempt has
342         # yielded *just* before the actual connection attempt, so kick
343         # the loop to make sure it's truly wedged.
344         await asyncio.sleep(0)
345
346         task.cancel()
347         await task
348
349     @TestBase.async_test
350     async def testTimeout(self):
351         """
352         Test what happens when a connection attempt times out.
353         """
354         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
355         task = run_as_task(self._hanging_connection())
356
357         # More insider baseball: to improve the speed of this test while
358         # guaranteeing that the connection even gets a chance to start,
359         # verify that the connection hangs *first*, then await the
360         # result of the task with a nearly-zero timeout.
361
362         state = await self.proto.runstate_changed()
363         self.assertEqual(state, Runstate.CONNECTING)
364         await asyncio.sleep(0)
365
366         with self.assertRaises(asyncio.TimeoutError):
367             await asyncio.wait_for(task, timeout=0)
368
369     @TestBase.async_test
370     async def testRequire(self):
371         """
372         Test what happens when a connection attempt is made while CONNECTING.
373         """
374         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
375         task = run_as_task(self._hanging_connection(), allow_cancellation=True)
376
377         state = await self.proto.runstate_changed()
378         self.assertEqual(state, Runstate.CONNECTING)
379
380         with self.assertRaises(StateError) as context:
381             await self._bad_connection('UNIX')
382
383         self.assertEqual(
384             context.exception.error_message,
385             "NullProtocol is currently connecting."
386         )
387         self.assertEqual(context.exception.state, Runstate.CONNECTING)
388         self.assertEqual(context.exception.required, Runstate.IDLE)
389
390         task.cancel()
391         await task
392
393     @TestBase.async_test
394     async def testImplicitRunstateInit(self):
395         """
396         Test what happens if we do not wait on the runstate event until
397         AFTER a connection is made, i.e., connect()/accept() themselves
398         initialize the runstate event. All of the above tests force the
399         initialization by waiting on the runstate *first*.
400         """
401         task = run_as_task(self._hanging_connection(), allow_cancellation=True)
402
403         # Kick the loop to coerce the state change
404         await asyncio.sleep(0)
405         assert self.proto.runstate == Runstate.CONNECTING
406
407         # We already missed the transition to CONNECTING
408         await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
409
410         task.cancel()
411         await task
412
413
414 class Accept(Connect):
415     """
416     All of the same tests as Connect, but using the accept() interface.
417     """
418     async def _bad_connection(self, family: str):
419         assert family in ('INET', 'UNIX')
420
421         if family == 'INET':
422             await self.proto.start_server_and_accept(('example.com', 1))
423         elif family == 'UNIX':
424             await self.proto.start_server_and_accept('/dev/null')
425
426     async def _hanging_connection(self):
427         with TemporaryDirectory(suffix='.aqmp') as tmpdir:
428             sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
429             await self.proto.start_server_and_accept(sock)
430
431
432 class FakeSession(TestBase):
433
434     def setUp(self):
435         super().setUp()
436         self.proto.fake_session = True
437
438     async def _asyncSetUp(self):
439         await super()._asyncSetUp()
440         await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
441
442     async def _asyncTearDown(self):
443         await self.proto.disconnect()
444         await super()._asyncTearDown()
445
446     ####
447
448     @TestBase.async_test
449     async def testFakeConnect(self):
450
451         """Test the full state lifecycle (via connect) with a no-op session."""
452         await self.proto.connect('/not/a/real/path')
453         self.assertEqual(self.proto.runstate, Runstate.RUNNING)
454
455     @TestBase.async_test
456     async def testFakeAccept(self):
457         """Test the full state lifecycle (via accept) with a no-op session."""
458         await self.proto.start_server_and_accept('/not/a/real/path')
459         self.assertEqual(self.proto.runstate, Runstate.RUNNING)
460
461     @TestBase.async_test
462     async def testFakeRecv(self):
463         """Test receiving a fake/null message."""
464         await self.proto.start_server_and_accept('/not/a/real/path')
465
466         logname = self.proto.logger.name
467         with self.assertLogs(logname, level='DEBUG') as context:
468             self.proto.trigger_input.set()
469             self.proto.trigger_input.clear()
470             await asyncio.sleep(0)  # Kick reader.
471
472         self.assertEqual(
473             context.output,
474             [f"DEBUG:{logname}:<-- None"],
475         )
476
477     @TestBase.async_test
478     async def testFakeSend(self):
479         """Test sending a fake/null message."""
480         await self.proto.start_server_and_accept('/not/a/real/path')
481
482         logname = self.proto.logger.name
483         with self.assertLogs(logname, level='DEBUG') as context:
484             # Cheat: Send a Null message to nobody.
485             await self.proto.send_msg()
486             # Kick writer; awaiting on a queue.put isn't sufficient to yield.
487             await asyncio.sleep(0)
488
489         self.assertEqual(
490             context.output,
491             [f"DEBUG:{logname}:--> None"],
492         )
493
494     async def _prod_session_api(
495             self,
496             current_state: Runstate,
497             error_message: str,
498             accept: bool = True
499     ):
500         with self.assertRaises(StateError) as context:
501             if accept:
502                 await self.proto.start_server_and_accept('/not/a/real/path')
503             else:
504                 await self.proto.connect('/not/a/real/path')
505
506         self.assertEqual(context.exception.error_message, error_message)
507         self.assertEqual(context.exception.state, current_state)
508         self.assertEqual(context.exception.required, Runstate.IDLE)
509
510     @TestBase.async_test
511     async def testAcceptRequireRunning(self):
512         """Test that accept() cannot be called when Runstate=RUNNING"""
513         await self.proto.start_server_and_accept('/not/a/real/path')
514
515         await self._prod_session_api(
516             Runstate.RUNNING,
517             "NullProtocol is already connected and running.",
518             accept=True,
519         )
520
521     @TestBase.async_test
522     async def testConnectRequireRunning(self):
523         """Test that connect() cannot be called when Runstate=RUNNING"""
524         await self.proto.start_server_and_accept('/not/a/real/path')
525
526         await self._prod_session_api(
527             Runstate.RUNNING,
528             "NullProtocol is already connected and running.",
529             accept=False,
530         )
531
532     @TestBase.async_test
533     async def testAcceptRequireDisconnecting(self):
534         """Test that accept() cannot be called when Runstate=DISCONNECTING"""
535         await self.proto.start_server_and_accept('/not/a/real/path')
536
537         # Cheat: force a disconnect.
538         await self.proto.simulate_disconnect()
539
540         await self._prod_session_api(
541             Runstate.DISCONNECTING,
542             ("NullProtocol is disconnecting."
543              " Call disconnect() to return to IDLE state."),
544             accept=True,
545         )
546
547     @TestBase.async_test
548     async def testConnectRequireDisconnecting(self):
549         """Test that connect() cannot be called when Runstate=DISCONNECTING"""
550         await self.proto.start_server_and_accept('/not/a/real/path')
551
552         # Cheat: force a disconnect.
553         await self.proto.simulate_disconnect()
554
555         await self._prod_session_api(
556             Runstate.DISCONNECTING,
557             ("NullProtocol is disconnecting."
558              " Call disconnect() to return to IDLE state."),
559             accept=False,
560         )
561
562
563 class SimpleSession(TestBase):
564
565     def setUp(self):
566         super().setUp()
567         self.server = LineProtocol(type(self).__name__ + '-server')
568
569     async def _asyncSetUp(self):
570         await super()._asyncSetUp()
571         await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
572
573     async def _asyncTearDown(self):
574         await self.proto.disconnect()
575         try:
576             await self.server.disconnect()
577         except EOFError:
578             pass
579         await super()._asyncTearDown()
580
581     @TestBase.async_test
582     async def testSmoke(self):
583         with TemporaryDirectory(suffix='.aqmp') as tmpdir:
584             sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
585             server_task = create_task(self.server.start_server_and_accept(sock))
586
587             # give the server a chance to start listening [...]
588             await asyncio.sleep(0)
589             await self.proto.connect(sock)
This page took 0.059973 seconds and 4 git commands to generate.