RECCON / handler.py
Khriis's picture
Update handler.py
c4e871c verified
import torch
import logging
import re
from typing import Dict, List, Any
from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the RECCON emotional trigger extraction model using native transformers.
Args:
path: Path to model directory (provided by HuggingFace Inference Endpoints)
"""
logger.info("Initializing RECCON Trigger Extraction endpoint...")
# Detect device (CUDA/CPU)
cuda_available = torch.cuda.is_available()
if not cuda_available:
logger.warning("GPU not detected. Running on CPU. Inference will be slower.")
# In 'pipeline', device is an integer (-1 for CPU, 0+ for GPU)
self.device_id = 0 if cuda_available else -1
# Determine model path
model_path = path if path and path != "." else "."
logger.info(f"Loading model from {model_path}...")
try:
# Load tokenizer and model explicitly to ensure correct loading
tokenizer = AutoTokenizer.from_pretrained(model_path)
model, loading_info = AutoModelForQuestionAnswering.from_pretrained(
model_path,
output_loading_info=True
)
logger.warning("RECCON load info - missing_keys: %s", loading_info.get("missing_keys"))
logger.warning("RECCON load info - unexpected_keys: %s", loading_info.get("unexpected_keys"))
logger.warning("RECCON load info - error_msgs: %s", loading_info.get("error_msgs"))
logger.warning("Loaded model class: %s", model.__class__.__name__)
logger.warning("Loaded model name_or_path: %s", getattr(model.config, "_name_or_path", None))
# Initialize the pipeline
# top_k=20 matches your previous 'n_best_size=20' logic
self.pipe = pipeline(
"question-answering",
model=model,
tokenizer=tokenizer,
device=self.device_id,
top_k=20,
handle_impossible_answer=False
)
logger.info("Model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
# Question template (must match training)
self.question_template = (
"Extract the exact short phrase (<= 8 words) from the target "
"utterance that most strongly signals the emotion {emotion}. "
"Return only a substring of the target utterance."
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process inference request.
"""
# Extract inputs
inputs = data.pop("inputs", data)
# Normalize to list format
if isinstance(inputs, dict):
inputs = [inputs]
if not inputs:
return [{"error": "No inputs provided", "triggers": []}]
# Validate and format inputs for the pipeline
pipeline_inputs = []
valid_indices = []
for i, item in enumerate(inputs):
utterance = item.get("utterance", "").strip()
emotion = item.get("emotion", "")
if not utterance:
logger.warning(f"Empty utterance at index {i}")
continue
# Format as QA task
question = self.question_template.format(emotion=emotion)
# The pipeline expects a list of dicts with 'question' and 'context'
pipeline_inputs.append({
'question': question,
'context': utterance
})
valid_indices.append(i)
# Run prediction
results = []
if not pipeline_inputs:
# All inputs were invalid
for item in inputs:
results.append({
"utterance": item.get("utterance", ""),
"emotion": item.get("emotion", ""),
"error": "Missing or empty utterance",
"triggers": []
})
return results
try:
# Run inference (batch_size helps with multiple inputs)
predictions = self.pipe(pipeline_inputs, batch_size=8)
# If batch_size=1 or single input, pipeline might return a single list/dict
# We ensure it's a list of lists (since top_k > 1)
if isinstance(predictions, dict): # Single input result
predictions = [predictions] # Wrap in list
elif isinstance(predictions, list) and len(predictions) > 0 and isinstance(predictions[0], dict):
# This happens if we have multiple inputs but top_k=1 (which is not the case here),
# OR if we have a single input and top_k > 1.
# If we have multiple inputs and top_k > 1, it returns a list of lists.
if len(pipeline_inputs) == 1:
predictions = [predictions]
# If multiple inputs and list of dicts, that implies top_k=1.
# But we set top_k=20. So it should be list of lists.
logger.debug(f"Raw predictions: {predictions}")
# Post-process results
pred_idx = 0
for i, item in enumerate(inputs):
utterance = item.get("utterance", "").strip()
emotion = item.get("emotion", "")
if i not in valid_indices:
results.append({
"utterance": utterance,
"emotion": emotion,
"error": "Missing or empty utterance",
"triggers": []
})
else:
# Get prediction for this item
# Because top_k=20, 'current_preds' is a list of dicts: [{'answer': '...', 'score': ...}, ...]
current_preds = predictions[pred_idx]
# Ensure it is a list
if isinstance(current_preds, dict):
current_preds = [current_preds]
logger.info(
"RECCON raw spans (answer, score): %s",
[(p.get("answer"), p.get("score", 0.0), 3) for p in current_preds[:5]]
)
def is_good_span(ans: str) -> bool:
if not ans:
return False
a = ans.strip()
if len(a) < 3:
return False
# reject pure punctuation
if all(ch in ".,!?;:-—'\"()[]{}" for ch in a):
return False
# require at least one letter
if not any(ch.isalpha() for ch in a):
return False
return True
raw_answers = [p.get("answer", "") for p in current_preds]
raw_answers = [a for a in raw_answers if is_good_span(a)]
triggers = self._clean_spans(raw_answers, utterance)
results.append({
"utterance": utterance,
"emotion": emotion,
"triggers": triggers
})
pred_idx += 1
logger.debug(f"Cleaned results: {results}")
return results
except Exception as e:
logger.error(f"Model prediction failed: {e}")
return [{
"utterance": item.get("utterance", ""),
"emotion": item.get("emotion", ""),
"error": str(e),
"triggers": []
} for item in inputs]
def _clean_spans(self, spans: List[str], target_text: str) -> List[str]:
"""
Clean and filter extracted trigger spans.
(Logic preserved exactly as provided)
"""
target_text = target_text or ""
target_lower = target_text.lower()
def _norm(s: str) -> str:
s = (s or "").strip().lower()
s = re.sub(r"\s+", " ", s)
s = re.sub(r"^[^\w]+|[^\w]+$", "", s)
return s
def _extract_from_target(target: str, phrase_lower: str) -> str:
idx = target.lower().find(phrase_lower)
if idx >= 0:
return target[idx:idx+len(phrase_lower)]
return phrase_lower
STOP = {
"a", "an", "the", "and", "or", "but", "so", "to", "of", "in", "on", "at",
"with", "for", "from", "is", "am", "are", "was", "were", "be", "been",
"being", "i", "you", "he", "she", "it", "we", "they", "my", "your", "his",
"her", "their", "our", "me", "him", "her", "them", "this", "that", "these",
"those"
}
candidates = []
for s in spans:
s = (s or "").strip()
if not s:
continue
s_norm = _norm(s)
if not s_norm:
continue
if target_text and s_norm not in target_lower:
continue
tokens = s_norm.split()
if len(tokens) > 8 or len(s_norm) > 80:
continue
if len(tokens) == 1 and (tokens[0] in STOP or len(tokens[0]) <= 2):
continue
candidates.append({
"norm": s_norm,
"tokens": tokens,
"tok_len": len(tokens),
"char_len": len(s_norm)
})
# Prioritize short, focused emotional keywords (1-3 words)
short_candidates = [c for c in candidates if 1 <= c["tok_len"] <= 3]
if short_candidates:
candidates = short_candidates
# Sort by SHORTEST spans first (most focused keywords)
candidates.sort(key=lambda x: (x["tok_len"], x["char_len"]), reverse=False)
kept_norms = []
for c in list(candidates):
n = c["norm"]
if any(n in kn or kn in n for kn in kept_norms):
continue
kept_norms.append(n)
cleaned = [_extract_from_target(target_text, n) for n in kept_norms]
if not cleaned and spans:
tt_tokens = target_lower.split()
best = None
for s in spans:
words = [w for w in (s or '').lower().strip().split() if w]
for L in range(min(8, len(words)), 0, -1):
for i in range(len(words) - L + 1):
phrase = words[i:i+L]
for j in range(len(tt_tokens) - L + 1):
if tt_tokens[j:j+L] == phrase:
cand = " ".join(phrase)
best = cand
break
if best:
break
if best:
break
if best:
return [_extract_from_target(target_text, best)]
return cleaned[:3]