Source code for wcraas_common.wcraas_common
# -*- coding: utf-8 -*-
"""The WCraaS Common module aims to single-source reused code across WCraaS the platform."""
import asyncio
import logging
import aio_pika
from abc import ABC, abstractmethod
from aio_pika import connect_robust, ExchangeType
from aio_pika.patterns import RPC
from aio_pika.pool import Pool
from wcraas_common.config import AMQPConfig
__all__ = ("WcraasWorker",)
[docs]class WcraasWorker(ABC):
"""
Base class for WCraaS Worker classes, aiming to single-source AMQP boilerplate.
"""
__slots__ = (
"amqp",
"logger",
"loglevel",
"_amqp_pool",
"_close",
)
def __init__(self, amqp: AMQPConfig, loglevel: int, *args, **kwargs):
self.amqp = amqp
self.logger = logging.getLogger("wcraas.common")
self.logger.setLevel(loglevel)
self.loglevel = loglevel
self._amqp_pool = self.create_channel_pool()
self._close = asyncio.Event()
def _discover_callable(self):
for attr in dir(self):
if attr.startswith("__"):
continue
obj = getattr(self, attr)
if not callable(obj):
continue
yield obj
def _discover(self, attribute):
return [
obj
for obj in self._discover_callable()
if getattr(obj, attribute, False)
]
[docs] def create_channel_pool(self, pool_size: int = 2, channel_size: int = 10) -> Pool:
"""
Given the max connection pool size and the max channel size create a channel Pool.
:param pool_size: Max size for the underlying connection Pool.
:type pool_size: integer
:param channel_size: Max size for the channel Pool.
:type channel_size: integer
"""
async def get_connection():
return await connect_robust(
host=self.amqp.host,
port=self.amqp.port,
login=self.amqp.user,
password=self.amqp.password,
)
connection_pool = Pool(get_connection, max_size=pool_size)
async def get_channel() -> aio_pika.Channel:
async with connection_pool.acquire() as connection:
return await connection.channel()
return Pool(get_channel, max_size=channel_size)
[docs] async def start_rpc(self) -> None:
"""
Asynchronous runtime for the worker, responsible of managing and maintaining async context open.
"""
async with self._amqp_pool.acquire() as rpc_channel:
rpc = await RPC.create(rpc_channel)
for func in self._discover("is_rpc"):
await rpc.register(func.rpc_command, func, auto_delete=True)
await self._close.wait()
[docs] async def start_consume(self):
async with self._amqp_pool.acquire() as sub_channel:
await sub_channel.set_qos(prefetch_count=1)
for func in self._discover("is_consume"):
queue_name = func.consume_queue
await self.register_consumer(sub_channel, func, queue_name)
self.logger.info(f"Registered {queue_name} ...")
await self._close.wait()
[docs] @staticmethod
async def register_consumer(sub_channel, consumer, queue_name):
"""
Given a channel, a consumer function and a queue name register & start the consumption.
:param sub_channel: An aio-pika Channel used for the subscriotion.
:type sub_channel: aio_pika.Channel
:param consumer: Consumer function that will handle incoming messages in the queue.
:type consumer: Callable
:param queue_name: Name of the queue to subscribe to.
:type queue_name: string
"""
exchange = await sub_channel.declare_exchange(queue_name, ExchangeType.FANOUT)
queue = await sub_channel.declare_queue(exclusive=True)
await queue.bind(exchange)
await queue.consume(consumer)
[docs] @abstractmethod
async def start(self):
"""
Asynchronous runtime for the worker, responsible of managing and maintaining async context open.
"""
pass
[docs] def run(self) -> None:
"""
Helper function implementing the synchronous boilerplate for initilization and teardown.
"""
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(self.start())
except KeyboardInterrupt:
self.logger.info("[x] Received ^C ! Exiting ...")
finally:
self._close.set()
loop.shutdown_asyncgens()
def __repr__(self):
return f"{self.__class__.__name__}(amqp={self.amqp}, loglevel={self.loglevel})"