Most of the way through the stator refactor
This commit is contained in:
		
							parent
							
								
									61c324508e
								
							
						
					
					
						commit
						7746abbbb7
					
				@ -1,8 +1,17 @@
 | 
			
		||||
from django.contrib import admin
 | 
			
		||||
 | 
			
		||||
from stator.models import StatorTask
 | 
			
		||||
from stator.models import StatorError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@admin.register(StatorTask)
 | 
			
		||||
@admin.register(StatorError)
 | 
			
		||||
class DomainAdmin(admin.ModelAdmin):
 | 
			
		||||
    list_display = ["id", "model_label", "instance_pk", "locked_until"]
 | 
			
		||||
    list_display = [
 | 
			
		||||
        "id",
 | 
			
		||||
        "date",
 | 
			
		||||
        "model_label",
 | 
			
		||||
        "instance_pk",
 | 
			
		||||
        "from_state",
 | 
			
		||||
        "to_state",
 | 
			
		||||
        "error",
 | 
			
		||||
    ]
 | 
			
		||||
    ordering = ["-date"]
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,16 @@
 | 
			
		||||
import datetime
 | 
			
		||||
from functools import wraps
 | 
			
		||||
from typing import Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
 | 
			
		||||
 | 
			
		||||
from django.db import models
 | 
			
		||||
from django.utils import timezone
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    Callable,
 | 
			
		||||
    ClassVar,
 | 
			
		||||
    Dict,
 | 
			
		||||
    List,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Set,
 | 
			
		||||
    Tuple,
 | 
			
		||||
    Type,
 | 
			
		||||
    Union,
 | 
			
		||||
    cast,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StateGraph:
 | 
			
		||||
@ -13,7 +20,7 @@ class StateGraph:
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    states: ClassVar[Dict[str, "State"]]
 | 
			
		||||
    choices: ClassVar[List[Tuple[str, str]]]
 | 
			
		||||
    choices: ClassVar[List[Tuple[object, str]]]
 | 
			
		||||
    initial_state: ClassVar["State"]
 | 
			
		||||
    terminal_states: ClassVar[Set["State"]]
 | 
			
		||||
 | 
			
		||||
@ -50,7 +57,7 @@ class StateGraph:
 | 
			
		||||
        cls.initial_state = initial_state
 | 
			
		||||
        cls.terminal_states = terminal_states
 | 
			
		||||
        # Generate choices
 | 
			
		||||
        cls.choices = [(name, name) for name in cls.states.keys()]
 | 
			
		||||
        cls.choices = [(state, name) for name, state in cls.states.items()]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class State:
 | 
			
		||||
@ -63,7 +70,7 @@ class State:
 | 
			
		||||
        self.parents: Set["State"] = set()
 | 
			
		||||
        self.children: Dict["State", "Transition"] = {}
 | 
			
		||||
 | 
			
		||||
    def _add_to_graph(self, graph: StateGraph, name: str):
 | 
			
		||||
    def _add_to_graph(self, graph: Type[StateGraph], name: str):
 | 
			
		||||
        self.graph = graph
 | 
			
		||||
        self.name = name
 | 
			
		||||
        self.graph.states[name] = self
 | 
			
		||||
@ -71,13 +78,19 @@ class State:
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"<State {self.name}>"
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return self.name
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self.name)
 | 
			
		||||
 | 
			
		||||
    def add_transition(
 | 
			
		||||
        self,
 | 
			
		||||
        other: "State",
 | 
			
		||||
        handler: Optional[Union[str, Callable]] = None,
 | 
			
		||||
        handler: Optional[Callable] = None,
 | 
			
		||||
        priority: int = 0,
 | 
			
		||||
    ) -> Callable:
 | 
			
		||||
        def decorator(handler: Union[str, Callable]):
 | 
			
		||||
        def decorator(handler: Callable[[Any], bool]):
 | 
			
		||||
            self.children[other] = Transition(
 | 
			
		||||
                self,
 | 
			
		||||
                other,
 | 
			
		||||
@ -85,9 +98,7 @@ class State:
 | 
			
		||||
                priority=priority,
 | 
			
		||||
            )
 | 
			
		||||
            other.parents.add(self)
 | 
			
		||||
            # All handlers should be class methods, so do that automatically.
 | 
			
		||||
            if callable(handler):
 | 
			
		||||
                return classmethod(handler)
 | 
			
		||||
            return handler
 | 
			
		||||
 | 
			
		||||
        # If we're not being called as a decorator, invoke it immediately
 | 
			
		||||
        if handler is not None:
 | 
			
		||||
