| """ |
| RunPod Serverless Handler - Wrapper for AI-Toolkit |
| Does NOT modify ai-toolkit code, only wraps it |
| |
| Supports RunPod model caching via HuggingFace integration. |
| """ |
|
|
| import os |
| import sys |
| import subprocess |
| import traceback |
| import logging |
| import uuid |
| from pathlib import Path |
|
|
| |
| |
| |
|
|
| |
| RUNPOD_CACHE_BASE = "/runpod-volume/huggingface-cache" |
| RUNPOD_HF_CACHE = "/runpod-volume/huggingface-cache/hub" |
|
|
| |
| IS_RUNPOD_CACHE = os.path.exists("/runpod-volume") |
|
|
| if IS_RUNPOD_CACHE: |
| |
| os.environ["HF_HOME"] = RUNPOD_CACHE_BASE |
| os.environ["HUGGINGFACE_HUB_CACHE"] = RUNPOD_HF_CACHE |
| os.environ["TRANSFORMERS_CACHE"] = RUNPOD_HF_CACHE |
| os.environ["HF_DATASETS_CACHE"] = f"{RUNPOD_CACHE_BASE}/datasets" |
|
|
| |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
| os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1" |
| os.environ["DISABLE_TELEMETRY"] = "YES" |
|
|
| |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") |
| if HF_TOKEN: |
| os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN |
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| AI_TOOLKIT_DIR = os.path.join(SCRIPT_DIR, "ai-toolkit") |
|
|
| import runpod |
| import torch |
| import yaml |
| import gc |
| import shutil |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| |
| CURRENT_MODEL = None |
|
|
| |
| |
| |
|
|
| |
| MODEL_PRESETS = { |
| "wan21_1b": "train_lora_wan21_1b_24gb.yaml", |
| "wan21_14b": "train_lora_wan21_14b_24gb.yaml", |
| "wan22_14b": "train_lora_wan22_14b_24gb.yaml", |
| "qwen_image": "train_lora_qwen_image_24gb.yaml", |
| "qwen_image_edit": "train_lora_qwen_image_edit_32gb.yaml", |
| "qwen_image_edit_2509": "train_lora_qwen_image_edit_2509_32gb.yaml", |
| "flux_dev": "train_lora_flux_24gb.yaml", |
| "flux_schnell": "train_lora_flux_schnell_24gb.yaml", |
| } |
|
|
| |
| CACHE_REPO = "Aloukik21/trainer" |
|
|
| |
| MODEL_CACHE_PATHS = { |
| "wan21_1b": "wan21-14b", |
| "wan21_14b": "wan21-14b", |
| "wan22_14b": "wan22-14b", |
| "qwen_image": "qwen-image", |
| "qwen_image_edit": "qwen-image", |
| "qwen_image_edit_2509": "qwen-image", |
| "flux_dev": "flux-dev", |
| "flux_schnell": "flux-schnell", |
| } |
|
|
| |
| MODEL_HF_REPOS = { |
| "wan21_1b": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", |
| "wan21_14b": "Wan-AI/Wan2.1-T2V-14B-Diffusers", |
| "wan22_14b": "ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16", |
| "qwen_image": "Qwen/Qwen-Image", |
| "qwen_image_edit": "Qwen/Qwen-Image-Edit", |
| "qwen_image_edit_2509": "Qwen/Qwen-Image-Edit", |
| "flux_dev": "black-forest-labs/FLUX.1-dev", |
| "flux_schnell": "black-forest-labs/FLUX.1-schnell", |
| } |
|
|
| |
| ARA_CACHE_PATH = "accuracy_recovery_adapters" |
|
|
|
|
| |
| |
| |
|
|
| def cleanup_gpu_memory(): |
| """Aggressively clean up GPU memory.""" |
| logger.info("Cleaning up GPU memory...") |
|
|
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
|
|
| |
| gc.collect() |
|
|
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| logger.info(f"GPU memory after cleanup: {get_gpu_info()}") |
|
|
|
|
| def cleanup_temp_files(): |
| """Clean up temporary training files.""" |
| logger.info("Cleaning up temporary files...") |
|
|
| |
| config_dir = os.path.join(AI_TOOLKIT_DIR, "config") |
| for f in os.listdir(config_dir): |
| if f.endswith('.yaml') and f.startswith(('lora_', 'test_', 'my_')): |
| try: |
| os.remove(os.path.join(config_dir, f)) |
| logger.info(f"Removed temp config: {f}") |
| except Exception as e: |
| logger.warning(f"Failed to remove {f}: {e}") |
|
|
| |
| workspace_dirs = ["/workspace/dataset", "/workspace/output"] |
| for ws_dir in workspace_dirs: |
| if os.path.exists(ws_dir): |
| for item in os.listdir(ws_dir): |
| item_path = os.path.join(ws_dir, item) |
| if item.startswith(('_latent_cache', '_t_e_cache', '.aitk')): |
| try: |
| if os.path.isdir(item_path): |
| shutil.rmtree(item_path) |
| else: |
| os.remove(item_path) |
| logger.info(f"Removed cache: {item_path}") |
| except Exception as e: |
| logger.warning(f"Failed to remove {item_path}: {e}") |
|
|
|
|
| def cleanup_before_training(new_model: str): |
| """Full cleanup before starting new model training.""" |
| global CURRENT_MODEL |
|
|
| if CURRENT_MODEL and CURRENT_MODEL != new_model: |
| logger.info(f"Switching from {CURRENT_MODEL} to {new_model} - performing full cleanup") |
| cleanup_gpu_memory() |
| cleanup_temp_files() |
| elif CURRENT_MODEL == new_model: |
| logger.info(f"Same model {new_model} - light cleanup only") |
| cleanup_gpu_memory() |
| else: |
| logger.info(f"First training run with {new_model}") |
|
|
| CURRENT_MODEL = new_model |
|
|
| |
| gpu_info = get_gpu_info() |
| logger.info(f"Ready for training. GPU: {gpu_info['name']}, Free: {gpu_info['free_gb']}GB") |
|
|
|
|
| |
| |
| |
|
|
| def get_gpu_info(): |
| """Get GPU information.""" |
| if not torch.cuda.is_available(): |
| return {"available": False} |
| props = torch.cuda.get_device_properties(0) |
| free_mem, total_mem = torch.cuda.mem_get_info(0) |
| return { |
| "available": True, |
| "name": props.name, |
| "total_gb": round(total_mem / (1024**3), 2), |
| "free_gb": round(free_mem / (1024**3), 2), |
| } |
|
|
|
|
| def get_environment_info(): |
| """Get environment information for debugging.""" |
| return { |
| "is_runpod_cache": IS_RUNPOD_CACHE, |
| "hf_home": os.environ.get("HF_HOME", "not set"), |
| "hf_token_set": bool(HF_TOKEN), |
| "gpu": get_gpu_info(), |
| "ai_toolkit_dir": AI_TOOLKIT_DIR, |
| "cache_exists": os.path.exists(RUNPOD_HF_CACHE) if IS_RUNPOD_CACHE else False, |
| } |
|
|
|
|
| def find_cached_model(model_key: str) -> str: |
| """ |
| Find cached model path on RunPod from Aloukik21/trainer repo. |
| |
| Args: |
| model_key: Model key (e.g., 'flux_dev', 'wan22_14b') |
| |
| Returns: |
| Path to cached model subfolder, or original HF repo if not cached |
| """ |
| if not IS_RUNPOD_CACHE: |
| return MODEL_HF_REPOS.get(model_key, "") |
|
|
| |
| cache_name = CACHE_REPO.replace("/", "--") |
| snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots" |
|
|
| if snapshots_dir.exists(): |
| snapshots = list(snapshots_dir.iterdir()) |
| if snapshots: |
| |
| subfolder = MODEL_CACHE_PATHS.get(model_key) |
| if subfolder: |
| cached_path = snapshots[0] / subfolder |
| if cached_path.exists(): |
| logger.info(f"Using cached model: {model_key} -> {cached_path}") |
| return str(cached_path) |
|
|
| |
| original_repo = MODEL_HF_REPOS.get(model_key, "") |
| logger.info(f"Model not in cache, using original: {original_repo}") |
| return original_repo |
|
|
|
|
| def find_cached_ara(adapter_name: str) -> str: |
| """ |
| Find cached accuracy recovery adapter. |
| |
| Args: |
| adapter_name: Adapter filename (e.g., 'wan22_14b_t2i_torchao_uint4.safetensors') |
| |
| Returns: |
| Path to cached adapter, or original HF path |
| """ |
| if not IS_RUNPOD_CACHE: |
| return f"ostris/accuracy_recovery_adapters/{adapter_name}" |
|
|
| cache_name = CACHE_REPO.replace("/", "--") |
| snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots" |
|
|
| if snapshots_dir.exists(): |
| snapshots = list(snapshots_dir.iterdir()) |
| if snapshots: |
| cached_path = snapshots[0] / ARA_CACHE_PATH / adapter_name |
| if cached_path.exists(): |
| logger.info(f"Using cached ARA: {adapter_name} -> {cached_path}") |
| return str(cached_path) |
|
|
| return f"ostris/accuracy_recovery_adapters/{adapter_name}" |
|
|
|
|
| def check_model_cache_status(model_key: str) -> dict: |
| """Check if model files are cached in Aloukik21/trainer.""" |
| if model_key not in MODEL_CACHE_PATHS: |
| return {"cached": False, "reason": "unknown model"} |
|
|
| status = { |
| "model": model_key, |
| "cache_repo": CACHE_REPO, |
| "subfolder": MODEL_CACHE_PATHS.get(model_key), |
| } |
|
|
| |
| cache_name = CACHE_REPO.replace("/", "--") |
| snapshots_dir = Path(RUNPOD_HF_CACHE) / f"models--{cache_name}" / "snapshots" |
|
|
| if snapshots_dir.exists(): |
| snapshots = list(snapshots_dir.iterdir()) |
| if snapshots: |
| subfolder = MODEL_CACHE_PATHS.get(model_key) |
| model_path = snapshots[0] / subfolder |
| status["cached"] = model_path.exists() |
| status["path"] = str(model_path) if model_path.exists() else None |
| else: |
| status["cached"] = False |
| else: |
| status["cached"] = False |
| status["reason"] = "cache repo not found" |
|
|
| return status |
|
|
|
|
| |
| |
| |
|
|
| def load_example_config(model_key): |
| """Load example config from ai-toolkit.""" |
| if model_key not in MODEL_PRESETS: |
| raise ValueError(f"Unknown model: {model_key}. Available: {list(MODEL_PRESETS.keys())}") |
|
|
| config_file = MODEL_PRESETS[model_key] |
| config_path = os.path.join(AI_TOOLKIT_DIR, "config", "examples", config_file) |
|
|
| with open(config_path, 'r') as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def run_training(params): |
| """Run training using ai-toolkit.""" |
| model_key = params.get("model", "wan22_14b") |
|
|
| |
| cleanup_before_training(model_key) |
|
|
| |
| config = load_example_config(model_key) |
|
|
| |
| job_name = params.get("name", f"lora_{model_key}_{uuid.uuid4().hex[:6]}") |
| config["config"]["name"] = job_name |
|
|
| process = config["config"]["process"][0] |
|
|
| |
| process["datasets"][0]["folder_path"] = params.get("dataset_path", "/workspace/dataset") |
|
|
| |
| process["training_folder"] = params.get("output_path", "/workspace/output") |
|
|
| |
| if "steps" in params: |
| process["train"]["steps"] = params["steps"] |
| if "batch_size" in params: |
| process["train"]["batch_size"] = params["batch_size"] |
| if "learning_rate" in params: |
| process["train"]["lr"] = params["learning_rate"] |
| if "lora_rank" in params: |
| process["network"]["linear"] = params["lora_rank"] |
| process["network"]["linear_alpha"] = params.get("lora_alpha", params["lora_rank"]) |
| if "save_every" in params: |
| process["save"]["save_every"] = params["save_every"] |
| if "sample_every" in params: |
| process["sample"]["sample_every"] = params["sample_every"] |
| if "resolution" in params: |
| process["datasets"][0]["resolution"] = params["resolution"] |
| if "num_frames" in params: |
| process["datasets"][0]["num_frames"] = params["num_frames"] |
| if "sample_prompts" in params: |
| process["sample"]["prompts"] = params["sample_prompts"] |
| if "trigger_word" in params: |
| process["trigger_word"] = params["trigger_word"] |
|
|
| |
| if "model" in process: |
| cached_path = find_cached_model(model_key) |
| if cached_path: |
| process["model"]["name_or_path"] = cached_path |
| logger.info(f"Model path set to: {cached_path}") |
|
|
| |
| config_dir = os.path.join(AI_TOOLKIT_DIR, "config") |
| config_path = os.path.join(config_dir, f"{job_name}.yaml") |
|
|
| with open(config_path, 'w') as f: |
| yaml.dump(config, f, default_flow_style=False) |
|
|
| logger.info(f"Config saved: {config_path}") |
| logger.info(f"Starting: {job_name}") |
|
|
| |
| cmd = [sys.executable, os.path.join(AI_TOOLKIT_DIR, "run.py"), config_path] |
| logger.info(f"Command: {' '.join(cmd)}") |
|
|
| proc = subprocess.Popen( |
| cmd, |
| cwd=AI_TOOLKIT_DIR, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.STDOUT, |
| text=True, |
| bufsize=1, |
| ) |
|
|
| for line in proc.stdout: |
| logger.info(line.rstrip()) |
|
|
| proc.wait() |
|
|
| |
| cleanup_gpu_memory() |
|
|
| if proc.returncode != 0: |
| raise RuntimeError(f"Training failed with code {proc.returncode}") |
|
|
| return { |
| "status": "success", |
| "job_name": job_name, |
| "output_path": process["training_folder"], |
| "model": model_key, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def handler(job): |
| """RunPod handler.""" |
| job_input = job.get("input", {}) |
| action = job_input.get("action", "train") |
|
|
| logger.info(f"Action: {action}, GPU: {get_gpu_info()}") |
|
|
| try: |
| if action == "list_models": |
| return {"status": "success", "models": list(MODEL_PRESETS.keys())} |
|
|
| elif action == "status": |
| return { |
| "status": "success", |
| "environment": get_environment_info(), |
| } |
|
|
| elif action == "check_cache": |
| model_key = job_input.get("model") |
| if model_key: |
| cache_status = check_model_cache_status(model_key) |
| else: |
| cache_status = {m: check_model_cache_status(m) for m in MODEL_PRESETS.keys()} |
| return {"status": "success", "cache": cache_status} |
|
|
| elif action == "cleanup": |
| |
| cleanup_gpu_memory() |
| cleanup_temp_files() |
| global CURRENT_MODEL |
| CURRENT_MODEL = None |
| return { |
| "status": "success", |
| "message": "Cleanup complete", |
| "gpu": get_gpu_info(), |
| } |
|
|
| elif action == "train": |
| params = job_input.get("params", {}) |
| params["model"] = job_input.get("model", params.get("model", "wan22_14b")) |
| return run_training(params) |
|
|
| else: |
| return {"status": "error", "error": f"Unknown action: {action}"} |
|
|
| except Exception as e: |
| logger.error(traceback.format_exc()) |
| return {"status": "error", "error": str(e)} |
|
|
|
|
| if __name__ == "__main__": |
| logger.info("Starting AI-Toolkit RunPod Handler") |
| logger.info(f"Environment: {get_environment_info()}") |
| runpod.serverless.start({"handler": handler}) |
|
|