import os import traceback os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf" os.environ["HF_HOME"] = "/tmp/hf" from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch app = FastAPI(title="Code Summarizer API") # Load model and tokenizer once model_name = "Amitabhdas/code-summarizer-python" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) class CodeRequest(BaseModel): code: str @app.get("/") def root(): return {"message": "Welcome to Code Summarizer API. Use /predict endpoint to get code summaries."} def tokens_to_words(tokens): """ Convert subword tokens to full words and create a mapping from token to word. """ words = [] current_word = "" mapping = [] for i, token in enumerate(tokens): if token.startswith("▁"): if current_word: words.append(current_word) current_word = token.lstrip("▁") mapping.append(len(words)) else: current_word += token mapping.append(len(words) - 1 if words else 0) if current_word: words.append(current_word) return words, mapping def compute_word_importance(attention_weights, inputs): """ Compute the importance of each word based on attention weights. Returns the top 10 most important words with their scores. """ # Get tokens and create mask from attention_mask tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) attention_mask = inputs.attention_mask[0].bool() # Process attention weights safely with proper dimension handling # First, let's handle the different possible shapes of attention weights all_layers_mean = [] for layer_attn in attention_weights: # Handle dimension properly # layer_attn shape could be [batch_size, num_heads, seq_len, seq_len] or [num_heads, seq_len, seq_len] if len(layer_attn.shape) == 4: # If batch dimension is present layer_mean = layer_attn[0].mean(dim=0) # Average over heads -> [seq_len, seq_len] else: layer_mean = layer_attn.mean(dim=0) # Average over heads -> [seq_len, seq_len] all_layers_mean.append(layer_mean) # Average across all layers mean_attention = torch.stack(all_layers_mean).mean(dim=0) # [seq_len, seq_len] # Get valid token indices valid_indices = torch.where(attention_mask)[0] # Get filtered attention and tokens filtered_attention = mean_attention[valid_indices][:, valid_indices] filtered_tokens = [tokens[i] for i in valid_indices.tolist()] # Map tokens to words words, token_to_word_map = tokens_to_words(filtered_tokens) # Calculate token importance as sum of attention token_importance = filtered_attention.sum(dim=0).cpu().numpy() # Aggregate token importance by word word_scores, word_counts = {}, {} for idx, word_idx in enumerate(token_to_word_map): if word_idx < len(words): # Safety check word = words[word_idx] if word and len(word.strip()) > 0: # Skip empty words word_scores[word] = word_scores.get(word, 0) + token_importance[idx] word_counts[word] = word_counts.get(word, 0) + 1 # Calculate average importance per word word_importance = [] for word in word_scores: avg_score = word_scores[word] / word_counts[word] word_importance.append((word, float(avg_score))) # Convert to float for JSON serialization # Sort by importance (descending) and return top 10 word_importance.sort(key=lambda x: x[1], reverse=True) return word_importance[:10] @app.post("/predict") def predict_code_summary(request: CodeRequest): try: # Tokenize input code inputs = tokenizer( request.code, max_length=512, truncation=True, padding="max_length", return_tensors="pt" ) # Step 1: Generate summary with torch.no_grad(): summary_outputs = model.generate( inputs.input_ids, max_length=150, num_beams=4, early_stopping=True ) summary = tokenizer.decode(summary_outputs[0], skip_special_tokens=True) # Step 2: Run forward pass to get encoder attentions with torch.no_grad(): forward_outputs = model.encoder( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, output_attentions=True, return_dict=True ) # Ensure we have the attentions if not hasattr(forward_outputs, 'attentions') or not forward_outputs.attentions: # Format response to maintain consistency return { 'summary': summary, 'topWords': [] } # Step 3: Compute word importance try: word_importance = compute_word_importance(forward_outputs.attentions, inputs) except Exception as e: # Format response to maintain consistency even when word importance fails return { 'summary': summary, 'topWords': [] } # Format the result according to the specified structure result = { 'summary': summary, 'topWords': [{'word': word, 'score': float(score)} for word, score in word_importance] } # Return formatted JSON response return result except Exception as e: error_details = { "error": str(e), "traceback": traceback.format_exc() } raise HTTPException(status_code=500, detail=error_details)