@ -113,7 +124,7 @@ class State:
 | 
			
		||||
        if automatic_only:
 | 
			
		||||
            transitions = [t for t in self.children.values() if t.automatic]
 | 
			
		||||
        else:
 | 
			
		||||
            transitions = self.children.values()
 | 
			
		||||
            transitions = list(self.children.values())
 | 
			
		||||
        return sorted(transitions, key=lambda t: t.priority, reverse=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -141,7 +152,10 @@ class Transition:
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(self.handler, str):
 | 
			
		||||
            self.handler = getattr(self.from_state.graph, self.handler)
 | 
			
		||||
        return self.handler
 | 
			
		||||
        return cast(Callable, self.handler)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f"<Transition {self.from_state} -> {self.to_state}>"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ManualTransition(Transition):
 | 
			
		||||
@ -157,6 +171,5 @@ class ManualTransition(Transition):
 | 
			
		||||
    ):
 | 
			
		||||
        self.from_state = from_state
 | 
			
		||||
        self.to_state = to_state
 | 
			
		||||
        self.handler = None
 | 
			
		||||
        self.priority = 0
 | 
			
		||||
        self.automatic = False
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										0
									
								
								stator/management/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								stator/management/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								stator/management/commands/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								stator/management/commands/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										28
									
								
								stator/management/commands/runstator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								stator/management/commands/runstator.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,28 @@
 | 
			
		||||
from typing import List, Type, cast
 | 
			
		||||
 | 
			
		||||
from asgiref.sync import async_to_sync
 | 
			
		||||
from django.apps import apps
 | 
			
		||||
from django.core.management.base import BaseCommand
 | 
			
		||||
 | 
			
		||||
from stator.models import StatorModel
 | 
			
		||||
from stator.runner import StatorRunner
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Command(BaseCommand):
 | 
			
		||||
    help = "Runs a Stator runner for a short period"
 | 
			
		||||
 | 
			
		||||
    def add_arguments(self, parser):
 | 
			
		||||
        parser.add_argument("model_labels", nargs="*", type=str)
 | 
			
		||||
 | 
			
		||||
    def handle(self, model_labels: List[str], *args, **options):
 | 
			
		||||
        # Resolve the models list into names
 | 
			
		||||
        models = cast(
 | 
			
		||||
            List[Type[StatorModel]],
 | 
			
		||||
            [apps.get_model(label) for label in model_labels],
 | 
			
		||||
        )
 | 
			
		||||
        if not models:
 | 
			
		||||
            models = StatorModel.subclasses
 | 
			
		||||
        print("Running for models: " + " ".join(m._meta.label_lower for m in models))
 | 
			
		||||
        # Run a runner
 | 
			
		||||
        runner = StatorRunner(models)
 | 
			
		||||
        async_to_sync(runner.run)()
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Generated by Django 4.1.3 on 2022-11-09 05:46
 | 
			
		||||
# Generated by Django 4.1.3 on 2022-11-10 03:24
 | 
			
		||||
 | 
			
		||||
