327 lines
11 KiB
Python
327 lines
11 KiB
Python
from django.db import IntegrityError, models, close_old_connections
|
|
from django.utils import timezone
|
|
from django.db.utils import OperationalError
|
|
from asgiref.sync import sync_to_async
|
|
|
|
import os, signal
|
|
import time
|
|
import threading
|
|
import logging, traceback
|
|
import asyncio
|
|
import collections, functools
|
|
import random
|
|
|
|
class AsyncronWorker:
|
|
INSTANCES = [] #AsyncronWorker instance
|
|
MAX_COUNT = 0
|
|
|
|
EXIST_SIGNALS = [
|
|
signal.SIGABRT,
|
|
signal.SIGHUP,
|
|
signal.SIGQUIT,
|
|
signal.SIGINT,
|
|
signal.SIGTERM
|
|
]
|
|
|
|
@classmethod
|
|
def override_exit_signals( cls ):
|
|
for sig in cls.EXIST_SIGNALS:
|
|
to_override = signal.getsignal(sig)
|
|
if getattr(to_override, "already_wrapped", False):
|
|
cls.log.warning(
|
|
f"An attempt was made to wrap around the {signal.strsignal(sig)} signal again!"
|
|
" Make sure you only call asyncron.AsyncronWorker.override_exit_signals once per process."
|
|
)
|
|
continue
|
|
|
|
if to_override and callable(to_override):
|
|
def wrapped( signum, frame ):
|
|
cls.sigcatch( signum, frame )
|
|
return to_override( signum, frame )
|
|
wrapped.already_wrapped = True
|
|
cls.log.debug(f"Wrapped {to_override} inside sigcatch for {signal.strsignal(sig)}")
|
|
signal.signal(sig, wrapped)
|
|
else:
|
|
cls.log.debug(f"Direct sigcatch for {signal.strsignal(sig)}")
|
|
signal.signal(sig, cls.sigcatch)
|
|
|
|
@classmethod
|
|
def sigcatch( cls, signum, frame ):
|
|
cls.stop(f"Signal {signal.strsignal(signum)}")
|
|
|
|
@classmethod
|
|
def stop( cls, reason = None ):
|
|
cls.log.info(f"[Asyncron] Stopping Worker(s): {reason}")
|
|
for worker in cls.INSTANCES:
|
|
if worker.is_stopping: continue
|
|
worker.is_stopping = True
|
|
worker.loop.call_soon_threadsafe(worker.loop.stop)
|
|
|
|
for worker in cls.INSTANCES:
|
|
if worker.thread.is_alive():
|
|
worker.thread.join()
|
|
|
|
@classmethod
|
|
def init( cls ):
|
|
if len(cls.INSTANCES) < cls.MAX_COUNT: cls()
|
|
|
|
#TODO: Use this to skip the 1 second delay in the self.start method on higher traffic servers.
|
|
#from django.db.backends.signals import connection_created
|
|
#from django.db.backends.postgresql.base import DatabaseWrapper
|
|
#from django.dispatch import receiver
|
|
#@receiver(connection_created, sender=DatabaseWrapper)
|
|
#def initial_connection_to_db(sender, **kwargs):
|
|
# if len(cls.INSTANCES) < cls.MAX_COUNT: cls()
|
|
|
|
|
|
|
|
##
|
|
## Start of instance methods
|
|
##
|
|
def __init__( self, daemon = True ):
|
|
self.INSTANCES.append(self)
|
|
self.is_stopping = False
|
|
self.clearing_dead_workers = False
|
|
self.watching_models = collections.defaultdict( set ) # Model -> Set of key name of the tasks
|
|
self.work_loop_over = asyncio.Event()
|
|
|
|
if daemon:
|
|
self.thread = threading.Thread( target = self.start )
|
|
self.thread.start()
|
|
|
|
|
|
def start( self, is_robust = False ):
|
|
assert not hasattr(self, "loop"), "This worker is already running!"
|
|
from .models import Worker, Task, Trace
|
|
|
|
self.model = Worker( pid = os.getpid(), thread_id = threading.get_ident(), is_robust = is_robust )
|
|
self.loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop( self.loop )
|
|
|
|
#Fight over who's gonna be the master, prove your health in the process!
|
|
self.loop.create_task( self.master_loop() )
|
|
main_task = self.loop.create_task( self.work_loop() )
|
|
|
|
time.sleep(0.3) #To avoid the django initialization warning!
|
|
self.model.save()
|
|
self.model.refresh_from_db()
|
|
|
|
#Fill in the ID fields of the tasks we didn't dare to check with db until now
|
|
from .models import Task
|
|
for func in Task.registered_tasks.values():
|
|
task = func.task
|
|
if not task.pk:
|
|
try: task.pk = Task.objects.get( name = task.name ).pk
|
|
except Task.DoesNotExist: pass #It's a new one, it's fine.
|
|
|
|
self.attach_django_signals()
|
|
|
|
try:
|
|
self.loop.run_forever() #This is the lifetime of this worker
|
|
except KeyboardInterrupt: self.log.info(f"[Asyncron][W{self.model.id}] Worker Received KeyboardInterrupt, exiting...")
|
|
else: self.log.info(f"[Asyncron][W{self.model.id}] Worker exiting...")
|
|
|
|
self.loop.run_until_complete( self.graceful_shutdown() )
|
|
|
|
count = Trace.objects.filter( status__in = "SWRP", worker_lock = self.model ).update(
|
|
status_reason = "Worker died during execution",
|
|
status = "A", worker_lock = None
|
|
)
|
|
#DONT print anything in here!
|
|
#if count: print(f"Had to cancel {count} task(s).") #cls.log.warning
|
|
self.model.delete()
|
|
#self.loop.call_soon(self.started.set)
|
|
|
|
def attach_django_signals( self ):
|
|
django_signals = {
|
|
name : attr
|
|
for name in ["post_save", "post_delete"] #TO Expand: dir(models.signals)
|
|
if not name.startswith("_") #Dont get private stuff
|
|
and ( attr := getattr(models.signals, name) ) #Just an assignment
|
|
and isinstance( attr, models.signals.ModelSignal ) #Is a signal related to models!
|
|
}
|
|
for name, signal in django_signals.items():
|
|
signal.connect( functools.partial( self.model_changed, name ) )
|
|
|
|
from .models import Task
|
|
for name, task in Task.registered_tasks.items():
|
|
if not hasattr(task, 'watching_models'): continue
|
|
for model in getattr(task, 'watching_models'):
|
|
self.watching_models[ model ].add( name )
|
|
|
|
|
|
def model_changed( self, signal_name, sender, signal, instance, **kwargs ):
|
|
from .models import Task
|
|
for name in self.watching_models[instance.__class__]:
|
|
asyncio.run_coroutine_threadsafe(
|
|
Task.registered_tasks[name].task.ensure_quick_execution( reason = f"Change ({signal_name}) on {instance}" ),
|
|
self.loop
|
|
)
|
|
|
|
async def graceful_shutdown( self ):
|
|
try:
|
|
for attempt in range( 10 ):
|
|
if not await self.model.trace_set.aexists():
|
|
break
|
|
await asyncio.sleep( attempt / 10 )
|
|
else: self.log.info(f"[Asyncron][W{self.model.id}] Graceful shutdown not graceful enough!")
|
|
except: await asyncio.sleep( 1 )
|
|
|
|
|
|
async def master_loop( self ):
|
|
from .models import Worker, Task, Trace
|
|
|
|
#Delete dead masters every now and then!
|
|
last_overtake_attempt = 0
|
|
current_master = False
|
|
|
|
while True:
|
|
try:
|
|
await Worker.objects.filter( is_master = False ).aupdate( is_master = models.Q(id = self.model.id) )
|
|
|
|
except IntegrityError: # I'm not master!
|
|
loop_wait = 5 + random.random() * 15
|
|
|
|
if current_master: self.log.info(f"[Asyncron][W{self.model.id}] No longer master.")
|
|
current_master = False
|
|
|
|
if last_overtake_attempt + 60 < time.time():
|
|
last_overtake_attempt = time.time()
|
|
took_master = False
|
|
|
|
if self.model.is_robust:
|
|
took_master = await Worker.objects.filter( is_master = True, is_robust = False ).aupdate( is_master = False )
|
|
loop_wait = 0
|
|
|
|
else:
|
|
await Worker.objects.filter(
|
|
is_master = True,
|
|
last_crowning_attempt__lte = timezone.now() - timezone.timedelta( minutes = 5 )
|
|
).aupdate( is_master = False )
|
|
|
|
|
|
else: #I am Master!
|
|
loop_wait = 2 + random.random() * 3
|
|
|
|
if not current_master: self.log.info(f"[Asyncron][W{self.model.id}] Running as master.")
|
|
current_master = True
|
|
|
|
if not self.clearing_dead_workers:
|
|
self.loop.create_task( self.clear_dead_workers() )
|
|
|
|
await self.sync_tasks()
|
|
await self.clear_orphaned_traces()
|
|
|
|
finally:
|
|
await Worker.objects.filter( id = self.model.id ).aupdate( last_crowning_attempt = timezone.now() )
|
|
await asyncio.sleep( loop_wait )
|
|
|
|
async def clear_orphaned_traces( self ):
|
|
from .models import Worker, Task, Trace
|
|
await Trace.objects.filter( worker_lock = None, status__in = "RPW" ).adelete()
|
|
|
|
async def clear_dead_workers( self ):
|
|
self.clearing_dead_workers = True
|
|
from .models import Worker, Task, Trace
|
|
await Worker.objects.filter(
|
|
last_crowning_attempt__lte = timezone.now() - timezone.timedelta( seconds = 30 ),
|
|
in_grace = False
|
|
).aupdate( in_grace = True )
|
|
|
|
async for worker in Worker.objects.filter( in_grace = False, last_crowning_attempt = None ):
|
|
if not await sync_to_async( worker.is_proc_alive )():
|
|
await worker.adelete()
|
|
|
|
await asyncio.sleep( 30 )
|
|
await Worker.objects.filter( in_grace = True ).adelete()
|
|
self.clearing_dead_workers = False
|
|
|
|
async def sync_tasks( self ):
|
|
from .models import Task
|
|
for name, func in Task.registered_tasks.items():
|
|
|
|
init_task = func.task
|
|
try:
|
|
func.task = await Task.objects.aget( name = name )
|
|
except Task.DoesNotExist:
|
|
await func.task.asave()
|
|
await func.task.arefresh_from_db()
|
|
else: #For now, to commit changes to db
|
|
init_task.id = func.task.id
|
|
await init_task.asave()
|
|
await func.task.arefresh_from_db()
|
|
|
|
|
|
async def work_loop( self ):
|
|
self.check_interval = 0
|
|
|
|
while True:
|
|
await asyncio.sleep( self.check_interval )
|
|
self.check_interval = 10
|
|
try:
|
|
await self.check_scheduled()
|
|
await sync_to_async( close_old_connections )()
|
|
except OperationalError as e:
|
|
self.log.warning(f"[Asyncron] DB Connection Error: {e}")
|
|
print( traceback.format_exc() )
|
|
break
|
|
|
|
except Exception as e:
|
|
self.log.warning(f"[Asyncron] check_scheduled failed: {e}")
|
|
print( traceback.format_exc() )
|
|
self.check_interval = 20
|
|
|
|
self.work_loop_over.set()
|
|
|
|
|
|
async def check_scheduled( self ):
|
|
from .models import Task, Trace
|
|
|
|
#Schedule traces that aren't yet set.
|
|
Ts = Task.objects.exclude( interval = None ).exclude(
|
|
trace__status = "S"
|
|
).exclude( worker_type = "D" if self.model.is_robust else "R" )
|
|
|
|
async for task in Ts:
|
|
trace = task.new_trace()
|
|
await trace.reschedule( reason = "Auto Scheduled" )
|
|
|
|
locked = await Task.objects.filter( id = task.id ).filter(
|
|
models.Q(worker_lock = None) |
|
|
models.Q(worker_lock = self.model) #This is incase the lock has been aquired for some reason before.
|
|
).aupdate( worker_lock = self.model )
|
|
|
|
if locked:
|
|
await trace.asave()
|
|
await Task.objects.filter( id = task.id, worker_lock = self.model ).aupdate( worker_lock = None )
|
|
|
|
early_seconds = 5 + self.check_interval * ( 1 + random.random() )
|
|
async for trace in Trace.objects.filter( status = "S", worker_lock = None, scheduled_datetime__lte = timezone.now() + timezone.timedelta( seconds = early_seconds ) ):
|
|
await trace.eval_related()
|
|
#print(f"Checking {trace} to do now: {trace.scheduled_datetime - timezone.now()}")
|
|
|
|
count = await Trace.objects.filter( id = trace.id, status = "S" ).aupdate( status = "W", worker_lock = self.model )
|
|
if not count: continue #Lost the race condition to another worker.
|
|
|
|
self.loop.create_task( self.start_trace_on_time( trace ) )
|
|
|
|
async def start_trace_on_time( self, trace ):
|
|
from .models import Trace
|
|
|
|
await asyncio.sleep( ( timezone.now() - trace.scheduled_datetime ).total_seconds() )
|
|
await trace.arefresh_from_db()
|
|
await trace.start()
|
|
trace.worker_lock = None
|
|
await trace.asave( update_fields = ['worker_lock'] )
|
|
|
|
#Traces for the same task that we are done with (Completed, Aborted, Errored)
|
|
QuerySet = Trace.objects.filter(
|
|
task_id = trace.task_id, status__in = "CAE", protected = False, worker_lock = None
|
|
).order_by('-register_datetime')
|
|
|
|
#Should be deleted after the threashold
|
|
max_count = trace.task.max_completed_traces if trace.status == "C" else trace.task.max_failed_traces
|
|
await QuerySet.exclude(
|
|
id__in = QuerySet[:max_count].values_list( 'id', flat = True )
|
|
).adelete()
|