|
|
""" |
|
|
Skybox generator: text → 2:1 equirectangular image (Stable Diffusion, local). |
|
|
Uses FP16 to reduce VRAM. Output 1024x512 or 2048x1024. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
DEFAULT_MODEL_ID = "runwayml/stable-diffusion-v1-5" |
|
|
FALLBACK_MODEL_ID = "runwayml/stable-diffusion-v1-5" |
|
|
|
|
|
|
|
|
def get_device() -> str: |
|
|
return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
def _is_complete_sd_dir(path: Path) -> bool: |
|
|
"""True if path looks like a complete Stable Diffusion pipeline (has unet weights).""" |
|
|
if not path.is_dir(): |
|
|
return False |
|
|
unet = path / "unet" |
|
|
if not unet.is_dir(): |
|
|
return False |
|
|
return any( |
|
|
(unet / f).exists() |
|
|
for f in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin") |
|
|
) |
|
|
|
|
|
|
|
|
def _default_local_weights_dir() -> str | None: |
|
|
"""First complete SD folder under weights/ (sd-v1-5 or stable-diffusion-2-1-base).""" |
|
|
try: |
|
|
root = Path(__file__).resolve().parent.parent |
|
|
for name in ("sd-v1-5", "stable-diffusion-2-1-base"): |
|
|
local = root / "weights" / name |
|
|
if _is_complete_sd_dir(local): |
|
|
return str(local) |
|
|
return None |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
def _resolve_model_path_and_token(): |
|
|
"""Use local path if set or default weights/ folder exists, else Hub id. Token from HF_TOKEN or huggingface-cli login.""" |
|
|
local = os.environ.get("SD_MODEL_PATH", "").strip() |
|
|
if local and os.path.isdir(local): |
|
|
return local, None |
|
|
default_local = _default_local_weights_dir() |
|
|
if default_local: |
|
|
return default_local, None |
|
|
model_id = os.environ.get("SD_MODEL_ID", DEFAULT_MODEL_ID) |
|
|
token = os.environ.get("HF_TOKEN") or True |
|
|
return model_id, token |
|
|
|
|
|
|
|
|
def generate_skybox( |
|
|
prompt: str, |
|
|
output_dir: str = "outputs", |
|
|
width: int = 1024, |
|
|
height: int = 512, |
|
|
seed: int | None = None, |
|
|
model_id: str | None = None, |
|
|
) -> tuple[str, float, float]: |
|
|
""" |
|
|
Generate a 2:1 equirectangular skybox image from a text prompt. |
|
|
Returns (path_to_image, inference_time_sec, peak_vram_mb). |
|
|
""" |
|
|
from diffusers import StableDiffusionPipeline |
|
|
|
|
|
device = get_device() |
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
Path(output_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
pretrained, token = _resolve_model_path_and_token() |
|
|
load_id = model_id or pretrained |
|
|
local_only = os.path.isdir(load_id) |
|
|
pipe = None |
|
|
last_error = None |
|
|
|
|
|
def _load(pid: str, local: bool) -> bool: |
|
|
nonlocal pipe, last_error |
|
|
try: |
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
pid, |
|
|
torch_dtype=dtype, |
|
|
safety_checker=None, |
|
|
token=None if local else (token or True), |
|
|
local_files_only=local, |
|
|
) |
|
|
return True |
|
|
except Exception as err: |
|
|
last_error = err |
|
|
return False |
|
|
|
|
|
if _load(load_id, local_only): |
|
|
pass |
|
|
elif not local_only and _load(FALLBACK_MODEL_ID, False): |
|
|
pass |
|
|
if pipe is None: |
|
|
raise RuntimeError( |
|
|
"Could not load Stable Diffusion. Need internet to download the model (first run).\n" |
|
|
" - Set HF_TOKEN=your_token if behind firewall (huggingface.co/settings/tokens)\n" |
|
|
" - Or download once: huggingface-cli download runwayml/stable-diffusion-v1-5 --local-dir ./weights/sd-v1-5" |
|
|
) from last_error |
|
|
|
|
|
pipe = pipe.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if device == "cuda": |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
generator = None |
|
|
if seed is not None: |
|
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
|
|
t0 = time.perf_counter() |
|
|
image = pipe( |
|
|
prompt=prompt, |
|
|
width=width, |
|
|
height=height, |
|
|
num_inference_steps=50, |
|
|
generator=generator, |
|
|
).images[0] |
|
|
|
|
|
if device == "cuda": |
|
|
torch.cuda.synchronize() |
|
|
t1 = time.perf_counter() |
|
|
inference_time = t1 - t0 |
|
|
peak_vram_mb = ( |
|
|
torch.cuda.max_memory_allocated() / 1024 / 1024 |
|
|
if device == "cuda" |
|
|
else 0.0 |
|
|
) |
|
|
|
|
|
|
|
|
safe_name = "".join(c if c.isalnum() or c in " -_" else "_" for c in prompt)[:60] |
|
|
out_path = os.path.join(output_dir, f"skybox_{safe_name.strip()}.png") |
|
|
image.save(out_path) |
|
|
|
|
|
return out_path, inference_time, peak_vram_mb |
|
|
|