from django.db import migrations, models
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,7 @@ class Migration(migrations.Migration):
 | 
			
		||||
 | 
			
		||||
    operations = [
 | 
			
		||||
        migrations.CreateModel(
 | 
			
		||||
            name="StatorTask",
 | 
			
		||||
            name="StatorError",
 | 
			
		||||
            fields=[
 | 
			
		||||
                (
 | 
			
		||||
                    "id",
 | 
			
		||||
@ -24,8 +24,11 @@ class Migration(migrations.Migration):
 | 
			
		||||
                ),
 | 
			
		||||
                ("model_label", models.CharField(max_length=200)),
 | 
			
		||||
                ("instance_pk", models.CharField(max_length=200)),
 | 
			
		||||
                ("locked_until", models.DateTimeField(blank=True, null=True)),
 | 
			
		||||
                ("priority", models.IntegerField(default=0)),
 | 
			
		||||
                ("from_state", models.CharField(max_length=200)),
 | 
			
		||||
                ("to_state", models.CharField(max_length=200)),
 | 
			
		||||
                ("date", models.DateTimeField(auto_now_add=True)),
 | 
			
		||||
                ("error", models.TextField()),
 | 
			
		||||
                ("error_details", models.TextField(blank=True, null=True)),
 | 
			
		||||
            ],
 | 
			
		||||
        ),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										203
									
								
								stator/models.py
									
									
									
									
									
								
							
							
						
						
									
										203
									
								
								stator/models.py
									
									
									
									
									
								
							@ -1,14 +1,13 @@
 | 
			
		||||
import datetime
 | 
			
		||||
from functools import reduce
 | 
			
		||||
from typing import Type, cast
 | 
			
		||||
import traceback
 | 
			
		||||
from typing import ClassVar, List, Optional, Type, cast
 | 
			
		||||
 | 
			
		||||
from asgiref.sync import sync_to_async
 | 
			
		||||
from django.apps import apps
 | 
			
		||||
from django.db import models, transaction
 | 
			
		||||
from django.utils import timezone
 | 
			
		||||
from django.utils.functional import classproperty
 | 
			
		||||
 | 
			
		||||
from stator.graph import State, StateGraph
 | 
			
		||||
from stator.graph import State, StateGraph, Transition
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StateField(models.CharField):
 | 
			
		||||
@ -55,6 +54,9 @@ class StatorModel(models.Model):
 | 
			
		||||
    concrete model yourself.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # If this row is up for transition attempts
 | 
			
		||||
    state_ready = models.BooleanField(default=False)
 | 
			
		||||
 | 
			
		||||
    # When the state last actually changed, or the date of instance creation
 | 
			
		||||
    state_changed = models.DateTimeField(auto_now_add=True)
 | 
			
		||||
 | 
			
		||||
@ -62,68 +64,128 @@ class StatorModel(models.Model):
 | 
			
		||||
    # (and not successful, as this is cleared on transition)
 | 
			
		||||
    state_attempted = models.DateTimeField(blank=True, null=True)
 | 
			
		||||
 | 
			
		||||
    # If a lock is out on this row, when it is locked until
 | 
			
		||||
    # (we don't identify the lock owner, as there's no heartbeats)
 | 
			
		||||
    state_locked_until = models.DateTimeField(null=True, blank=True)
 | 
			
		||||
 | 
			
		||||
    # Collection of subclasses of us
 | 
			
		||||
    subclasses: ClassVar[List[Type["StatorModel"]]] = []
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        abstract = True
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def schedule_overdue(cls, now=None) -> models.QuerySet:
 | 
			
		||||
        """
 | 
			
		||||
        Finds instances of this model that need to run and schedule them.
 | 
			
		||||
        """
 | 
			
		||||
        q = models.Q()
 | 
			
		||||
        for transition in cls.state_graph.transitions(automatic_only=True):
 | 
			
		||||
            q = q | transition.get_query(now=now)
 | 
			
		||||
        return cls.objects.filter(q)
 | 
			
		||||
    def __init_subclass__(cls) -> None:
 | 
			
		||||
        if cls is not StatorModel:
 | 
			
		||||
            cls.subclasses.append(cls)
 | 
			
		||||
 | 
			
		||||
    @classproperty
 | 
			
		||||
    def state_graph(cls) -> Type[StateGraph]:
 | 
			
		||||
        return cls._meta.get_field("state").graph
 | 
			
		||||
 | 
			
		||||
    def schedule_transition(self, priority: int = 0):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    async def atransition_schedule_due(cls, now=None) -> models.QuerySet:
 | 
			
		||||
        """
 | 
			
		||||
        Finds instances of this model that need to run and schedule them.
 | 
			
		||||
        """
 | 
			
		||||
        q = models.Q()
 | 
			
		||||
        for state in cls.state_graph.states.values():
 | 
			
		||||
            state = cast(State, state)
 | 
			
		||||
            if not state.terminal:
 | 
			
		||||
                q = q | models.Q(
 | 
			
		||||
                    (
 | 
			
		||||
                        models.Q(
 | 
			
		||||
                            state_attempted__lte=timezone.now()
 | 
			
		||||
                            - datetime.timedelta(seconds=state.try_interval)
 | 
			
		||||
                        )
 | 
			
		||||
                        | models.Q(state_attempted__isnull=True)
 | 
			
		||||
                    ),
 | 
			
		||||
                    state=state.name,
 | 
			
		||||
                )
 | 
			
		||||
        await cls.objects.filter(q).aupdate(state_ready=True)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def transition_get_with_lock(
 | 
			
		||||
        cls, number: int, lock_expiry: datetime.datetime
 | 
			
		||||
    ) -> List["StatorModel"]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns up to `number` tasks for execution, having locked them.
 | 
			
		||||
        """
 | 
			
		||||
        with transaction.atomic():
 | 
			
		||||
            selected = list(
 | 
			
		||||
                cls.objects.filter(state_locked_until__isnull=True, state_ready=True)[
 | 
			
		||||
                    :number
 | 
			
		||||
                ].select_for_update()
 | 
			
		||||
            )
 | 
			
		||||
            cls.objects.filter(pk__in=[i.pk for i in selected]).update(
 | 
			
		||||
                state_locked_until=timezone.now()
 | 
			
		||||
            )
 | 
			
		||||
        return selected
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    async def atransition_get_with_lock(
 | 
			
		||||
        cls, number: int, lock_expiry: datetime.datetime
 | 
			
		||||
    ) -> List["StatorModel"]:
 | 
			
		||||
        return await sync_to_async(cls.transition_get_with_lock)(number, lock_expiry)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    async def atransition_clean_locks(cls):
 | 
			
		||||
        await cls.objects.filter(state_locked_until__lte=timezone.now()).aupdate(
 | 
			
		||||
            state_locked_until=None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def transition_schedule(self):
 | 
			
		||||
        """
 | 
			
		||||
        Adds this instance to the queue to get its state transition attempted.
 | 
			
		||||
 | 
			
		||||
        The scheduler will call this, but you can also call it directly if you
 | 
			
		||||
        know it'll be ready and want to lower latency.
 | 
			
		||||
        """
 | 
			
		||||
        StatorTask.schedule_for_execution(self, priority=priority)
 | 
			
		||||
        self.state_ready = True
 | 
			
		||||
        self.save()
 | 
			
		||||
 | 
			
		||||
    async def attempt_transition(self):
 | 
			
		||||
    async def atransition_attempt(self) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Attempts to transition the current state by running its handler(s).
 | 
			
		||||
        """
 | 
			
		||||
        # Try each transition in priority order
 | 
			
		||||
        for transition in self.state_graph.states[self.state].transitions(
 | 
			
		||||
            automatic_only=True
 | 
			
		||||
        ):
 | 
			
		||||
            success = await transition.get_handler()(self)
 | 
			
		||||
        for transition in self.state.transitions(automatic_only=True):
 | 
			
		||||
            try:
 | 
			
		||||
                success = await transition.get_handler()(self)
 | 
			
		||||
            except BaseException as e:
 | 
			
		||||
                await StatorError.acreate_from_instance(self, transition, e)
 | 
			
		||||
                traceback.print_exc()
 | 
			
		||||
                continue
 | 
			
		||||
            if success:
 | 
			
		||||
                await self.perform_transition(transition.to_state.name)
 | 
			
		||||
                return
 | 
			
		||||
                await self.atransition_perform(transition.to_state.name)
 | 
			
		||||
                return True
 | 
			
		||||
        await self.__class__.objects.filter(pk=self.pk).aupdate(
 | 
			
		||||
            state_attempted=timezone.now()
 | 
			
		||||
            state_attempted=timezone.now(),
 | 
			
		||||
            state_locked_until=None,
 | 
			
		||||
            state_ready=False,
 | 
			
		||||
        )
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    async def perform_transition(self, state_name):
 | 
			
		||||
    def transition_perform(self, state_name):
 | 
			
		||||
        """
 | 
			
		||||
        Transitions the instance to the given state name
 | 
			
		||||
        Transitions the instance to the given state name, forcibly.
 | 
			
		||||
        """
 | 
			
		||||
        if state_name not in self.state_graph.states:
 | 
			
		||||
            raise ValueError(f"Invalid state {state_name}")
 | 
			
		||||
        await self.__class__.objects.filter(pk=self.pk).aupdate(
 | 
			
		||||
        self.__class__.objects.filter(pk=self.pk).update(
 | 
			
		||||
            state=state_name,
 | 
			
		||||
            state_changed=timezone.now(),
 | 
			
		||||
            state_attempted=None,
 | 
			
		||||
            state_locked_until=None,
 | 
			
		||||
            state_ready=False,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    atransition_perform = sync_to_async(transition_perform)
 | 
			
		||||
 | 
			
		||||
class StatorTask(models.Model):
 | 
			
		||||
 | 
			
		||||
class StatorError(models.Model):
 | 
			
		||||
    """
 | 
			
		||||
    The model that we use for an internal scheduling queue.
 | 
			
		||||
 | 
			
		||||
    Entries in this queue are up for checking and execution - it also performs
 | 
			
		||||
    locking to ensure we get closer to exactly-once execution (but we err on
 | 
			
		||||
    the side of at-least-once)
 | 
			
		||||
    Tracks any errors running the transitions.
 | 
			
		||||
    Meant to be cleaned out regularly. Should probably be a log.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # appname.modelname (lowercased) label for the model this represents
 | 
			
		||||
@ -132,60 +194,33 @@ class StatorTask(models.Model):
 | 
			
		||||
    # The primary key of that model (probably int or str)
 | 
			
		||||
    instance_pk = models.CharField(max_length=200)
 | 
			
		||||
 | 
			
		||||
    # Locking columns (no runner ID, as we have no heartbeats - all runners
 | 
			
		||||
    # only live for a short amount of time anyway)
 | 
			
		||||
    locked_until = models.DateTimeField(null=True, blank=True)
 | 
			
		||||
    # The state we moved from
 | 
			
		||||
    from_state = models.CharField(max_length=200)
 | 
			
		||||
 | 
			
		||||
    # Basic total ordering priority - higher is more important
 | 
			
		||||
    priority = models.IntegerField(default=0)
 | 
			
		||||
    # The state we moved to (or tried to)
 | 
			
		||||
    to_state = models.CharField(max_length=200)
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"#{self.pk}: {self.model_label}.{self.instance_pk}"
 | 
			
		||||
    # When it happened
 | 
			
		||||
    date = models.DateTimeField(auto_now_add=True)
 | 
			
		||||
 | 
			
		||||
    # Error name
 | 
			
		||||
    error = models.TextField()
 | 
			
		||||
 | 
			
		||||
    # Error details
 | 
			
		||||
    error_details = models.TextField(blank=True, null=True)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def schedule_for_execution(cls, model_instance: StatorModel, priority: int = 0):
 | 
			
		||||
        # We don't do a transaction here as it's fine to occasionally double up
 | 
			
		||||
        model_label = model_instance._meta.label_lower
 | 
			
		||||
        pk = model_instance.pk
 | 
			
		||||
        # TODO: Increase priority of existing if present
 | 
			
		||||
        if not cls.objects.filter(
 | 
			
		||||
            model_label=model_label, instance_pk=pk, locked__isnull=True
 | 
			
		||||
        ).exists():
 | 
			
		||||
            StatorTask.objects.create(
 | 
			
		||||
                model_label=model_label,
 | 
			
		||||
                instance_pk=pk,
 | 
			
		||||
                priority=priority,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_for_execution(cls, number: int, lock_expiry: datetime.datetime):
 | 
			
		||||
        """
 | 
			
		||||
        Returns up to `number` tasks for execution, having locked them.
 | 
			
		||||
        """
 | 
			
		||||
        with transaction.atomic():
 | 
			
		||||
            selected = list(
 | 
			
		||||
                cls.objects.filter(locked_until__isnull=True)[
 | 
			
		||||
                    :number
 | 
			
		||||
                ].select_for_update()
 | 
			
		||||
            )
 | 
			
		||||
            cls.objects.filter(pk__in=[i.pk for i in selected]).update(
 | 
			
		||||
                locked_until=timezone.now()
 | 
			
		||||
            )
 | 
			
		||||
        return selected
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    async def aget_for_execution(cls, number: int, lock_expiry: datetime.datetime):
 | 
			
		||||
        return await sync_to_async(cls.get_for_execution)(number, lock_expiry)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    async def aclean_old_locks(cls):
 | 
			
		||||
        await cls.objects.filter(locked_until__lte=timezone.now()).aupdate(
 | 
			
		||||
            locked_until=None
 | 
			
		||||
    async def acreate_from_instance(
 | 
			
		||||
        cls,
 | 
			
		||||
        instance: StatorModel,
 | 
			
		||||
        transition: Transition,
 | 
			
		||||
        exception: Optional[BaseException] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        return await cls.objects.acreate(
 | 
			
		||||
            model_label=instance._meta.label_lower,
 | 
			
		||||
            instance_pk=str(instance.pk),
 | 
			
		||||
            from_state=transition.from_state,
 | 
			
		||||
            to_state=transition.to_state,
 | 
			
		||||
            error=str(exception),
 | 
			
		||||
            error_details=traceback.format_exc(),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def aget_model_instance(self) -> StatorModel:
 | 
			
		||||
        model = apps.get_model(self.model_label)
 | 
			
		||||
        return cast(StatorModel, await model.objects.aget(pk=self.pk))
 | 
			
		||||
 | 
			
		||||
    async def adelete(self):
 | 
			
		||||
        self.__class__.objects.adelete(pk=self.pk)
 | 
			
		||||
 | 
			
		||||
@ -4,11 +4,9 @@ import time
 | 
			
		||||
import uuid
 | 
			
		||||
from typing import List, Type
 | 
			
		||||
 | 
			
		||||
from asgiref.sync import sync_to_async
 | 
			
		||||
from django.db import transaction
 | 
			
		||||
from django.utils import timezone
 | 
			
		||||
 | 
			
		||||
from stator.models import StatorModel, StatorTask
 | 
			
		||||
from stator.models import StatorModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class StatorRunner:
 | 
			
		||||
@ -22,6 +20,7 @@ class StatorRunner:
 | 
			
		||||
    LOCK_TIMEOUT = 120
 | 
			
		||||
 | 
			
		||||
    MAX_TASKS = 30
 | 
			
		||||
    MAX_TASKS_PER_MODEL = 5
 | 
			
		||||
 | 
			
		||||
    def __init__(self, models: List[Type[StatorModel]]):
 | 
			
		||||
        self.models = models
 | 
			
		||||
@ -32,38 +31,44 @@ class StatorRunner:
 | 
			
		||||
        self.handled = 0
 | 
			
		||||
        self.tasks = []
 | 
			
		||||
        # Clean up old locks
 | 
			
		||||
        await StatorTask.aclean_old_locks()
 | 
			
		||||
        # Examine what needs scheduling
 | 
			
		||||
 | 
			
		||||
        print("Running initial cleaning and scheduling")
 | 
			
		||||
        initial_tasks = []
 | 
			
		||||
        for model in self.models:
 | 
			
		||||
            initial_tasks.append(model.atransition_clean_locks())
 | 
			
		||||
            initial_tasks.append(model.atransition_schedule_due())
 | 
			
		||||
        await asyncio.gather(*initial_tasks)
 | 
			
		||||
        # For the first time period, launch tasks
 | 
			
		||||
        print("Running main task loop")
 | 
			
		||||
        while (time.monotonic() - start_time) < self.START_TIMEOUT:
 | 
			
		||||
            self.remove_completed_tasks()
 | 
			
		||||
            space_remaining = self.MAX_TASKS - len(self.tasks)
 | 
			
		||||
            # Fetch new tasks
 | 
			
		||||
            if space_remaining > 0:
 | 
			
		||||
                for new_task in await StatorTask.aget_for_execution(
 | 
			
		||||
                    space_remaining,
 | 
			
		||||
                    timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),
 | 
			
		||||
                ):
 | 
			
		||||
                    self.tasks.append(asyncio.create_task(self.run_task(new_task)))
 | 
			
		||||
                    self.handled += 1
 | 
			
		||||
            for model in self.models:
 | 
			
		||||
                if space_remaining > 0:
 | 
			
		||||
                    for instance in await model.atransition_get_with_lock(
 | 
			
		||||
                        min(space_remaining, self.MAX_TASKS_PER_MODEL),
 | 
			
		||||
                        timezone.now() + datetime.timedelta(seconds=self.LOCK_TIMEOUT),
 | 
			
		||||
                    ):
 | 
			
		||||
                        print(
 | 
			
		||||
                            f"Attempting transition on {instance._meta.label_lower}#{instance.pk}"
 | 
			
		||||
                        )
 | 
			
		||||
                        self.tasks.append(
 | 
			
		||||
                            asyncio.create_task(instance.atransition_attempt())
 | 
			
		||||
                        )
 | 
			
		||||
                        self.handled += 1
 | 
			
		||||
                        space_remaining -= 1
 | 
			
		||||
            # Prevent busylooping
 | 
			
		||||
            await asyncio.sleep(0.01)
 | 
			
		||||
            await asyncio.sleep(0.1)
 | 
			
		||||
        # Then wait for tasks to finish
 | 
			
		||||
        print("Waiting for tasks to complete")
 | 
			
		||||
        while (time.monotonic() - start_time) < self.TOTAL_TIMEOUT:
 | 
			
		||||
            self.remove_completed_tasks()
 | 
			
		||||
            if not self.tasks:
 | 
			
		||||
                break
 | 
			
		||||
            # Prevent busylooping
 | 
			
		||||
            await asyncio.sleep(1)
 | 
			
		||||
        print("Complete")
 | 
			
		||||
        return self.handled
 | 
			
		||||
 | 
			
		||||
    async def run_task(self, task: StatorTask):
 | 
			
		||||
        # Resolve the model instance
 | 
			
		||||
        model_instance = await task.aget_model_instance()
 | 
			
		||||
        await model_instance.attempt_transition()
 | 
			
		||||
        # Remove ourselves from the database as complete
 | 
			
		||||
        await task.adelete()
 | 
			
		||||
 | 
			
		||||
    def remove_completed_tasks(self):
 | 
			
		||||
        self.tasks = [t for t in self.tasks if not t.done()]
 | 
			
		||||
 | 
			
		||||
@ -51,14 +51,14 @@ def test_bad_declarations():
 | 
			
		||||
    # More than one initial state
 | 
			
		||||
    with pytest.raises(ValueError):
 | 
			
		||||
 | 
			
		||||
        class TestGraph(StateGraph):
 | 
			
		||||
        class TestGraph2(StateGraph):
 | 
			
		||||
            initial = State()
 | 
			
		||||
            initial2 = State()
 | 
			
		||||
 | 
			
		||||
    # No initial states
 | 
			
		||||
    with pytest.raises(ValueError):
 | 
			
		||||
 | 
			
		||||
        class TestGraph(StateGraph):
 | 
			
		||||
        class TestGraph3(StateGraph):
 | 
			
		||||
            loop = State()
 | 
			
		||||
            loop2 = State()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,23 @@
 | 
			
		||||
# Generated by Django 4.1.3 on 2022-11-10 03:24
 | 
			
		||||
 | 
			
		||||
from django.db import migrations, models
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Migration(migrations.Migration):
 | 
			
		||||
 | 
			
		||||
    dependencies = [
 | 
			
		||||
        ("users", "0004_remove_follow_state_locked_and_more"),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    operations = [
 | 
			
		||||
        migrations.AddField(
 | 
			
		||||
            model_name="follow",
 | 
			
		||||
            name="state_locked_until",
 | 
			
		||||
            field=models.DateTimeField(blank=True, null=True),
 | 
			
		||||
        ),
 | 
			
		||||
        migrations.AddField(
 | 
			
		||||
            model_name="follow",
 | 
			
		||||
            name="state_ready",
 | 
			
		||||
            field=models.BooleanField(default=False),
 | 
			
		||||
        ),
 | 
			
		||||
    ]
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
from .block import Block  # noqa
 | 
			
		||||
from .domain import Domain  # noqa
 | 
			
		||||
from .follow import Follow  # noqa
 | 
			
		||||
from .identity import Identity  # noqa
 | 
			
		||||
from .follow import Follow, FollowStates  # noqa
 | 
			
		||||
from .identity import Identity, IdentityStates  # noqa
 | 
			
		||||
from .user import User  # noqa
 | 
			
		||||
from .user_event import UserEvent  # noqa
 | 
			
		||||
 | 
			
		||||
@ -55,7 +55,7 @@ class Domain(models.Model):
 | 
			
		||||
            return cls.objects.create(domain=domain, local=False)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_local_domain(cls, domain: str) -> Optional["Domain"]:
 | 
			
		||||
    def get_domain(cls, domain: str) -> Optional["Domain"]:
 | 
			
		||||
        try:
 | 
			
		||||
            return cls.objects.get(
 | 
			
		||||
                models.Q(domain=domain) | models.Q(service_domain=domain)
 | 
			
		||||
 | 
			
		||||
@ -6,13 +6,13 @@ from stator.models import State, StateField, StateGraph, StatorModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FollowStates(StateGraph):
 | 
			
		||||
    pending = State(try_interval=3600)
 | 
			
		||||
    pending = State(try_interval=30)
 | 
			
		||||
    requested = State()
 | 
			
		||||
    accepted = State()
 | 
			
		||||
 | 
			
		||||
    @pending.add_transition(requested)
 | 
			
		||||
    async def try_request(cls, instance):
 | 
			
		||||
        print("Would have tried to follow")
 | 
			
		||||
    async def try_request(instance: "Follow"):  # type:ignore
 | 
			
		||||
        print("Would have tried to follow on", instance)
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    requested.add_manual_transition(accepted)
 | 
			
		||||
@ -73,11 +73,3 @@ class Follow(StatorModel):
 | 
			
		||||
                follow.state = FollowStates.accepted
 | 
			
		||||
            follow.save()
 | 
			
		||||
        return follow
 | 
			
		||||
 | 
			
		||||
    def undo(self):
 | 
			
		||||
        """
 | 
			
		||||
        Undoes this follow
 | 
			
		||||
        """
 | 
			
		||||
        if not self.target.local:
 | 
			
		||||
            Task.submit("follow_undo", str(self.pk))
 | 
			
		||||
        self.delete()
 | 
			
		||||
 | 
			
		||||
@ -14,9 +14,21 @@ from django.utils import timezone
 | 
			
		||||
from OpenSSL import crypto
 | 
			
		||||
 | 
			
		||||
from core.ld import canonicalise
 | 
			
		||||
from stator.models import State, StateField, StateGraph, StatorModel
 | 
			
		||||
from users.models.domain import Domain
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IdentityStates(StateGraph):
 | 
			
		||||
    outdated = State(try_interval=3600)
 | 
			
		||||
    updated = State()
 | 
			
		||||
 | 
			
		||||
    @outdated.add_transition(updated)
 | 
			
		||||
    async def fetch_identity(identity: "Identity"):  # type:ignore
 | 
			
		||||
        if identity.local:
 | 
			
		||||
            return True
 | 
			
		||||
        return await identity.fetch_actor()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def upload_namer(prefix, instance, filename):
 | 
			
		||||
    """
 | 
			
		||||
    Names uploaded images etc.
 | 
			
		||||
@ -26,7 +38,7 @@ def upload_namer(prefix, instance, filename):
 | 
			
		||||
    return f"{prefix}/{now.year}/{now.month}/{now.day}/{filename}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Identity(models.Model):
 | 
			
		||||
class Identity(StatorModel):
 | 
			
		||||
    """
 | 
			
		||||
    Represents both local and remote Fediverse identities (actors)
 | 
			
		||||
    """
 | 
			
		||||
@ -35,6 +47,8 @@ class Identity(models.Model):
 | 
			
		||||
    # one around as well for making nice URLs etc.
 | 
			
		||||
    actor_uri = models.CharField(max_length=500, unique=True)
 | 
			
		||||
 | 
			
		||||
    state = StateField(IdentityStates)
 | 
			
		||||
 | 
			
		||||
    local = models.BooleanField()
 | 
			
		||||
    users = models.ManyToManyField("users.User", related_name="identities")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,7 @@ from django.http import Http404
 | 
			
		||||
from users.models import Domain, Identity
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def by_handle_or_404(request, handle, local=True, fetch=False):
 | 
			
		||||
def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity:
 | 
			
		||||
    """
 | 
			
		||||
    Retrieves an Identity by its long or short handle.
 | 
			
		||||
    Domain-sensitive, so it will understand short handles on alternate domains.
 | 
			
		||||
@ -12,14 +12,17 @@ def by_handle_or_404(request, handle, local=True, fetch=False):
 | 
			
		||||
        if "HTTP_HOST" not in request.META:
 | 
			
		||||
            raise Http404("No hostname available")
 | 
			
		||||
        username = handle
 | 
			
		||||
        domain_instance = Domain.get_local_domain(request.META["HTTP_HOST"])
 | 
			
		||||
        domain_instance = Domain.get_domain(request.META["HTTP_HOST"])
 | 
			
		||||
        if domain_instance is None:
 | 
			
		||||
            raise Http404("No matching domains found")
 | 
			
		||||
        domain = domain_instance.domain
 | 
			
		||||
    else:
 | 
			
		||||
        username, domain = handle.split("@", 1)
 | 
			
		||||
        # Resolve the domain to the display domain
 | 
			
		||||
        domain = Domain.get_local_domain(request.META["HTTP_HOST"]).domain
 | 
			
		||||
        domain_instance = Domain.get_domain(domain)
 | 
			
		||||
        if domain_instance is None:
 | 
			
		||||
            raise Http404("No matching domains found")
 | 
			
		||||
        domain = domain_instance.domain
 | 
			
		||||
    identity = Identity.by_username_and_domain(
 | 
			
		||||
        username,
 | 
			
		||||
        domain,
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ from core.forms import FormHelper
 | 
			
		||||
from core.ld import canonicalise
 | 
			
		||||
from core.signatures import HttpSignature
 | 
			
		||||
from users.decorators import identity_required
 | 
			
		||||
from users.models import Domain, Follow, Identity
 | 
			
		||||
from users.models import Domain, Follow, Identity, IdentityStates
 | 
			
		||||
from users.shortcuts import by_handle_or_404
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -34,7 +34,7 @@ class ViewIdentity(TemplateView):
 | 
			
		||||
        )
 | 
			
		||||
        statuses = identity.statuses.all()[:100]
 | 
			
		||||
        if identity.data_age > settings.IDENTITY_MAX_AGE:
 | 
			
		||||
            Task.submit("identity_fetch", identity.handle)
 | 
			
		||||
            identity.transition_perform(IdentityStates.outdated)
 | 
			
		||||
        return {
 | 
			
		||||
            "identity": identity,
 | 
			
		||||
            "statuses": statuses,
 | 
			
		||||
@ -129,7 +129,7 @@ class CreateIdentity(FormView):
 | 
			
		||||
    def form_valid(self, form):
 | 
			
		||||
        username = form.cleaned_data["username"]
 | 
			
		||||
        domain = form.cleaned_data["domain"]
 | 
			
		||||
        domain_instance = Domain.get_local_domain(domain)
 | 
			
		||||
        domain_instance = Domain.get_domain(domain)
 | 
			
		||||
        new_identity = Identity.objects.create(
 | 
			
		||||
            actor_uri=f"https://{domain_instance.uri_domain}/@{username}@{domain}/actor/",
 | 
			
		||||
            username=username,
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user