Harvested from another project

This commit is contained in:
aliqandil 2024-10-17 06:12:01 +03:30
parent 2dc2339b6a
commit 4ffa8f4907
14 changed files with 1121 additions and 0 deletions

3
asyncron/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .shortcuts import task, run_on_model_change
__all__ = ['task', 'run_on_model_change']

198
asyncron/admin.py Normal file
View File

@ -0,0 +1,198 @@
from django.contrib import admin
from django.utils import timezone
from django.db.models import F
from .base.admin import BaseModelAdmin
from .models import Worker, Task, Trace, Metadata
import os
import asyncio
import humanize
@admin.register( Metadata )
class MetadataAdmin( BaseModelAdmin ):
order = 3
list_display = 'name', 'model_type', 'target', 'expiration'
def target( self, obj ): return str(obj.model)
def expiration( self, obj ):
if obj.expiration_datetime: return humanize.naturaltime( obj.expiration_datetime )
return "Never"
expiration.admin_order_field = 'expiration_datetime'
def has_add_permission( self, request, obj = None ): return False
def has_change_permission( self, request, obj = None ): return False
@admin.register( Worker )
class WorkerAdmin( BaseModelAdmin ):
order = 4
list_display = 'pid', 'thread_id', 'is_robust', 'is_master', 'is_running', 'health',
def has_add_permission( self, request, obj = None ): return False
def is_running( self, obj ): return obj.is_proc_alive()
is_running.boolean = True
def health( self, obj ):
return (f"In Grace " if obj.in_grace else "") + humanize.naturaltime( obj.last_crowning_attempt )
@admin.register( Task )
class TaskAdmin( BaseModelAdmin ):
order = 1
list_display = 'name', 'timeout', 'gracetime', 'jitter', 'type', 'worker_type', 'logged_executions', 'last_execution', 'scheduled'
fields = ["name", "description", "type", "jitter"]
actions = 'schedule_execution', 'execution_now',
def has_add_permission( self, request, obj = None ): return False
def has_delete_permission( self, request, obj = None ): return False
def has_change_permission( self, request, obj = None ): return False
def description( self, obj ):
try:
return obj.registered_tasks[obj.name].__doc__.strip("\n")
except:
return "N/A"
def jitter( self, obj ):
match obj.jitter_pivot:
case "S": sign = '+'
case "M": sign = '±'
case "E": sign = '-'
return f"{sign}{obj.jitter_length}"
def type( self, obj ):
results = []
if obj.timeout is None:
results.append( "Service" )
if obj.interval:
delta = humanize.naturaldelta( obj.interval )
delta = delta.replace("an ", "1 ").replace("a ", "1 ")
results.append( f"Periodic, every {delta}" )
if obj.name not in obj.registered_tasks:
results.append("Script Missing!")
return ", ".join( results ) if results else "Callable"
def logged_executions( self, obj ):
return obj.trace_set.exclude( status = "S" ).count()
def last_execution( self, obj ):
last_trace = obj.trace_set.exclude( status = "S" ).exclude( last_run_datetime = None ).order_by('last_run_datetime').last()
return humanize.naturaltime( last_trace.last_run_datetime ) if last_trace else "Never"
def scheduled( self, obj ):
return obj.trace_set.filter( status = "S" ).exists()
scheduled.boolean = True
@admin.action( description = "(Re)Schedule an execution for periodic tasks" )
def schedule_execution( self, request, qs ):
trace_ids = set()
for task in qs.exclude( interval = None ):
trace = task.new_trace()
trace.save()
trace.refresh_from_db()
trace_ids.add(trace.id)
results = asyncio.run( Trace.objects.filter( id__in = trace_ids ).gather_method( 'reschedule', reason = "Manually Schedule" ) )
self.explain_gather_results( request, results, 5 )
@admin.action( description = "Execute now!" )
def execution_now( self, request, qs ):
trace_ids = set()
for task in qs:
trace = task.new_trace()
trace.status = "W"
trace.save()
trace.refresh_from_db()
trace_ids.add(trace.id)
results = asyncio.run( Trace.objects.filter( id__in = trace_ids ).gather_method( 'start' ) )
self.explain_gather_results( request, results, 5 )
class TraceAppFilter(admin.SimpleListFilter):
title = "app"
parameter_name = 'app_groups'
def lookups( self, request, model_admin ):
return (
( a.lower(), a )
for a in sorted({
task.name.split(".", 1)[0]
for task in Task.objects.all()
}) if a
)
def queryset( self, request, queryset ):
q = self.value()
if not q: return queryset
return queryset.filter( task__name__istartswith = q )
class TraceNameFilter(admin.SimpleListFilter):
title = "task name"
parameter_name = 'short_name'
def lookups( self, request, model_admin ):
return (
( a.lower(), a )
for a in sorted({
task.name.rsplit(".", 1)[-1]
for task in Task.objects.all()
}) if a
)
def queryset( self, request, queryset ):
q = self.value()
if not q: return queryset
return queryset.filter( task__name__iendswith = q )
@admin.register( Trace )
class TraceAdmin( BaseModelAdmin ):
order = 2
list_display = 'task', 'execution', 'state', 'worker_lock'
list_filter = TraceAppFilter, TraceNameFilter, 'task__worker_type', 'status', 'status_reason',
ordering = F('scheduled_datetime').desc(nulls_last=True),
#readonly_fields = [ f.name for f in Trace._meta.fields ]
def has_add_permission( self, request, obj = None ): return False
def execution( self, obj ):
if obj.last_run_datetime:
return "- Ran " + humanize.naturaltime( obj.last_run_datetime )
if obj.scheduled_datetime:
if obj.scheduled_datetime < timezone.now():
return "- Should've run " + humanize.naturaltime( obj.scheduled_datetime )
else:
return "+ In " + humanize.naturaltime( obj.scheduled_datetime )
return "Never"
execution.admin_order_field = 'scheduled_datetime'
def state( self, obj ):
return f"{obj.status}: {obj.status_reason}" if obj.status_reason else f"{obj.get_status_display()}"
state.admin_order_field = 'status'
actions = 'reschedule_to_now',
@admin.action( description = "Reschedule to run now" )
def reschedule_to_now( self, request, qs ):
results = asyncio.run(
qs.exclude( task__interval = None ).filter( status = "S" ).gather_method(
'reschedule',
reason = "Manually Rescheduled",
target_datetime = timezone.now(),
)
)
self.explain_gather_results( request, results, 5 )
#

