asyncron/asyncron/workers.py

327 lines
11 KiB
Python

from django.db import IntegrityError, models, close_old_connections
from django.utils import timezone
from django.db.utils import OperationalError
from asgiref.sync import sync_to_async
import os, signal
import time
import threading
import logging, traceback
import asyncio
import collections, functools
import random
class AsyncronWorker:
INSTANCES = [] #AsyncronWorker instance
MAX_COUNT = 0
EXIST_SIGNALS = [
signal.SIGABRT,
signal.SIGHUP,
signal.SIGQUIT,
signal.SIGINT,
signal.SIGTERM
]
@classmethod
def override_exit_signals( cls ):
for sig in cls.EXIST_SIGNALS:
to_override = signal.getsignal(sig)
if getattr(to_override, "already_wrapped", False):
cls.log.warning(
f"An attempt was made to wrap around the {signal.strsignal(sig)} signal again!"
" Make sure you only call asyncron.AsyncronWorker.override_exit_signals once per process."
)
continue
if to_override and callable(to_override):
def wrapped( signum, frame ):
cls.sigcatch( signum, frame )
return to_override( signum, frame )
wrapped.already_wrapped = True
cls.log.debug(f"Wrapped {to_override} inside sigcatch for {signal.strsignal(sig)}")
signal.signal(sig, wrapped)
else:
cls.log.debug(f"Direct sigcatch for {signal.strsignal(sig)}")
signal.signal(sig, cls.sigcatch)
@classmethod
def sigcatch( cls, signum, frame ):
cls.stop(f"Signal {signal.strsignal(signum)}")
@classmethod
def stop( cls, reason = None ):
cls.log.info(f"[Asyncron] Stopping Worker(s): {reason}")
for worker in cls.INSTANCES:
if worker.is_stopping: continue
worker.is_stopping = True
worker.loop.call_soon_threadsafe(worker.loop.stop)
for worker in cls.INSTANCES:
if worker.thread.is_alive():
worker.thread.join()
@classmethod
def init( cls ):
if len(cls.INSTANCES) < cls.MAX_COUNT: cls()
#TODO: Use this to skip the 1 second delay in the self.start method on higher traffic servers.
#from django.db.backends.signals import connection_created
#from django.db.backends.postgresql.base import DatabaseWrapper
#from django.dispatch import receiver
#@receiver(connection_created, sender=DatabaseWrapper)
#def initial_connection_to_db(sender, **kwargs):
# if len(cls.INSTANCES) < cls.MAX_COUNT: cls()
##
## Start of instance methods
##
def __init__( self, daemon = True ):
self.INSTANCES.append(self)
self.is_stopping = False
self.clearing_dead_workers = False
self.watching_models = collections.defaultdict( set ) # Model -> Set of key name of the tasks
self.work_loop_over = asyncio.Event()
if daemon:
self.thread = threading.Thread( target = self.start )
self.thread.start()
def start( self, is_robust = False ):
assert not hasattr(self, "loop"), "This worker is already running!"
from .models import Worker, Task, Trace
self.model = Worker( pid = os.getpid(), thread_id = threading.get_ident(), is_robust = is_robust )
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop( self.loop )
#Fight over who's gonna be the master, prove your health in the process!
self.loop.create_task( self.master_loop() )
main_task = self.loop.create_task( self.work_loop() )
time.sleep(0.3) #To avoid the django initialization warning!
self.model.save()
self.model.refresh_from_db()
#Fill in the ID fields of the tasks we didn't dare to check with db until now
from .models import Task
for func in Task.registered_tasks.values():
task = func.task
if not task.pk:
try: task.pk = Task.objects.get( name = task.name ).pk
except Task.DoesNotExist: pass #It's a new one, it's fine.
self.attach_django_signals()
try:
self.loop.run_forever() #This is the lifetime of this worker
except KeyboardInterrupt: self.log.info(f"[Asyncron][W{self.model.id}] Worker Received KeyboardInterrupt, exiting...")
else: self.log.info(f"[Asyncron][W{self.model.id}] Worker exiting...")
self.loop.run_until_complete( self.graceful_shutdown() )
count = Trace.objects.filter( status__in = "SWRP", worker_lock = self.model ).update(
status_reason = "Worker died during execution",
status = "A", worker_lock = None
)
#DONT print anything in here!
#if count: print(f"Had to cancel {count} task(s).") #cls.log.warning
self.model.delete()
#self.loop.call_soon(self.started.set)
def attach_django_signals( self ):
django_signals = {
name : attr
for name in ["post_save", "post_delete"] #TO Expand: dir(models.signals)
if not name.startswith("_") #Dont get private stuff
and ( attr := getattr(models.signals, name) ) #Just an assignment
and isinstance( attr, models.signals.ModelSignal ) #Is a signal related to models!
}
for name, signal in django_signals.items():
signal.connect( functools.partial( self.model_changed, name ) )
from .models import Task
for name, task in Task.registered_tasks.items():
if not hasattr(task, 'watching_models'): continue
for model in getattr(task, 'watching_models'):
self.watching_models[ model ].add( name )
def model_changed( self, signal_name, sender, signal, instance, **kwargs ):
from .models import Task
for name in self.watching_models[instance.__class__]:
asyncio.run_coroutine_threadsafe(
Task.registered_tasks[name].task.ensure_quick_execution( reason = f"Change ({signal_name}) on {instance}" ),
self.loop
)
async def graceful_shutdown( self ):
try:
for attempt in range( 10 ):
if not await self.model.trace_set.aexists():
break
await asyncio.sleep( attempt / 10 )
else: self.log.info(f"[Asyncron][W{self.model.id}] Graceful shutdown not graceful enough!")
except: await asyncio.sleep( 1 )
async def master_loop( self ):
from .models import Worker, Task, Trace
#Delete dead masters every now and then!
last_overtake_attempt = 0
current_master = False
while True:
try:
await Worker.objects.filter( is_master = False ).aupdate( is_master = models.Q(id = self.model.id) )
except IntegrityError: # I'm not master!
loop_wait = 5 + random.random() * 15
if current_master: self.log.info(f"[Asyncron][W{self.model.id}] No longer master.")
current_master = False
if last_overtake_attempt + 60 < time.time():
last_overtake_attempt = time.time()
took_master = False
if self.model.is_robust:
took_master = await Worker.objects.filter( is_master = True, is_robust = False ).aupdate( is_master = False )
loop_wait = 0
else:
await Worker.objects.filter(
is_master = True,
last_crowning_attempt__lte = timezone.now() - timezone.timedelta( minutes = 5 )
).aupdate( is_master = False )
else: #I am Master!
loop_wait = 2 + random.random() * 3
if not current_master: self.log.info(f"[Asyncron][W{self.model.id}] Running as master.")
current_master = True
if not self.clearing_dead_workers:
self.loop.create_task( self.clear_dead_workers() )
await self.sync_tasks()
await self.clear_orphaned_traces()
finally:
await Worker.objects.filter( id = self.model.id ).aupdate( last_crowning_attempt = timezone.now() )
await asyncio.sleep( loop_wait )
async def clear_orphaned_traces( self ):
from .models import Worker, Task, Trace
await Trace.objects.filter( worker_lock = None, status__in = "RPW" ).adelete()
async def clear_dead_workers( self ):
self.clearing_dead_workers = True
from .models import Worker, Task, Trace
await Worker.objects.filter(
last_crowning_attempt__lte = timezone.now() - timezone.timedelta( seconds = 30 ),
in_grace = False
).aupdate( in_grace = True )
async for worker in Worker.objects.filter( in_grace = False, last_crowning_attempt = None ):
if not await sync_to_async( worker.is_proc_alive )():
await worker.adelete()
await asyncio.sleep( 30 )
await Worker.objects.filter( in_grace = True ).adelete()
self.clearing_dead_workers = False
async def sync_tasks( self ):
from .models import Task
for name, func in Task.registered_tasks.items():
init_task = func.task
try:
func.task = await Task.objects.aget( name = name )
except Task.DoesNotExist:
await func.task.asave()
await func.task.arefresh_from_db()
else: #For now, to commit changes to db
init_task.id = func.task.id
await init_task.asave()
await func.task.arefresh_from_db()
async def work_loop( self ):
self.check_interval = 0
while True:
await asyncio.sleep( self.check_interval )
self.check_interval = 10
try:
await self.check_scheduled()
await sync_to_async( close_old_connections )()
except OperationalError as e:
self.log.warning(f"[Asyncron] DB Connection Error: {e}")
print( traceback.format_exc() )
break
except Exception as e:
self.log.warning(f"[Asyncron] check_scheduled failed: {e}")
print( traceback.format_exc() )
self.check_interval = 20
self.work_loop_over.set()
async def check_scheduled( self ):
from .models import Task, Trace
#Schedule traces that aren't yet set.
Ts = Task.objects.exclude( interval = None ).exclude(
trace__status = "S"
).exclude( worker_type = "D" if self.model.is_robust else "R" )
async for task in Ts:
trace = task.new_trace()
await trace.reschedule( reason = "Auto Scheduled" )
locked = await Task.objects.filter( id = task.id ).filter(
models.Q(worker_lock = None) |
models.Q(worker_lock = self.model) #This is incase the lock has been aquired for some reason before.
).aupdate( worker_lock = self.model )
if locked:
await trace.asave()
await Task.objects.filter( id = task.id, worker_lock = self.model ).aupdate( worker_lock = None )
early_seconds = 5 + self.check_interval * ( 1 + random.random() )
async for trace in Trace.objects.filter( status = "S", worker_lock = None, scheduled_datetime__lte = timezone.now() + timezone.timedelta( seconds = early_seconds ) ):
await trace.eval_related()
#print(f"Checking {trace} to do now: {trace.scheduled_datetime - timezone.now()}")
count = await Trace.objects.filter( id = trace.id, status = "S" ).aupdate( status = "W", worker_lock = self.model )
if not count: continue #Lost the race condition to another worker.
self.loop.create_task( self.start_trace_on_time( trace ) )
async def start_trace_on_time( self, trace ):
from .models import Trace
await asyncio.sleep( ( timezone.now() - trace.scheduled_datetime ).total_seconds() )
await trace.arefresh_from_db()
await trace.start()
trace.worker_lock = None
await trace.asave( update_fields = ['worker_lock'] )
#Traces for the same task that we are done with (Completed, Aborted, Errored)
QuerySet = Trace.objects.filter(
task_id = trace.task_id, status__in = "CAE", protected = False, worker_lock = None
).order_by('-register_datetime')
#Should be deleted after the threashold
max_count = trace.task.max_completed_traces if trace.status == "C" else trace.task.max_failed_traces
await QuerySet.exclude(
id__in = QuerySet[:max_count].values_list( 'id', flat = True )
).adelete()