|
|
import torch |
|
|
import logging |
|
|
import re |
|
|
from typing import Dict, List, Any |
|
|
from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Initialize the RECCON emotional trigger extraction model using native transformers. |
|
|
Args: |
|
|
path: Path to model directory (provided by HuggingFace Inference Endpoints) |
|
|
""" |
|
|
logger.info("Initializing RECCON Trigger Extraction endpoint...") |
|
|
|
|
|
|
|
|
cuda_available = torch.cuda.is_available() |
|
|
if not cuda_available: |
|
|
logger.warning("GPU not detected. Running on CPU. Inference will be slower.") |
|
|
|
|
|
|
|
|
self.device_id = 0 if cuda_available else -1 |
|
|
|
|
|
|
|
|
model_path = path if path and path != "." else "." |
|
|
logger.info(f"Loading model from {model_path}...") |
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
model, loading_info = AutoModelForQuestionAnswering.from_pretrained( |
|
|
model_path, |
|
|
output_loading_info=True |
|
|
) |
|
|
|
|
|
logger.warning("RECCON load info - missing_keys: %s", loading_info.get("missing_keys")) |
|
|
logger.warning("RECCON load info - unexpected_keys: %s", loading_info.get("unexpected_keys")) |
|
|
logger.warning("RECCON load info - error_msgs: %s", loading_info.get("error_msgs")) |
|
|
logger.warning("Loaded model class: %s", model.__class__.__name__) |
|
|
logger.warning("Loaded model name_or_path: %s", getattr(model.config, "_name_or_path", None)) |
|
|
|
|
|
|
|
|
|
|
|
self.pipe = pipeline( |
|
|
"question-answering", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
device=self.device_id, |
|
|
top_k=20, |
|
|
handle_impossible_answer=False |
|
|
) |
|
|
logger.info("Model loaded successfully.") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
self.question_template = ( |
|
|
"Extract the exact short phrase (<= 8 words) from the target " |
|
|
"utterance that most strongly signals the emotion {emotion}. " |
|
|
"Return only a substring of the target utterance." |
|
|
) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Process inference request. |
|
|
""" |
|
|
|
|
|
inputs = data.pop("inputs", data) |
|
|
|
|
|
|
|
|
if isinstance(inputs, dict): |
|
|
inputs = [inputs] |
|
|
|
|
|
if not inputs: |
|
|
return [{"error": "No inputs provided", "triggers": []}] |
|
|
|
|
|
|
|
|
pipeline_inputs = [] |
|
|
valid_indices = [] |
|
|
|
|
|
for i, item in enumerate(inputs): |
|
|
utterance = item.get("utterance", "").strip() |
|
|
emotion = item.get("emotion", "") |
|
|
|
|
|
if not utterance: |
|
|
logger.warning(f"Empty utterance at index {i}") |
|
|
continue |
|
|
|
|
|
|
|
|
question = self.question_template.format(emotion=emotion) |
|
|
|
|
|
|
|
|
pipeline_inputs.append({ |
|
|
'question': question, |
|
|
'context': utterance |
|
|
}) |
|
|
valid_indices.append(i) |
|
|
|
|
|
|
|
|
results = [] |
|
|
|
|
|
if not pipeline_inputs: |
|
|
|
|
|
for item in inputs: |
|
|
results.append({ |
|
|
"utterance": item.get("utterance", ""), |
|
|
"emotion": item.get("emotion", ""), |
|
|
"error": "Missing or empty utterance", |
|
|
"triggers": [] |
|
|
}) |
|
|
return results |
|
|
|
|
|
try: |
|
|
|
|
|
predictions = self.pipe(pipeline_inputs, batch_size=8) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(predictions, dict): |
|
|
predictions = [predictions] |
|
|
elif isinstance(predictions, list) and len(predictions) > 0 and isinstance(predictions[0], dict): |
|
|
|
|
|
|
|
|
|
|
|
if len(pipeline_inputs) == 1: |
|
|
predictions = [predictions] |
|
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Raw predictions: {predictions}") |
|
|
|
|
|
|
|
|
pred_idx = 0 |
|
|
for i, item in enumerate(inputs): |
|
|
utterance = item.get("utterance", "").strip() |
|
|
emotion = item.get("emotion", "") |
|
|
|
|
|
if i not in valid_indices: |
|
|
results.append({ |
|
|
"utterance": utterance, |
|
|
"emotion": emotion, |
|
|
"error": "Missing or empty utterance", |
|
|
"triggers": [] |
|
|
}) |
|
|
else: |
|
|
|
|
|
|
|
|
current_preds = predictions[pred_idx] |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(current_preds, dict): |
|
|
current_preds = [current_preds] |
|
|
|
|
|
logger.info( |
|
|
"RECCON raw spans (answer, score): %s", |
|
|
[(p.get("answer"), p.get("score", 0.0), 3) for p in current_preds[:5]] |
|
|
) |
|
|
|
|
|
def is_good_span(ans: str) -> bool: |
|
|
if not ans: |
|
|
return False |
|
|
a = ans.strip() |
|
|
if len(a) < 3: |
|
|
return False |
|
|
|
|
|
if all(ch in ".,!?;:-—'\"()[]{}" for ch in a): |
|
|
return False |
|
|
|
|
|
if not any(ch.isalpha() for ch in a): |
|
|
return False |
|
|
return True |
|
|
|
|
|
raw_answers = [p.get("answer", "") for p in current_preds] |
|
|
raw_answers = [a for a in raw_answers if is_good_span(a)] |
|
|
triggers = self._clean_spans(raw_answers, utterance) |
|
|
|
|
|
results.append({ |
|
|
"utterance": utterance, |
|
|
"emotion": emotion, |
|
|
"triggers": triggers |
|
|
}) |
|
|
pred_idx += 1 |
|
|
|
|
|
logger.debug(f"Cleaned results: {results}") |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Model prediction failed: {e}") |
|
|
return [{ |
|
|
"utterance": item.get("utterance", ""), |
|
|
"emotion": item.get("emotion", ""), |
|
|
"error": str(e), |
|
|
"triggers": [] |
|
|
} for item in inputs] |
|
|
|
|
|
def _clean_spans(self, spans: List[str], target_text: str) -> List[str]: |
|
|
""" |
|
|
Clean and filter extracted trigger spans. |
|
|
(Logic preserved exactly as provided) |
|
|
""" |
|
|
target_text = target_text or "" |
|
|
target_lower = target_text.lower() |
|
|
|
|
|
def _norm(s: str) -> str: |
|
|
s = (s or "").strip().lower() |
|
|
s = re.sub(r"\s+", " ", s) |
|
|
s = re.sub(r"^[^\w]+|[^\w]+$", "", s) |
|
|
return s |
|
|
|
|
|
def _extract_from_target(target: str, phrase_lower: str) -> str: |
|
|
idx = target.lower().find(phrase_lower) |
|
|
if idx >= 0: |
|
|
return target[idx:idx+len(phrase_lower)] |
|
|
return phrase_lower |
|
|
|
|
|
STOP = { |
|
|
"a", "an", "the", "and", "or", "but", "so", "to", "of", "in", "on", "at", |
|
|
"with", "for", "from", "is", "am", "are", "was", "were", "be", "been", |
|
|
"being", "i", "you", "he", "she", "it", "we", "they", "my", "your", "his", |
|
|
"her", "their", "our", "me", "him", "her", "them", "this", "that", "these", |
|
|
"those" |
|
|
} |
|
|
|
|
|
candidates = [] |
|
|
for s in spans: |
|
|
s = (s or "").strip() |
|
|
if not s: |
|
|
continue |
|
|
s_norm = _norm(s) |
|
|
if not s_norm: |
|
|
continue |
|
|
if target_text and s_norm not in target_lower: |
|
|
continue |
|
|
tokens = s_norm.split() |
|
|
if len(tokens) > 8 or len(s_norm) > 80: |
|
|
continue |
|
|
if len(tokens) == 1 and (tokens[0] in STOP or len(tokens[0]) <= 2): |
|
|
continue |
|
|
candidates.append({ |
|
|
"norm": s_norm, |
|
|
"tokens": tokens, |
|
|
"tok_len": len(tokens), |
|
|
"char_len": len(s_norm) |
|
|
}) |
|
|
|
|
|
|
|
|
short_candidates = [c for c in candidates if 1 <= c["tok_len"] <= 3] |
|
|
if short_candidates: |
|
|
candidates = short_candidates |
|
|
|
|
|
|
|
|
candidates.sort(key=lambda x: (x["tok_len"], x["char_len"]), reverse=False) |
|
|
kept_norms = [] |
|
|
for c in list(candidates): |
|
|
n = c["norm"] |
|
|
if any(n in kn or kn in n for kn in kept_norms): |
|
|
continue |
|
|
kept_norms.append(n) |
|
|
|
|
|
cleaned = [_extract_from_target(target_text, n) for n in kept_norms] |
|
|
|
|
|
if not cleaned and spans: |
|
|
tt_tokens = target_lower.split() |
|
|
best = None |
|
|
for s in spans: |
|
|
words = [w for w in (s or '').lower().strip().split() if w] |
|
|
for L in range(min(8, len(words)), 0, -1): |
|
|
for i in range(len(words) - L + 1): |
|
|
phrase = words[i:i+L] |
|
|
for j in range(len(tt_tokens) - L + 1): |
|
|
if tt_tokens[j:j+L] == phrase: |
|
|
cand = " ".join(phrase) |
|
|
best = cand |
|
|
break |
|
|
if best: |
|
|
break |
|
|
if best: |
|
|
break |
|
|
if best: |
|
|
return [_extract_from_target(target_text, best)] |
|
|
|
|
|
return cleaned[:3] |