Source code for wcraas_common.wcraas_common

# -*- coding: utf-8 -*-

"""The WCraaS Common module aims to single-source reused code across WCraaS the platform."""

import asyncio
import logging
import aio_pika

from abc import ABC, abstractmethod

from aio_pika import connect_robust, ExchangeType
from aio_pika.patterns import RPC
from aio_pika.pool import Pool

from wcraas_common.config import AMQPConfig


__all__ = ("WcraasWorker",)


[docs]class WcraasWorker(ABC): """ Base class for WCraaS Worker classes, aiming to single-source AMQP boilerplate. """ __slots__ = ( "amqp", "logger", "loglevel", "_amqp_pool", "_close", ) def __init__(self, amqp: AMQPConfig, loglevel: int, *args, **kwargs): self.amqp = amqp self.logger = logging.getLogger("wcraas.common") self.logger.setLevel(loglevel) self.loglevel = loglevel self._amqp_pool = self.create_channel_pool() self._close = asyncio.Event() def _discover_callable(self): for attr in dir(self): if attr.startswith("__"): continue obj = getattr(self, attr) if not callable(obj): continue yield obj def _discover(self, attribute): return [ obj for obj in self._discover_callable() if getattr(obj, attribute, False) ]
[docs] def create_channel_pool(self, pool_size: int = 2, channel_size: int = 10) -> Pool: """ Given the max connection pool size and the max channel size create a channel Pool. :param pool_size: Max size for the underlying connection Pool. :type pool_size: integer :param channel_size: Max size for the channel Pool. :type channel_size: integer """ async def get_connection(): return await connect_robust( host=self.amqp.host, port=self.amqp.port, login=self.amqp.user, password=self.amqp.password, ) connection_pool = Pool(get_connection, max_size=pool_size) async def get_channel() -> aio_pika.Channel: async with connection_pool.acquire() as connection: return await connection.channel() return Pool(get_channel, max_size=channel_size)
[docs] async def start_rpc(self) -> None: """ Asynchronous runtime for the worker, responsible of managing and maintaining async context open. """ async with self._amqp_pool.acquire() as rpc_channel: rpc = await RPC.create(rpc_channel) for func in self._discover("is_rpc"): await rpc.register(func.rpc_command, func, auto_delete=True) await self._close.wait()
[docs] async def start_consume(self): async with self._amqp_pool.acquire() as sub_channel: await sub_channel.set_qos(prefetch_count=1) for func in self._discover("is_consume"): queue_name = func.consume_queue await self.register_consumer(sub_channel, func, queue_name) self.logger.info(f"Registered {queue_name} ...") await self._close.wait()
[docs] @staticmethod async def register_consumer(sub_channel, consumer, queue_name): """ Given a channel, a consumer function and a queue name register & start the consumption. :param sub_channel: An aio-pika Channel used for the subscriotion. :type sub_channel: aio_pika.Channel :param consumer: Consumer function that will handle incoming messages in the queue. :type consumer: Callable :param queue_name: Name of the queue to subscribe to. :type queue_name: string """ exchange = await sub_channel.declare_exchange(queue_name, ExchangeType.FANOUT) queue = await sub_channel.declare_queue(exclusive=True) await queue.bind(exchange) await queue.consume(consumer)
[docs] @abstractmethod async def start(self): """ Asynchronous runtime for the worker, responsible of managing and maintaining async context open. """ pass
[docs] def run(self) -> None: """ Helper function implementing the synchronous boilerplate for initilization and teardown. """ loop = asyncio.get_event_loop() try: loop.run_until_complete(self.start()) except KeyboardInterrupt: self.logger.info("[x] Received ^C ! Exiting ...") finally: self._close.set() loop.shutdown_asyncgens()
def __repr__(self): return f"{self.__class__.__name__}(amqp={self.amqp}, loglevel={self.loglevel})"