Spaces:
Running
on
A10G
Running
on
A10G
| from __future__ import annotations | |
| import difflib | |
| import functools | |
| import logging | |
| import re | |
| from typing import Dict, Optional, Tuple | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| COCO_CLASSES: Tuple[str, ...] = ( | |
| "person", | |
| "bicycle", | |
| "car", | |
| "motorcycle", | |
| "airplane", | |
| "bus", | |
| "train", | |
| "truck", | |
| "boat", | |
| "traffic light", | |
| "fire hydrant", | |
| "stop sign", | |
| "parking meter", | |
| "bench", | |
| "bird", | |
| "cat", | |
| "dog", | |
| "horse", | |
| "sheep", | |
| "cow", | |
| "elephant", | |
| "bear", | |
| "zebra", | |
| "giraffe", | |
| "backpack", | |
| "umbrella", | |
| "handbag", | |
| "tie", | |
| "suitcase", | |
| "frisbee", | |
| "skis", | |
| "snowboard", | |
| "sports ball", | |
| "kite", | |
| "baseball bat", | |
| "baseball glove", | |
| "skateboard", | |
| "surfboard", | |
| "tennis racket", | |
| "bottle", | |
| "wine glass", | |
| "cup", | |
| "fork", | |
| "knife", | |
| "spoon", | |
| "bowl", | |
| "banana", | |
| "apple", | |
| "sandwich", | |
| "orange", | |
| "broccoli", | |
| "carrot", | |
| "hot dog", | |
| "pizza", | |
| "donut", | |
| "cake", | |
| "chair", | |
| "couch", | |
| "potted plant", | |
| "bed", | |
| "dining table", | |
| "toilet", | |
| "tv", | |
| "laptop", | |
| "mouse", | |
| "remote", | |
| "keyboard", | |
| "cell phone", | |
| "microwave", | |
| "oven", | |
| "toaster", | |
| "sink", | |
| "refrigerator", | |
| "book", | |
| "clock", | |
| "vase", | |
| "scissors", | |
| "teddy bear", | |
| "hair drier", | |
| "toothbrush", | |
| ) | |
| def coco_class_catalog() -> str: | |
| """Return the COCO classes in a comma-separated catalog for prompts.""" | |
| return ", ".join(COCO_CLASSES) | |
| def _normalize(label: str) -> str: | |
| return re.sub(r"[^a-z0-9]+", " ", label.lower()).strip() | |
| _CANONICAL_LOOKUP: Dict[str, str] = {_normalize(name): name for name in COCO_CLASSES} | |
| _COCO_SYNONYMS: Dict[str, str] = { | |
| "people": "person", | |
| "man": "person", | |
| "woman": "person", | |
| "men": "person", | |
| "women": "person", | |
| "pedestrian": "person", | |
| "soldier": "person", | |
| "infantry": "person", | |
| "civilian": "person", | |
| "motorbike": "motorcycle", | |
| "motor bike": "motorcycle", | |
| "bike": "bicycle", | |
| "aircraft": "airplane", | |
| "plane": "airplane", | |
| "jet": "airplane", | |
| "aeroplane": "airplane", | |
| "drone": "airplane", | |
| "uav": "airplane", | |
| "helicopter": "airplane", | |
| "pickup": "truck", | |
| "pickup truck": "truck", | |
| "semi": "truck", | |
| "lorry": "truck", | |
| "tractor trailer": "truck", | |
| "vehicle": "car", | |
| "sedan": "car", | |
| "suv": "car", | |
| "van": "car", | |
| "vessel": "boat", | |
| "ship": "boat", | |
| "warship": "boat", | |
| "speedboat": "boat", | |
| "cargo ship": "boat", | |
| "fishing boat": "boat", | |
| "yacht": "boat", | |
| "kayak": "boat", | |
| "canoe": "boat", | |
| "watercraft": "boat", | |
| "coach": "bus", | |
| "television": "tv", | |
| "tv monitor": "tv", | |
| "mobile phone": "cell phone", | |
| "smartphone": "cell phone", | |
| "cellphone": "cell phone", | |
| "dinner table": "dining table", | |
| "sofa": "couch", | |
| "cooker": "oven", | |
| } | |
| _ALIAS_LOOKUP: Dict[str, str] = {_normalize(alias): canonical for alias, canonical in _COCO_SYNONYMS.items()} | |
| # --------------------------------------------------------------------------- | |
| # Semantic similarity fallback (lazy-loaded) | |
| # --------------------------------------------------------------------------- | |
| _SEMANTIC_MODEL = None | |
| _COCO_EMBEDDINGS: Optional[np.ndarray] = None | |
| _SEMANTIC_THRESHOLD = 0.65 # Minimum cosine similarity to accept a match | |
| def _get_semantic_model(): | |
| """Lazy-load a lightweight sentence-transformer for semantic matching.""" | |
| global _SEMANTIC_MODEL, _COCO_EMBEDDINGS | |
| if _SEMANTIC_MODEL is not None: | |
| return _SEMANTIC_MODEL, _COCO_EMBEDDINGS | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| _SEMANTIC_MODEL = SentenceTransformer("all-MiniLM-L6-v2") | |
| # Prefix with "a photo of a" to anchor embeddings in visual/object space | |
| coco_phrases = [f"a photo of a {cls}" for cls in COCO_CLASSES] | |
| _COCO_EMBEDDINGS = _SEMANTIC_MODEL.encode( | |
| coco_phrases, normalize_embeddings=True | |
| ) | |
| logger.info("Loaded semantic similarity model for COCO class mapping") | |
| except Exception: | |
| logger.warning("sentence-transformers unavailable; semantic COCO mapping disabled", exc_info=True) | |
| _SEMANTIC_MODEL = False # Sentinel: tried and failed | |
| _COCO_EMBEDDINGS = None | |
| return _SEMANTIC_MODEL, _COCO_EMBEDDINGS | |
| def _semantic_coco_match(value: str) -> Optional[str]: | |
| """Find the closest COCO class by embedding cosine similarity. | |
| Returns the COCO class name if similarity >= threshold, else None. | |
| """ | |
| model, coco_embs = _get_semantic_model() | |
| if model is False or coco_embs is None: | |
| return None | |
| query_emb = model.encode( | |
| [f"a photo of a {value}"], normalize_embeddings=True | |
| ) | |
| similarities = query_emb @ coco_embs.T # (1, 80) | |
| best_idx = int(np.argmax(similarities)) | |
| best_score = float(similarities[0, best_idx]) | |
| if best_score >= _SEMANTIC_THRESHOLD: | |
| matched = COCO_CLASSES[best_idx] | |
| logger.info( | |
| "Semantic COCO match: '%s' -> '%s' (score=%.3f)", | |
| value, matched, best_score, | |
| ) | |
| return matched | |
| logger.debug( | |
| "Semantic COCO match failed: '%s' best='%s' (score=%.3f < %.2f)", | |
| value, COCO_CLASSES[best_idx], best_score, _SEMANTIC_THRESHOLD, | |
| ) | |
| return None | |
| def canonicalize_coco_name(value: str | None) -> str | None: | |
| """Map an arbitrary string to the closest COCO class name if possible. | |
| Matching cascade: | |
| 1. Exact normalized match | |
| 2. Synonym lookup | |
| 3. Substring match (alias then canonical) | |
| 4. Token-level match | |
| 5. Fuzzy string match (difflib) | |
| 6. Semantic embedding similarity (sentence-transformers) | |
| """ | |
| if not value: | |
| return None | |
| normalized = _normalize(value) | |
| if not normalized: | |
| return None | |
| if normalized in _CANONICAL_LOOKUP: | |
| return _CANONICAL_LOOKUP[normalized] | |
| if normalized in _ALIAS_LOOKUP: | |
| return _ALIAS_LOOKUP[normalized] | |
| for alias_norm, canonical in _ALIAS_LOOKUP.items(): | |
| if alias_norm and alias_norm in normalized: | |
| return canonical | |
| for canonical_norm, canonical in _CANONICAL_LOOKUP.items(): | |
| if canonical_norm and canonical_norm in normalized: | |
| return canonical | |
| tokens = normalized.split() | |
| for token in tokens: | |
| if token in _CANONICAL_LOOKUP: | |
| return _CANONICAL_LOOKUP[token] | |
| if token in _ALIAS_LOOKUP: | |
| return _ALIAS_LOOKUP[token] | |
| close = difflib.get_close_matches(normalized, list(_CANONICAL_LOOKUP.keys()), n=1, cutoff=0.82) | |
| if close: | |
| return _CANONICAL_LOOKUP[close[0]] | |
| # Last resort: semantic embedding similarity | |
| return _semantic_coco_match(value) | |