34
asyncron/apps.py Normal file
View File

@ -0,0 +1,34 @@
from django.apps import AppConfig
from django.conf import settings
from django.apps import apps
import pathlib, importlib, types
class AsyncronConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'asyncron'
def ready( self ):
try: names = settings.ASYNCRON['IMPORT_PER_APP']
except (KeyError, AttributeError): pass
else: self.import_per_app( names )
#Init the asyncron worker for this process
from .workers import AsyncronWorker
#The worker should not start working until they know we're responding to requests.
AsyncronWorker.init()
def import_per_app( self, names ):
for app in apps.get_app_configs():
app_dir = pathlib.Path(app.path)
if app_dir.parent != settings.BASE_DIR: continue
for name in names:
import_file = app_dir / f"{name}.py"
if not import_file.exists() or not import_file.is_file(): continue
#print( f"Loading {app.name}.{name}:", import_file )
loader = importlib.machinery.SourceFileLoader( f"{app.name}.{name}", str(import_file) )
loader.exec_module( types.ModuleType(loader.name) )

68
asyncron/base/admin.py Normal file
View File

@ -0,0 +1,68 @@
from django.contrib import admin, messages
from django.utils.safestring import mark_safe
from django.utils.html import escape
import traceback
class BaseModelAdmin( admin.ModelAdmin ):
def explain_gather_results( self, request, results, fails_to_show = 2 ):
failed = 0
for id, e in results.items():
if isinstance(e, BaseException):
failed += 1
if failed <= fails_to_show:
if request.user.is_superuser:
traceback_message = ''.join(traceback.TracebackException.from_exception(e).format())
self.message_user( request, mark_safe(f"""
Error For <b>{id}</b>: {e}
<style>
pre.tb {{display: none;}}
a.tb:focus + pre.tb {{display: block;}}
</style>
<a class='tb' href='#'>[TraceBack]</a>
<pre class='tb'>{escape(traceback_message)}</pre>
"""), messages.ERROR
)
else: self.message_user( request, f"Error For {id}: {e}", messages.ERROR)
if failed == 0: self.message_user( request, f"All {len(results)} Succeeded!", messages.SUCCESS )
elif failed <= fails_to_show:
if len(results) - failed > 0:
self.message_user( request, f"All the rest ({len(results) - failed}) Succeeded!", messages.SUCCESS )
else: self.message_user( request, f"{len(results) - failed} Succeeded, {failed - fails_to_show} more failed!", messages.WARNING )
import math
from django.contrib import admin
from django.apps import apps
def get_app_list(self, request, app_label=None):
"""
Return a sorted list of all the installed apps that have been
registered in this site.
"""
app_dict = self._build_app_dict(request, app_label)
app_ordering = { app.name: index for index, app in enumerate( apps.get_app_configs() ) }
# Sort the apps by settings order, then alphabetically.
app_list = sorted(
app_dict.values(),
key = lambda x:
(
app_ordering.get( x["name"], math.inf ),
x["name"].lower()
)
)
# Sort the models admin.order/alphabetically within each app.
for app in app_list:
app["models"].sort(
key=lambda x:
(
getattr( admin.site.get_model_admin( x['model'] ), 'order', math.inf ),
x['name'].lower()
)
)
return app_list
admin.AdminSite.get_app_list = get_app_list

