Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not mark unchanged computed var deps as dirty #4488

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
22 changes: 19 additions & 3 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# A special event handler for setting base vars.
setvar: ClassVar[EventHandler]

# Track if computed vars have changed since last serialization
_changed_computed_vars: Set[str] = set()

# Track which computed vars have already been computed
_ready_computed_vars: Set[str] = set()

def __init__(
self,
parent_state: BaseState | None = None,
Expand Down Expand Up @@ -1850,11 +1856,12 @@ def _mark_dirty_computed_vars(self) -> None:
while dirty_vars:
calc_vars, dirty_vars = dirty_vars, set()
for cvar in self._dirty_computed_vars(from_vars=calc_vars):
self.dirty_vars.add(cvar)
dirty_vars.add(cvar)
actual_var = self.computed_vars.get(cvar)
if actual_var is not None:
assert actual_var is not None
if actual_var.has_changed(instance=self):
actual_var.mark_dirty(instance=self)
self.dirty_vars.add(cvar)
dirty_vars.add(cvar)

def _expired_computed_vars(self) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the expiration time.
Expand Down Expand Up @@ -2134,6 +2141,10 @@ def __getstate__(self):
state["__dict__"].pop("parent_state", None)
state["__dict__"].pop("substates", None)
state["__dict__"].pop("_was_touched", None)
state["__dict__"].pop("_changed_computed_vars", None)
state["__dict__"].pop("_ready_computed_vars", None)
state["__fields_set__"].discard("_changed_computed_vars")
state["__fields_set__"].discard("_ready_computed_vars")
# Remove all inherited vars.
for inherited_var_name in self.inherited_vars:
state["__dict__"].pop(inherited_var_name, None)
Expand All @@ -2150,6 +2161,9 @@ def __setstate__(self, state: dict[str, Any]):
state["__dict__"]["parent_state"] = None
state["__dict__"]["substates"] = {}
super().__setstate__(state)
self._was_touched = False
self._changed_computed_vars = set()
self._ready_computed_vars = set()

def _check_state_size(
self,
Expand Down Expand Up @@ -3131,6 +3145,8 @@ async def get_state(
root_state = self.states.get(client_token)
if root_state is not None:
# Retrieved state from memory.
root_state._changed_computed_vars = set()
root_state._ready_computed_vars = set()
return root_state

# Deserialize root state from disk.
Expand Down
2 changes: 2 additions & 0 deletions reflex/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def override(func: Callable) -> Callable:
"_abc_impl",
"_backend_vars",
"_was_touched",
"_changed_computed_vars",
"_ready_computed_vars",
}

if sys.version_info >= (3, 11):
Expand Down
76 changes: 61 additions & 15 deletions reflex/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2022,18 +2022,7 @@ def __get__(self, instance: BaseState | None, owner):
existing_var=self,
)

if not self._cache:
value = self.fget(instance)
else:
# handle caching
if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
# Set cache attr on state instance.
setattr(instance, self._cache_attr, self.fget(instance))
# Ensure the computed var gets serialized to redis.
instance._was_touched = True
# Set the last updated timestamp on the state instance.
setattr(instance, self._last_updated_attr, datetime.datetime.now())
value = getattr(instance, self._cache_attr)
value = self.get_value(instance)

if not _isinstance(value, self._var_type):
console.deprecate(
Expand Down Expand Up @@ -2158,14 +2147,71 @@ def _deps(
self_is_top_of_stack = False
return d

def mark_dirty(self, instance) -> None:
def mark_dirty(self, instance: BaseState) -> None:
"""Mark this ComputedVar as dirty.

Args:
instance: the state instance that needs to recompute the value.
"""
with contextlib.suppress(AttributeError):
delattr(instance, self._cache_attr)
instance._ready_computed_vars.discard(self._js_expr)

def already_computed(self, instance: BaseState) -> bool:
"""Check if the ComputedVar has already been computed.

Args:
instance: the state instance that needs to recompute the value.

Returns:
True if the ComputedVar has already been computed, False otherwise.
"""
if self.needs_update(instance):
return False
return self._js_expr in instance._ready_computed_vars

def get_value(self, instance: BaseState) -> RETURN_TYPE:
"""Get the value of the ComputedVar.

Args:
instance: the state instance that needs to recompute the value.

Returns:
The value of the ComputedVar.
"""
if not self._cache:
instance._was_touched = True
new = self.fget(instance)
return new

has_cache = hasattr(instance, self._cache_attr)

if self.already_computed(instance) and has_cache:
return getattr(instance, self._cache_attr)

cache_value = getattr(instance, self._cache_attr, None)
instance._ready_computed_vars.add(self._js_expr)
setattr(instance, self._last_updated_attr, datetime.datetime.now())
new_value = self.fget(instance)
if cache_value != new_value:
instance._changed_computed_vars.add(self._js_expr)
instance._was_touched = True
setattr(instance, self._cache_attr, new_value)
return new_value

def has_changed(self, instance: BaseState) -> bool:
"""Check if the ComputedVar value has changed.

Args:
instance: the state instance that needs to recompute the value.

Returns:
True if the value has changed, False otherwise.
"""
if not self._cache:
return True
if self._js_expr in instance._changed_computed_vars:
return True
# TODO: prime the cache if it's not already? creates side effects and breaks order of computed var execution
return self._js_expr in instance._changed_computed_vars

def _determine_var_type(self) -> Type:
"""Get the type of the var.
Expand Down
27 changes: 27 additions & 0 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3563,6 +3563,33 @@ class DillState(BaseState):
_ = state3._serialize()


def test_pickle():
class PickleState(BaseState):
pass

state = PickleState(_reflex_internal_init=True) # type: ignore

# test computed var cache is persisted
setattr(state, "__cvcached", 1)
state = PickleState._deserialize(state._serialize())
assert getattr(state, "__cvcached", None) == 1

# test ready computed vars set is not persisted
state._ready_computed_vars = {"foo"}
state = PickleState._deserialize(state._serialize())
assert not state._ready_computed_vars

# test that changed computed vars set is not persisted
state._changed_computed_vars = {"foo"}
state = PickleState._deserialize(state._serialize())
assert not state._changed_computed_vars

# test was_touched is not persisted
state._was_touched = True
state = PickleState._deserialize(state._serialize())
assert not state._was_touched


def test_typed_state() -> None:
class TypedState(rx.State):
field: rx.Field[str] = rx.field("")
Expand Down
Loading