import torch import logging import re from typing import Dict, List, Any from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path=""): """ Initialize the RECCON emotional trigger extraction model using native transformers. Args: path: Path to model directory (provided by HuggingFace Inference Endpoints) """ logger.info("Initializing RECCON Trigger Extraction endpoint...") # Detect device (CUDA/CPU) cuda_available = torch.cuda.is_available() if not cuda_available: logger.warning("GPU not detected. Running on CPU. Inference will be slower.") # In 'pipeline', device is an integer (-1 for CPU, 0+ for GPU) self.device_id = 0 if cuda_available else -1 # Determine model path model_path = path if path and path != "." else "." logger.info(f"Loading model from {model_path}...") try: # Load tokenizer and model explicitly to ensure correct loading tokenizer = AutoTokenizer.from_pretrained(model_path) model, loading_info = AutoModelForQuestionAnswering.from_pretrained( model_path, output_loading_info=True ) logger.warning("RECCON load info - missing_keys: %s", loading_info.get("missing_keys")) logger.warning("RECCON load info - unexpected_keys: %s", loading_info.get("unexpected_keys")) logger.warning("RECCON load info - error_msgs: %s", loading_info.get("error_msgs")) logger.warning("Loaded model class: %s", model.__class__.__name__) logger.warning("Loaded model name_or_path: %s", getattr(model.config, "_name_or_path", None)) # Initialize the pipeline # top_k=20 matches your previous 'n_best_size=20' logic self.pipe = pipeline( "question-answering", model=model, tokenizer=tokenizer, device=self.device_id, top_k=20, handle_impossible_answer=False ) logger.info("Model loaded successfully.") except Exception as e: logger.error(f"Failed to load model: {e}") raise # Question template (must match training) self.question_template = ( "Extract the exact short phrase (<= 8 words) from the target " "utterance that most strongly signals the emotion {emotion}. " "Return only a substring of the target utterance." ) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process inference request. """ # Extract inputs inputs = data.pop("inputs", data) # Normalize to list format if isinstance(inputs, dict): inputs = [inputs] if not inputs: return [{"error": "No inputs provided", "triggers": []}] # Validate and format inputs for the pipeline pipeline_inputs = [] valid_indices = [] for i, item in enumerate(inputs): utterance = item.get("utterance", "").strip() emotion = item.get("emotion", "") if not utterance: logger.warning(f"Empty utterance at index {i}") continue # Format as QA task question = self.question_template.format(emotion=emotion) # The pipeline expects a list of dicts with 'question' and 'context' pipeline_inputs.append({ 'question': question, 'context': utterance }) valid_indices.append(i) # Run prediction results = [] if not pipeline_inputs: # All inputs were invalid for item in inputs: results.append({ "utterance": item.get("utterance", ""), "emotion": item.get("emotion", ""), "error": "Missing or empty utterance", "triggers": [] }) return results try: # Run inference (batch_size helps with multiple inputs) predictions = self.pipe(pipeline_inputs, batch_size=8) # If batch_size=1 or single input, pipeline might return a single list/dict # We ensure it's a list of lists (since top_k > 1) if isinstance(predictions, dict): # Single input result predictions = [predictions] # Wrap in list elif isinstance(predictions, list) and len(predictions) > 0 and isinstance(predictions[0], dict): # This happens if we have multiple inputs but top_k=1 (which is not the case here), # OR if we have a single input and top_k > 1. # If we have multiple inputs and top_k > 1, it returns a list of lists. if len(pipeline_inputs) == 1: predictions = [predictions] # If multiple inputs and list of dicts, that implies top_k=1. # But we set top_k=20. So it should be list of lists. logger.debug(f"Raw predictions: {predictions}") # Post-process results pred_idx = 0 for i, item in enumerate(inputs): utterance = item.get("utterance", "").strip() emotion = item.get("emotion", "") if i not in valid_indices: results.append({ "utterance": utterance, "emotion": emotion, "error": "Missing or empty utterance", "triggers": [] }) else: # Get prediction for this item # Because top_k=20, 'current_preds' is a list of dicts: [{'answer': '...', 'score': ...}, ...] current_preds = predictions[pred_idx] # Ensure it is a list if isinstance(current_preds, dict): current_preds = [current_preds] logger.info( "RECCON raw spans (answer, score): %s", [(p.get("answer"), p.get("score", 0.0), 3) for p in current_preds[:5]] ) def is_good_span(ans: str) -> bool: if not ans: return False a = ans.strip() if len(a) < 3: return False # reject pure punctuation if all(ch in ".,!?;:-—'\"()[]{}" for ch in a): return False # require at least one letter if not any(ch.isalpha() for ch in a): return False return True raw_answers = [p.get("answer", "") for p in current_preds] raw_answers = [a for a in raw_answers if is_good_span(a)] triggers = self._clean_spans(raw_answers, utterance) results.append({ "utterance": utterance, "emotion": emotion, "triggers": triggers }) pred_idx += 1 logger.debug(f"Cleaned results: {results}") return results except Exception as e: logger.error(f"Model prediction failed: {e}") return [{ "utterance": item.get("utterance", ""), "emotion": item.get("emotion", ""), "error": str(e), "triggers": [] } for item in inputs] def _clean_spans(self, spans: List[str], target_text: str) -> List[str]: """ Clean and filter extracted trigger spans. (Logic preserved exactly as provided) """ target_text = target_text or "" target_lower = target_text.lower() def _norm(s: str) -> str: s = (s or "").strip().lower() s = re.sub(r"\s+", " ", s) s = re.sub(r"^[^\w]+|[^\w]+$", "", s) return s def _extract_from_target(target: str, phrase_lower: str) -> str: idx = target.lower().find(phrase_lower) if idx >= 0: return target[idx:idx+len(phrase_lower)] return phrase_lower STOP = { "a", "an", "the", "and", "or", "but", "so", "to", "of", "in", "on", "at", "with", "for", "from", "is", "am", "are", "was", "were", "be", "been", "being", "i", "you", "he", "she", "it", "we", "they", "my", "your", "his", "her", "their", "our", "me", "him", "her", "them", "this", "that", "these", "those" } candidates = [] for s in spans: s = (s or "").strip() if not s: continue s_norm = _norm(s) if not s_norm: continue if target_text and s_norm not in target_lower: continue tokens = s_norm.split() if len(tokens) > 8 or len(s_norm) > 80: continue if len(tokens) == 1 and (tokens[0] in STOP or len(tokens[0]) <= 2): continue candidates.append({ "norm": s_norm, "tokens": tokens, "tok_len": len(tokens), "char_len": len(s_norm) }) # Prioritize short, focused emotional keywords (1-3 words) short_candidates = [c for c in candidates if 1 <= c["tok_len"] <= 3] if short_candidates: candidates = short_candidates # Sort by SHORTEST spans first (most focused keywords) candidates.sort(key=lambda x: (x["tok_len"], x["char_len"]), reverse=False) kept_norms = [] for c in list(candidates): n = c["norm"] if any(n in kn or kn in n for kn in kept_norms): continue kept_norms.append(n) cleaned = [_extract_from_target(target_text, n) for n in kept_norms] if not cleaned and spans: tt_tokens = target_lower.split() best = None for s in spans: words = [w for w in (s or '').lower().strip().split() if w] for L in range(min(8, len(words)), 0, -1): for i in range(len(words) - L + 1): phrase = words[i:i+L] for j in range(len(tt_tokens) - L + 1): if tt_tokens[j:j+L] == phrase: cand = " ".join(phrase) best = cand break if best: break if best: break if best: return [_extract_from_target(target_text, best)] return cleaned[:3]