Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| RhythmEnv Client. | |
| Provides the WebSocket client for connecting to a RhythmEnv Life Simulator server. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_client import EnvClient | |
| try: | |
| from .models import RhythmAction, RhythmObservation, RhythmState, StepRecord | |
| except ImportError: | |
| from models import RhythmAction, RhythmObservation, RhythmState, StepRecord | |
| class RhythmEnv(EnvClient[RhythmAction, RhythmObservation, RhythmState]): | |
| """ | |
| Client for the RhythmEnv Life Simulator. | |
| Example: | |
| >>> async with RhythmEnv(base_url="https://InosLihka-rhythm-env.hf.space") as client: | |
| ... result = await client.reset() | |
| ... result = await client.step(RhythmAction(action_type=ActionType.DEEP_WORK)) | |
| """ | |
| def _step_payload(self, action: RhythmAction) -> Dict[str, Any]: | |
| """Serialize RhythmAction to JSON payload.""" | |
| return {"action_type": action.action_type.value} | |
| def _parse_result(self, payload: Dict[str, Any]) -> StepResult[RhythmObservation]: | |
| """Parse server response into StepResult[RhythmObservation]. | |
| Surfaces ALL observation fields the server returns, including the | |
| per-meter deltas, anomalies (in step_history), last_action, and the | |
| full step history. Without these, an external agent connecting to the | |
| server can't see the meta-RL signals it needs to infer the profile. | |
| """ | |
| obs_data = payload.get("observation", {}) | |
| # Reconstruct step_history with full StepRecord fidelity | |
| step_history_raw = obs_data.get("step_history", []) or [] | |
| step_history = [ | |
| StepRecord( | |
| step=h.get("step", 0), | |
| action=h.get("action", ""), | |
| reward=h.get("reward", 0.0), | |
| vitality_delta=h.get("vitality_delta", 0.0), | |
| cognition_delta=h.get("cognition_delta", 0.0), | |
| progress_delta=h.get("progress_delta", 0.0), | |
| serenity_delta=h.get("serenity_delta", 0.0), | |
| connection_delta=h.get("connection_delta", 0.0), | |
| vitality_anomaly=h.get("vitality_anomaly", 0.0), | |
| cognition_anomaly=h.get("cognition_anomaly", 0.0), | |
| progress_anomaly=h.get("progress_anomaly", 0.0), | |
| serenity_anomaly=h.get("serenity_anomaly", 0.0), | |
| connection_anomaly=h.get("connection_anomaly", 0.0), | |
| ) | |
| for h in step_history_raw | |
| ] | |
| observation = RhythmObservation( | |
| timestep=obs_data.get("timestep", 0), | |
| day=obs_data.get("day", 0), | |
| slot=obs_data.get("slot", 0), | |
| vitality=obs_data.get("vitality", 0.8), | |
| cognition=obs_data.get("cognition", 0.7), | |
| progress=obs_data.get("progress", 0.0), | |
| serenity=obs_data.get("serenity", 0.7), | |
| connection=obs_data.get("connection", 0.5), | |
| active_event=obs_data.get("active_event"), | |
| remaining_steps=obs_data.get("remaining_steps", 28), | |
| reward_breakdown=obs_data.get("reward_breakdown", {}), | |
| done=payload.get("done", False), | |
| reward=payload.get("reward", 0.0), | |
| metadata=obs_data.get("metadata", {}), | |
| # Per-meter deltas from THIS step (was being silently dropped) | |
| vitality_delta=obs_data.get("vitality_delta", 0.0), | |
| cognition_delta=obs_data.get("cognition_delta", 0.0), | |
| progress_delta=obs_data.get("progress_delta", 0.0), | |
| serenity_delta=obs_data.get("serenity_delta", 0.0), | |
| connection_delta=obs_data.get("connection_delta", 0.0), | |
| last_action=obs_data.get("last_action"), | |
| # Rolling history with anomalies (the meta-RL signal) | |
| step_history=step_history, | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=payload.get("reward", 0.0), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: Dict[str, Any]) -> RhythmState: | |
| """Parse server response into RhythmState.""" | |
| return RhythmState( | |
| episode_id=payload.get("episode_id", ""), | |
| step_count=payload.get("step_count", 0), | |
| timestep=payload.get("timestep", 0), | |
| day=payload.get("day", 0), | |
| slot=payload.get("slot", 0), | |
| profile_name=payload.get("profile_name", ""), | |
| vitality=payload.get("vitality", 0.8), | |
| cognition=payload.get("cognition", 0.7), | |
| progress=payload.get("progress", 0.0), | |
| serenity=payload.get("serenity", 0.7), | |
| connection=payload.get("connection", 0.5), | |
| active_event=payload.get("active_event"), | |
| ) | |