evoneuralIn3D / scripts /skybox_generator.py
evoneuralai's picture
Upload folder using huggingface_hub
74f0b48 verified
"""
Skybox generator: text → 2:1 equirectangular image (Stable Diffusion, local).
Uses FP16 to reduce VRAM. Output 1024x512 or 2048x1024.
"""
import os
import time
from pathlib import Path
import torch
# Default: v1.5 works without license acceptance. Use SD_MODEL_ID to prefer SD 2.1.
DEFAULT_MODEL_ID = "runwayml/stable-diffusion-v1-5"
FALLBACK_MODEL_ID = "runwayml/stable-diffusion-v1-5" # Same; alternate if primary fails
def get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
def _is_complete_sd_dir(path: Path) -> bool:
"""True if path looks like a complete Stable Diffusion pipeline (has unet weights)."""
if not path.is_dir():
return False
unet = path / "unet"
if not unet.is_dir():
return False
return any(
(unet / f).exists()
for f in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin")
)
def _default_local_weights_dir() -> str | None:
"""First complete SD folder under weights/ (sd-v1-5 or stable-diffusion-2-1-base)."""
try:
root = Path(__file__).resolve().parent.parent
for name in ("sd-v1-5", "stable-diffusion-2-1-base"):
local = root / "weights" / name
if _is_complete_sd_dir(local):
return str(local)
return None
except Exception:
return None
def _resolve_model_path_and_token():
"""Use local path if set or default weights/ folder exists, else Hub id. Token from HF_TOKEN or huggingface-cli login."""
local = os.environ.get("SD_MODEL_PATH", "").strip()
if local and os.path.isdir(local):
return local, None
default_local = _default_local_weights_dir()
if default_local:
return default_local, None
model_id = os.environ.get("SD_MODEL_ID", DEFAULT_MODEL_ID)
token = os.environ.get("HF_TOKEN") or True # True = use cached login
return model_id, token
def generate_skybox(
prompt: str,
output_dir: str = "outputs",
width: int = 1024,
height: int = 512,
seed: int | None = None,
model_id: str | None = None,
) -> tuple[str, float, float]:
"""
Generate a 2:1 equirectangular skybox image from a text prompt.
Returns (path_to_image, inference_time_sec, peak_vram_mb).
"""
from diffusers import StableDiffusionPipeline
device = get_device()
dtype = torch.float16 if device == "cuda" else torch.float32
Path(output_dir).mkdir(parents=True, exist_ok=True)
pretrained, token = _resolve_model_path_and_token()
load_id = model_id or pretrained
local_only = os.path.isdir(load_id)
pipe = None
last_error = None
def _load(pid: str, local: bool) -> bool:
nonlocal pipe, last_error
try:
pipe = StableDiffusionPipeline.from_pretrained(
pid,
torch_dtype=dtype,
safety_checker=None,
token=None if local else (token or True),
local_files_only=local,
)
return True
except Exception as err:
last_error = err
return False
if _load(load_id, local_only):
pass
elif not local_only and _load(FALLBACK_MODEL_ID, False):
pass
if pipe is None:
raise RuntimeError(
"Could not load Stable Diffusion. Need internet to download the model (first run).\n"
" - Set HF_TOKEN=your_token if behind firewall (huggingface.co/settings/tokens)\n"
" - Or download once: huggingface-cli download runwayml/stable-diffusion-v1-5 --local-dir ./weights/sd-v1-5"
) from last_error
pipe = pipe.to(device)
# Optional: enable xformers for lower VRAM (uncomment if installed)
# if device == "cuda":
# pipe.enable_xformers_memory_efficient_attention()
if device == "cuda":
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
generator = None
if seed is not None:
generator = torch.Generator(device=device).manual_seed(seed)
t0 = time.perf_counter()
image = pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=50,
generator=generator,
).images[0]
if device == "cuda":
torch.cuda.synchronize()
t1 = time.perf_counter()
inference_time = t1 - t0
peak_vram_mb = (
torch.cuda.max_memory_allocated() / 1024 / 1024
if device == "cuda"
else 0.0
)
# Save with safe filename
safe_name = "".join(c if c.isalnum() or c in " -_" else "_" for c in prompt)[:60]
out_path = os.path.join(output_dir, f"skybox_{safe_name.strip()}.png")
image.save(out_path)
return out_path, inference_time, peak_vram_mb