56
asyncron/base/models.py Normal file
View File

@ -0,0 +1,56 @@
from django.db import models
from django.contrib.contenttypes.fields import GenericRelation
from asgiref.sync import sync_to_async
from asyncron.utils import rgetattr
import asyncio
class AsyncronQuerySet( models.QuerySet ):
async def gather_method( self, method, *args, **kwargs ):
mapping = {
instance.pk: getattr( instance, method )( *args, **kwargs )
async for instance in self
}
returns = await asyncio.gather( *list(mapping.values()), return_exceptions = True )
for index, pk in enumerate(mapping):
mapping[pk] = returns[index]
return mapping
def to_json( self, *structure ):
return [ m.fields_to_dict( *structure ) for m in self ]
class BaseModel( models.Model ):
objects = AsyncronQuerySet.as_manager()
metadata = GenericRelation("asyncron.Metadata", content_type_field = 'model_type', object_id_field = 'model_id')
async def eval_related( self, *fields ):
if not fields:
fields = [ f.name for f in self._meta.fields if f.is_relation ]
#Since we're using an underscore variable
#This next line running correctly is optional,
#but helps reduce or eliminate 'sync_to_async' context switches.
try: fields = [ f for f in fields if f not in self._state.fields_cache ]
except: print("WARNING: could not check already cached relations.")
if fields:
await sync_to_async(lambda: [ getattr(self, f) for f in fields ])()
def fields_to_dict( self, *fields ):
"""
To create json/dict from fields.
"""
results = {}
for f in fields:
name, method = (f[0], f[1]) if isinstance(f, tuple) else (f, f)
value = method(self) if callable(method) else rgetattr(self, method)
results[name] = value() if callable(value) else value
return results
class Meta:
abstract = True

38
asyncron/gunicorn.py Normal file
View File

@ -0,0 +1,38 @@
##
## - Gunicorn compatibility
## Add this to gunicorn.py conf file:
## from asyncron.gunicorn import post_fork
##
## adds an asyncron worker in each gunicorn worker process
## Hooks into 'dev reload' and 'exist signals' for graceful termination of tasks
##
def post_fork( server, worker ): #worker and AsyncronWorker, pay attention!
post_fork.server = server
post_fork.worker = worker
from .workers import AsyncronWorker
AsyncronWorker.log = worker.log
AsyncronWorker.log.info("Asyncron worker attached.")
init_to_override = AsyncronWorker.init
def init( *args, **kwargs ):
AsyncronWorker.MAX_COUNT = 1
AsyncronWorker.override_exit_signals()
to_override = worker.reloader._callback
def new_callback(*args, **kwargs):
AsyncronWorker.stop( reason = "Auto Reload" )
return to_override(*args, **kwargs)
worker.reloader._callback = new_callback
return init_to_override( *args, **kwargs )
AsyncronWorker.init = init
# Keeping the worker in post_fork.worker so we can add extra files it for it to track
# TODO: Currently unfinished, since i just realized using the "inotify" support of gunicorn
# makes this reduntant, but still here is the relevant code if I want to also support the simpler
# polling system
# Should be in asyncron.app.ready
# -> post_fork.worker.reloader.add_extra_file

