Shreyash Indurkar
Final out changes
4b45566
import os
import traceback
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
os.environ["HF_HOME"] = "/tmp/hf"
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
app = FastAPI(title="Code Summarizer API")
# Load model and tokenizer once
model_name = "Amitabhdas/code-summarizer-python"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
class CodeRequest(BaseModel):
code: str
@app.get("/")
def root():
return {"message": "Welcome to Code Summarizer API. Use /predict endpoint to get code summaries."}
def tokens_to_words(tokens):
"""
Convert subword tokens to full words and create a mapping from token to word.
"""
words = []
current_word = ""
mapping = []
for i, token in enumerate(tokens):
if token.startswith("▁"):
if current_word:
words.append(current_word)
current_word = token.lstrip("▁")
mapping.append(len(words))
else:
current_word += token
mapping.append(len(words) - 1 if words else 0)
if current_word:
words.append(current_word)
return words, mapping
def compute_word_importance(attention_weights, inputs):
"""
Compute the importance of each word based on attention weights.
Returns the top 10 most important words with their scores.
"""
# Get tokens and create mask from attention_mask
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
attention_mask = inputs.attention_mask[0].bool()
# Process attention weights safely with proper dimension handling
# First, let's handle the different possible shapes of attention weights
all_layers_mean = []
for layer_attn in attention_weights:
# Handle dimension properly
# layer_attn shape could be [batch_size, num_heads, seq_len, seq_len] or [num_heads, seq_len, seq_len]
if len(layer_attn.shape) == 4:
# If batch dimension is present
layer_mean = layer_attn[0].mean(dim=0) # Average over heads -> [seq_len, seq_len]
else:
layer_mean = layer_attn.mean(dim=0) # Average over heads -> [seq_len, seq_len]
all_layers_mean.append(layer_mean)
# Average across all layers
mean_attention = torch.stack(all_layers_mean).mean(dim=0) # [seq_len, seq_len]
# Get valid token indices
valid_indices = torch.where(attention_mask)[0]
# Get filtered attention and tokens
filtered_attention = mean_attention[valid_indices][:, valid_indices]
filtered_tokens = [tokens[i] for i in valid_indices.tolist()]
# Map tokens to words
words, token_to_word_map = tokens_to_words(filtered_tokens)
# Calculate token importance as sum of attention
token_importance = filtered_attention.sum(dim=0).cpu().numpy()
# Aggregate token importance by word
word_scores, word_counts = {}, {}
for idx, word_idx in enumerate(token_to_word_map):
if word_idx < len(words): # Safety check
word = words[word_idx]
if word and len(word.strip()) > 0: # Skip empty words
word_scores[word] = word_scores.get(word, 0) + token_importance[idx]
word_counts[word] = word_counts.get(word, 0) + 1
# Calculate average importance per word
word_importance = []
for word in word_scores:
avg_score = word_scores[word] / word_counts[word]
word_importance.append((word, float(avg_score))) # Convert to float for JSON serialization
# Sort by importance (descending) and return top 10
word_importance.sort(key=lambda x: x[1], reverse=True)
return word_importance[:10]
@app.post("/predict")
def predict_code_summary(request: CodeRequest):
try:
# Tokenize input code
inputs = tokenizer(
request.code,
max_length=512,
truncation=True,
padding="max_length",
return_tensors="pt"
)
# Step 1: Generate summary
with torch.no_grad():
summary_outputs = model.generate(
inputs.input_ids,
max_length=150,
num_beams=4,
early_stopping=True
)
summary = tokenizer.decode(summary_outputs[0], skip_special_tokens=True)
# Step 2: Run forward pass to get encoder attentions
with torch.no_grad():
forward_outputs = model.encoder(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
output_attentions=True,
return_dict=True
)
# Ensure we have the attentions
if not hasattr(forward_outputs, 'attentions') or not forward_outputs.attentions:
# Format response to maintain consistency
return {
'summary': summary,
'topWords': []
}
# Step 3: Compute word importance
try:
word_importance = compute_word_importance(forward_outputs.attentions, inputs)
except Exception as e:
# Format response to maintain consistency even when word importance fails
return {
'summary': summary,
'topWords': []
}
# Format the result according to the specified structure
result = {
'summary': summary,
'topWords': [{'word': word, 'score': float(score)} for word, score in word_importance]
}
# Return formatted JSON response
return result
except Exception as e:
error_details = {
"error": str(e),
"traceback": traceback.format_exc()
}
raise HTTPException(status_code=500, detail=error_details)