Spaces:
Sleeping
Sleeping
| import os | |
| import traceback | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf" | |
| os.environ["HF_HOME"] = "/tmp/hf" | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| app = FastAPI(title="Code Summarizer API") | |
| # Load model and tokenizer once | |
| model_name = "Amitabhdas/code-summarizer-python" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| class CodeRequest(BaseModel): | |
| code: str | |
| def root(): | |
| return {"message": "Welcome to Code Summarizer API. Use /predict endpoint to get code summaries."} | |
| def tokens_to_words(tokens): | |
| """ | |
| Convert subword tokens to full words and create a mapping from token to word. | |
| """ | |
| words = [] | |
| current_word = "" | |
| mapping = [] | |
| for i, token in enumerate(tokens): | |
| if token.startswith("▁"): | |
| if current_word: | |
| words.append(current_word) | |
| current_word = token.lstrip("▁") | |
| mapping.append(len(words)) | |
| else: | |
| current_word += token | |
| mapping.append(len(words) - 1 if words else 0) | |
| if current_word: | |
| words.append(current_word) | |
| return words, mapping | |
| def compute_word_importance(attention_weights, inputs): | |
| """ | |
| Compute the importance of each word based on attention weights. | |
| Returns the top 10 most important words with their scores. | |
| """ | |
| # Get tokens and create mask from attention_mask | |
| tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) | |
| attention_mask = inputs.attention_mask[0].bool() | |
| # Process attention weights safely with proper dimension handling | |
| # First, let's handle the different possible shapes of attention weights | |
| all_layers_mean = [] | |
| for layer_attn in attention_weights: | |
| # Handle dimension properly | |
| # layer_attn shape could be [batch_size, num_heads, seq_len, seq_len] or [num_heads, seq_len, seq_len] | |
| if len(layer_attn.shape) == 4: | |
| # If batch dimension is present | |
| layer_mean = layer_attn[0].mean(dim=0) # Average over heads -> [seq_len, seq_len] | |
| else: | |
| layer_mean = layer_attn.mean(dim=0) # Average over heads -> [seq_len, seq_len] | |
| all_layers_mean.append(layer_mean) | |
| # Average across all layers | |
| mean_attention = torch.stack(all_layers_mean).mean(dim=0) # [seq_len, seq_len] | |
| # Get valid token indices | |
| valid_indices = torch.where(attention_mask)[0] | |
| # Get filtered attention and tokens | |
| filtered_attention = mean_attention[valid_indices][:, valid_indices] | |
| filtered_tokens = [tokens[i] for i in valid_indices.tolist()] | |
| # Map tokens to words | |
| words, token_to_word_map = tokens_to_words(filtered_tokens) | |
| # Calculate token importance as sum of attention | |
| token_importance = filtered_attention.sum(dim=0).cpu().numpy() | |
| # Aggregate token importance by word | |
| word_scores, word_counts = {}, {} | |
| for idx, word_idx in enumerate(token_to_word_map): | |
| if word_idx < len(words): # Safety check | |
| word = words[word_idx] | |
| if word and len(word.strip()) > 0: # Skip empty words | |
| word_scores[word] = word_scores.get(word, 0) + token_importance[idx] | |
| word_counts[word] = word_counts.get(word, 0) + 1 | |
| # Calculate average importance per word | |
| word_importance = [] | |
| for word in word_scores: | |
| avg_score = word_scores[word] / word_counts[word] | |
| word_importance.append((word, float(avg_score))) # Convert to float for JSON serialization | |
| # Sort by importance (descending) and return top 10 | |
| word_importance.sort(key=lambda x: x[1], reverse=True) | |
| return word_importance[:10] | |
| def predict_code_summary(request: CodeRequest): | |
| try: | |
| # Tokenize input code | |
| inputs = tokenizer( | |
| request.code, | |
| max_length=512, | |
| truncation=True, | |
| padding="max_length", | |
| return_tensors="pt" | |
| ) | |
| # Step 1: Generate summary | |
| with torch.no_grad(): | |
| summary_outputs = model.generate( | |
| inputs.input_ids, | |
| max_length=150, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| summary = tokenizer.decode(summary_outputs[0], skip_special_tokens=True) | |
| # Step 2: Run forward pass to get encoder attentions | |
| with torch.no_grad(): | |
| forward_outputs = model.encoder( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| output_attentions=True, | |
| return_dict=True | |
| ) | |
| # Ensure we have the attentions | |
| if not hasattr(forward_outputs, 'attentions') or not forward_outputs.attentions: | |
| # Format response to maintain consistency | |
| return { | |
| 'summary': summary, | |
| 'topWords': [] | |
| } | |
| # Step 3: Compute word importance | |
| try: | |
| word_importance = compute_word_importance(forward_outputs.attentions, inputs) | |
| except Exception as e: | |
| # Format response to maintain consistency even when word importance fails | |
| return { | |
| 'summary': summary, | |
| 'topWords': [] | |
| } | |
| # Format the result according to the specified structure | |
| result = { | |
| 'summary': summary, | |
| 'topWords': [{'word': word, 'score': float(score)} for word, score in word_importance] | |
| } | |
| # Return formatted JSON response | |
| return result | |
| except Exception as e: | |
| error_details = { | |
| "error": str(e), | |
| "traceback": traceback.format_exc() | |
| } | |
| raise HTTPException(status_code=500, detail=error_details) |