View File

@ -0,0 +1,88 @@
##
#
# Command: python manage.py startasyncron
#
##
import logging
import asyncio
import time
from django.core.management.base import BaseCommand, CommandError
from django.conf import settings
from asyncron.workers import AsyncronWorker
from asyncron.models import Task
class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
class Command(BaseCommand):
help = 'Start an Asyncorn Worker'
def handle( self, *arg, **kwargs ):
AsyncronWorker.log = logging.getLogger(__name__)
worker = AsyncronWorker( daemon = False )
print( "Starting:", worker )
worker.start( is_robust = True )
#Older Stuff
def maintain_tasks( self ):
for name, func in Task.registered_tasks.items():
try: task = Task.objects.get( name = name )
except: task = func.task.save()
def handle_mgr( self, *arg, **kwargs ):
from multiprocessing.connection import Listener
import multiprocessing
from asyncron.manager import PoolManager
PoolManager.init_manager()
print( "Coordinator:", PoolManager.coordinator )
address = PoolManager.coordinator.split("unix:", 1)[-1]
with Listener( address, authkey = settings.SECRET_KEY.encode() ) as listener:
while True:
try:
with listener.accept() as conn:
print( "New Conn:", conn )
while msg := conn.recv():
func, args, kwargs, name, repeat_interval, timeout_after, execution_context, execution_pool = msg
print( "Msg:", msg )
asyncio.run( func( *args, **kwargs ) )
print("Ran.")
except EOFError:
print("Connection Closed.")
continue
except KeyboardInterrupt:
print("Stopping...")
break
#

View File

246
asyncron/models.py Normal file
View File

@ -0,0 +1,246 @@
from django.utils import timezone
from django.db import models
from django.db.models.constraints import UniqueConstraint, Q
from unittest.mock import patch #to mock print, can't use redirect_stdout in async code
import functools, traceback, io
import random
import asyncio
# Create your models here.
from .base.models import BaseModel
class Worker( BaseModel ):
pid = models.IntegerField()
thread_id = models.PositiveBigIntegerField()
is_robust = models.BooleanField( default = False )
is_master = models.BooleanField( default = False )
in_grace = models.BooleanField( default = False ) #If the worker sees this as True, it should kill itself!
#Variables with very feel good names! :)
last_crowning_attempt = models.DateTimeField( null = True, blank = True )
consumption_interval_seconds = models.IntegerField( default = 10 )
consumption_total_active = models.IntegerField( default = 0 )
def __str__( self ): return f"P{self.pid}W{self.thread_id}" + ("R" if self.is_robust else "D")
class Meta:
constraints = [
UniqueConstraint( fields = ('is_master',), condition = Q( is_master = True ), name='only_one_master'),
]
def is_proc_alive( self ):
import os
pid = self.pid #Slightly Altered: https://stackoverflow.com/a/20186516
if pid < 0: return False #NOTE: pid == 0 returns True
try: os.kill(pid, 0)
except ProcessLookupError: return False # errno.ESRCH: No such process
except PermissionError: return True # errno.EPERM: Operation not permitted (i.e., process exists)
else: return True # no error, we can send a signal to the process
class Task( BaseModel ):
registered_tasks = {} #Name -> self
name = models.TextField( unique = True ) #Path to the function
worker_lock = models.ForeignKey( Worker, null = True, blank = True, on_delete = models.SET_NULL )
worker_type = models.CharField( default = "A", choices = {
"A": "Any",
"R": "Robust", #Only seperate Robust workers
"D": "Dynamic", #Only on potentially reloadable workers
})
max_completed_traces = models.IntegerField( default = 10 )
max_failed_traces = models.IntegerField( default = 1000 )
timeout = models.DurationField(
default = timezone.timedelta( minutes = 5 ),
null = True, blank = True
) #None will mean it's a "service" like task
gracetime = models.DurationField( default = timezone.timedelta( minutes = 1 ) )
#Periodic Tasks
interval = models.DurationField( null = True, blank = True )
jitter_length = models.DurationField( default = timezone.timedelta( seconds = 0 ), blank = True )
jitter_pivot = models.CharField( default = "M", max_length = 1, choices = {
"S":"Start", "M":"Middle", "E":"End",
})
def get_jitter( self ):
jitter = self.jitter_length * random.random()
match self.jitter_pivot:
case "M":
jitter -= self.jitter_length / 2
case "E":
jitter *= -1
return jitter
def __str__( self ):
type = "Callable" if self.interval is None else "Periodic"
mode = "Service" if self.timeout is None else "Task"
short = self.name.rsplit('.')[-1]
return " ".join([type, mode, short])
def register( self, f ):
if not self.name: self.name = f"{f.__module__}.{f.__qualname__}"
self.registered_tasks[self.name] = f
f.task = self
return f
def new_trace( self ):
trace = Trace( task_id = self.id )
trace.task = self #Less db hits
return trace
async def ensure_quick_execution( self, reason = "Quick Exec" ):
now = timezone.now()
if await self.trace_set.filter( status = "W" ).aexists():
return
if await self.trace_set.filter( status = "S", scheduled_datetime__lte = now ).aexists():
return
trace = await self.trace_set.filter( status = "S" ).order_by('scheduled_datetime').afirst()
if not trace: trace = self.new_trace()
await trace.reschedule( reason = reason, target_datetime = now )
await trace.asave()
class Trace( BaseModel ):
task = models.ForeignKey( Task, on_delete = models.CASCADE )
status_reason = models.TextField( default = "", blank = True )
status = models.CharField( default = "S", max_length = 1, choices = {
"S":"Scheduled",
"W":"Waiting",
"R":"Running",
"P":"Paused",
"C":"Completed",
"A":"Aborted",
"E":"Error",
})
def set_status( self, status, reason = "" ):
self.status = status
self.status_reason = reason
scheduled_datetime = models.DateTimeField( null = True, blank = True )
register_datetime = models.DateTimeField( auto_now_add = True )
last_run_datetime = models.DateTimeField( null = True, blank = True )
last_end_datetime = models.DateTimeField( null = True, blank = True )
worker_lock = models.ForeignKey( Worker, null = True, blank = True, on_delete = models.SET_NULL )
protected = models.BooleanField( default = False ) #Do not delete these.
args = models.JSONField( default = list, blank = True )
kwargs = models.JSONField( default = dict, blank = True )
stdout = models.TextField( null = True, blank = True )
stderr = models.TextField( null = True, blank = True )
returned = models.JSONField( null = True, blank = True )
def __str__( self ): return f"Trace of Task {self.task}"
class Meta:
constraints = [
UniqueConstraint(
fields = ['task_id'],
condition = models.Q(status = "S", scheduled_datetime = None),
name = "unique_unscheduled_for_task",
)
]
async def reschedule( self, reason = "", target_datetime = None ):
assert self.status in "SAE", f"Cannot reschedule a task that is in {self.get_status_display()} state!"
await self.eval_related('task')
assert self.task.interval, "This is not a periodic task! Nothing to reschedule."
self.set_status( "S", reason )
if target_datetime:
self.scheduled_datetime = target_datetime
else:
base_time = self.last_run_datetime or timezone.now()
jitter = self.task.get_jitter()
self.scheduled_datetime = base_time + self.task.interval + jitter
if self.id: await self.asave( update_fields = ["status", "status_reason", "scheduled_datetime"] )
async def start( self ):
#assert self.status == "W", f"Cannot start a task that is not Waiting ({self.get_status_display()})."
await self.eval_related('task')
assert self.status in "SPAWE", f"Cannot start a task that is in {self.get_status_display()} state!"
self.last_run_datetime = timezone.now()
self.last_end_datetime = None
self.returned = None
self.stderr = ""
self.stdout = ""
try:
func = Task.registered_tasks[self.task.name]
except KeyError:
self.set_status( "E", "Script Missing!" )
return
else:
self.set_status( "R" )
finally:
await self.asave()
#Create an io object to read the print output
new_lines = asyncio.Event() #TODO: So we can update the db mid task in an async way
def confined_print( *args, sep = " ", end = "\n", **kwargs ):
self.stdout += sep.join( str(i) for i in args ) + end
new_lines.set()
try:
with patch( 'builtins.print', confined_print ):
output = await func( *self.args, **self.kwargs )
except Exception as e:
self.set_status( "E", f"Exception: {e}" )
self.stderr = traceback.format_exc()
else:
self.set_status( "C" )
self.returned = output
finally:
self.last_end_datetime = timezone.now()
await self.asave()
#TODO: Cool stuff to add later
#callee = models.TextField()
#caller = models.TextField()
#repeatable = models.BooleanField( default = True ) #Unless something in args or kwargs is unserializable!
#tags = models.JSONField( default = list )
#
#https://docs.djangoproject.com/en/5.1/ref/contrib/contenttypes/
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
class Metadata( BaseModel ):
model_type = models.ForeignKey( ContentType, on_delete = models.CASCADE )
model_id = models.PositiveIntegerField()
model = GenericForeignKey("model_type", "model_id")
name = models.CharField( max_length = 256 )
data = models.JSONField( null = True, blank = True )
expiration_datetime = models.DateTimeField( null = True, blank = True )
@property
def is_expired( self ):
if self.expiration_datetime: return self.expiration_datetime < timezone.now()
return False
def __str__(self): return self.name
class Meta:
indexes = [
models.Index(fields=["model_type", "model_id"]),
]
verbose_name = verbose_name_plural = 'Metadata'

74
asyncron/shortcuts.py Normal file
View File

@ -0,0 +1,74 @@
##
## decorators / functions to make the task calls easier
##
from django.utils.dateparse import parse_duration
from django.db import models
from django.utils import timezone
from django.apps import apps
import re
# Regular expression pattern with named groups for "1w2d5h30m10s500ms1000us" without spaces
pattern = re.compile(
r'(\+|-)?'
r'(?:(?P<weeks>\d+)w)?'
r'(?:(?P<days>\d+)d)?'
r'(?:(?P<hours>\d+)h)?'
r'(?:(?P<minutes>\d+)m)?'
r'(?:(?P<seconds>\d+)s)?'
r'(?:(?P<milliseconds>\d+)ms)?'
r'(?:(?P<microseconds>\d+)us)?'
)
def task( *args, **kwargs ):
from .models import Task
jitter = kwargs.pop('jitter', "")
match = pattern.match( jitter )
if jitter and match:
kwargs['jitter_length'] = timezone.timedelta( **{
k: int(v)
for k, v in match.groupdict().items()
if v is not None
} )
kwargs['jitter_pivot'] = {
"-": "S",
None: "M",
"+": "E",
}[match.group(1)]
for f in Task._meta.fields:
if not isinstance(f, models.DurationField): continue
if f not in kwargs: continue
if not kwargs[f]: continue
kwargs[f] = parse_duration( kwargs[f] )
return Task( *args, **kwargs ).register
def run_on_model_change( *models ):
models = [
apps.get_model(m) if isinstance(m, str) else m
for m in models
]
def decorator( f ):
f.watching_models = models
return f
return decorator
#

3
asyncron/tests.py Normal file
View File

@ -0,0 +1,3 @@
from django.test import TestCase
# Create your tests here.

10
asyncron/utils.py Normal file
View File

@ -0,0 +1,10 @@
import functools
def rsetattr(obj, attr, val):
pre, _, post = attr.rpartition('.')
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))

3
asyncron/views.py Normal file
View File

@ -0,0 +1,3 @@
from django.shortcuts import render
# Create your views here.

300
asyncron/workers.py Normal file
View File

@ -0,0 +1,300 @@
from django.db import IntegrityError, models, close_old_connections
from django.utils import timezone
from asgiref.sync import sync_to_async
import os, signal
import time
import threading
import logging, traceback
import asyncio
import collections, functools
import random
class AsyncronWorker:
INSTANCES = [] #AsyncronWorker instance
MAX_COUNT = 0
EXIST_SIGNALS = [
signal.SIGABRT,
signal.SIGHUP,
signal.SIGQUIT,
signal.SIGINT,
signal.SIGTERM
]
@classmethod
def override_exit_signals( cls ):
for sig in cls.EXIST_SIGNALS:
to_override = signal.getsignal(sig)
if getattr(to_override, "already_wrapped", False):
cls.log.warning(
f"An attempt was made to wrap around the {signal.strsignal(sig)} signal again!"
" Make sure you only call asyncron.AsyncronWorker.override_exit_signals once per process."
)
continue
if to_override and callable(to_override):
def wrapped( signum, frame ):
cls.sigcatch( signum, frame )
return to_override( signum, frame )
wrapped.already_wrapped = True
cls.log.debug(f"Wrapped {to_override} inside sigcatch for {signal.strsignal(sig)}")
signal.signal(sig, wrapped)
else:
cls.log.debug(f"Direct sigcatch for {signal.strsignal(sig)}")
signal.signal(sig, cls.sigcatch)
@classmethod
def sigcatch( cls, signum, frame ):
cls.stop(f"Signal {signal.strsignal(signum)}")
@classmethod
def stop( cls, reason = None ):
cls.log.info(f"Stopping AsyncronWorker(s): {reason}")
for worker in cls.INSTANCES:
if worker.is_stopping: continue
worker.is_stopping = True
worker.loop.call_soon_threadsafe(worker.loop.stop)
for worker in cls.INSTANCES:
if worker.thread.is_alive():
worker.thread.join()
@classmethod
def init( cls ):
if len(cls.INSTANCES) < cls.MAX_COUNT: cls()
#TODO: Use this to skip the 1 second delay in the self.start method on higher traffic servers.
#from django.db.backends.signals import connection_created
#from django.db.backends.postgresql.base import DatabaseWrapper
#from django.dispatch import receiver
#@receiver(connection_created, sender=DatabaseWrapper)
#def initial_connection_to_db(sender, **kwargs):
# if len(cls.INSTANCES) < cls.MAX_COUNT: cls()
##
## Start of instance methods
##
def __init__( self, daemon = True ):
self.INSTANCES.append(self)
self.is_stopping = False
self.clearing_dead_workers = False
self.watching_models = collections.defaultdict( set ) # Model -> Set of key name of the tasks
if daemon:
self.thread = threading.Thread( target = self.start )
self.thread.start()
def start( self, is_robust = False ):
assert not hasattr(self, "loop"), "This worker is already running!"
from .models import Worker, Task, Trace
self.model = Worker( pid = os.getpid(), thread_id = threading.get_ident(), is_robust = is_robust )
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop( self.loop )
#Fight over who's gonna be the master, prove your health in the process!
self.loop.create_task( self.master_loop() )
self.loop.create_task( self.work_loop() )
time.sleep(0.3) #To avoid the django initialization warning!
self.model.save()
self.model.refresh_from_db()
#Fill in the ID fields of the tasks we didn't dare to check with db until now
from .models import Task
for func in Task.registered_tasks.values():
task = func.task
if not task.pk: task.pk = Task.objects.get( name = task.name ).pk
self.attach_django_signals()
try:
self.loop.run_forever() #This is the lifetime of this worker
except KeyboardInterrupt:
print("Received exit, exiting")
count = Trace.objects.filter( status__in = "SWRP", worker_lock = self.model ).update(
status_reason = "Worker died during execution",
status = "A", worker_lock = None
)
#DONT print anything in here!
#if count: print(f"Had to cancel {count} task(s).") #cls.log.warning
self.model.delete()
#self.loop.call_soon(self.started.set)
def attach_django_signals( self ):
django_signals = {
name : attr
for name in ["post_save", "post_delete"] #TO Expand: dir(models.signals)
if not name.startswith("_") #Dont get private stuff
and ( attr := getattr(models.signals, name) ) #Just an assignment
and isinstance( attr, models.signals.ModelSignal ) #Is a signal related to models!
}
for name, signal in django_signals.items():
signal.connect( functools.partial( self.model_changed, name ) )
from .models import Task
for name, task in Task.registered_tasks.items():
if not hasattr(task, 'watching_models'): continue
for model in getattr(task, 'watching_models'):
self.watching_models[ model ].add( name )
def model_changed( self, signal_name, sender, signal, instance, **kwargs ):
from .models import Task
for name in self.watching_models[instance.__class__]:
asyncio.run_coroutine_threadsafe(
Task.registered_tasks[name].task.ensure_quick_execution( reason = f"Change ({signal_name}) on {instance}" ),
self.loop
)
async def master_loop( self ):
from .models import Worker, Task, Trace
#Delete dead masters every now and then!
last_overtake_attempt = 0
current_master = False
while True:
try:
await Worker.objects.filter( is_master = False ).aupdate( is_master = models.Q(id = self.model.id) )
except IntegrityError: # I'm not master!
loop_wait = 5 + random.random() * 15
if current_master: self.log.info(f"[Asyncron][W{self.model.id}] No longer master.")
current_master = False
if last_overtake_attempt + 60 < time.time():
last_overtake_attempt = time.time()
took_master = False
if self.model.is_robust:
took_master = await Worker.objects.filter( is_master = True, is_robust = False ).aupdate( is_master = False )
loop_wait = 0
else:
await Worker.objects.filter(
is_master = True,
last_crowning_attempt__lte = timezone.now() - timezone.timedelta( minutes = 5 )
).aupdate( is_master = False )
else: #I am Master!
loop_wait = 2 + random.random() * 3
if not current_master: self.log.info(f"[Asyncron][W{self.model.id}] Running as master.")
current_master = True
if not self.clearing_dead_workers:
self.loop.create_task( self.clear_dead_workers() )
await self.sync_tasks()
await self.clear_orphaned_traces()
finally:
await Worker.objects.filter( id = self.model.id ).aupdate( last_crowning_attempt = timezone.now() )
await asyncio.sleep( loop_wait )
async def clear_orphaned_traces( self ):
from .models import Worker, Task, Trace
await Trace.objects.filter( worker_lock = None, status__in = "RPW" ).adelete()
async def clear_dead_workers( self ):
self.clearing_dead_workers = True
from .models import Worker, Task, Trace
await Worker.objects.filter(
last_crowning_attempt__lte = timezone.now() - timezone.timedelta( seconds = 30 ),
in_grace = False
).aupdate( in_grace = True )
async for worker in Worker.objects.filter( in_grace = False, last_crowning_attempt = None ):
if not await sync_to_async( worker.is_proc_alive )():
await worker.adelete()
await asyncio.sleep( 30 )
await Worker.objects.filter( in_grace = True ).adelete()
self.clearing_dead_workers = False
async def sync_tasks( self ):
from .models import Task
for name, func in Task.registered_tasks.items():
init_task = func.task
try:
func.task = await Task.objects.aget( name = name )
except Task.DoesNotExist:
await func.task.asave()
await func.task.arefresh_from_db()
else: #For now, to commit changes to db
init_task.id = func.task.id
await init_task.asave()
await func.task.arefresh_from_db()
async def work_loop( self ):
self.check_interval = 0
while True:
await asyncio.sleep( self.check_interval )
self.check_interval = 10
try:
await self.check_scheduled()
await sync_to_async( close_old_connections )()
except Exception as e:
self.log.warning(f"[Asyncron] check_scheduled failed: {e}")
print( traceback.format_exc() )
self.check_interval = 20
async def check_scheduled( self ):
from .models import Task, Trace
Ts = Task.objects.exclude( interval = None ).exclude(
trace__status = "S"
).exclude( worker_type = "D" if self.model.is_robust else "R" )
async for task in Ts:
trace = task.new_trace()
await trace.reschedule( reason = "Auto Scheduled" )
locked = await Task.objects.filter( id = task.id, worker_lock = None ).aupdate( worker_lock = self.model )
if locked:
await trace.asave()
await Task.objects.filter( id = task.id, worker_lock = self.model ).aupdate( worker_lock = None )
early_seconds = 5 + self.check_interval * ( 1 + random.random() )
async for trace in Trace.objects.filter( status = "S", worker_lock = None, scheduled_datetime__lte = timezone.now() + timezone.timedelta( seconds = early_seconds ) ):
await trace.eval_related()
#print(f"Checking {trace} to do now: {trace.scheduled_datetime - timezone.now()}")
count = await Trace.objects.filter( id = trace.id, status = "S" ).aupdate( status = "W", worker_lock = self.model )
if not count: continue #Lost the race condition to another worker.
self.loop.create_task( self.start_trace_on_time( trace ) )
async def start_trace_on_time( self, trace ):
from .models import Trace
await asyncio.sleep( ( timezone.now() - trace.scheduled_datetime ).total_seconds() )
await trace.arefresh_from_db()
await trace.start()
trace.worker_lock = None
await trace.asave( update_fields = ['worker_lock'] )
#Traces for the same task that we are done with (Completed, Aborted, Errored)
QuerySet = Trace.objects.filter(
task_id = trace.task_id, status__in = "CAE", protected = False, worker_lock = None
).order_by('-register_datetime')
#Should be deleted after the threashold
max_count = trace.task.max_completed_traces if trace.status == "C" else trace.task.max_failed_traces
await QuerySet.exclude(
id__in = QuerySet[:max_count].values_list( 'id', flat = True )
).adelete()