Spaces:
Sleeping
Sleeping
File size: 5,963 Bytes
09d1986 c0d569c 09d1986 9c7f50a c0d569c 9c7f50a d1061a6 c0d569c 9c7f50a c0d569c 9c7f50a c0d569c 9c7f50a c0d569c 9c7f50a c0d569c 9c7f50a c0d569c 9c7f50a c0d569c 4b45566 c0d569c 4b45566 c0d569c 4b45566 c0d569c 4b45566 c0d569c 4b45566 9c7f50a 4b45566 9c7f50a c0d569c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | import os
import traceback
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
os.environ["HF_HOME"] = "/tmp/hf"
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
app = FastAPI(title="Code Summarizer API")
# Load model and tokenizer once
model_name = "Amitabhdas/code-summarizer-python"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
class CodeRequest(BaseModel):
code: str
@app.get("/")
def root():
return {"message": "Welcome to Code Summarizer API. Use /predict endpoint to get code summaries."}
def tokens_to_words(tokens):
"""
Convert subword tokens to full words and create a mapping from token to word.
"""
words = []
current_word = ""
mapping = []
for i, token in enumerate(tokens):
if token.startswith("▁"):
if current_word:
words.append(current_word)
current_word = token.lstrip("▁")
mapping.append(len(words))
else:
current_word += token
mapping.append(len(words) - 1 if words else 0)
if current_word:
words.append(current_word)
return words, mapping
def compute_word_importance(attention_weights, inputs):
"""
Compute the importance of each word based on attention weights.
Returns the top 10 most important words with their scores.
"""
# Get tokens and create mask from attention_mask
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
attention_mask = inputs.attention_mask[0].bool()
# Process attention weights safely with proper dimension handling
# First, let's handle the different possible shapes of attention weights
all_layers_mean = []
for layer_attn in attention_weights:
# Handle dimension properly
# layer_attn shape could be [batch_size, num_heads, seq_len, seq_len] or [num_heads, seq_len, seq_len]
if len(layer_attn.shape) == 4:
# If batch dimension is present
layer_mean = layer_attn[0].mean(dim=0) # Average over heads -> [seq_len, seq_len]
else:
layer_mean = layer_attn.mean(dim=0) # Average over heads -> [seq_len, seq_len]
all_layers_mean.append(layer_mean)
# Average across all layers
mean_attention = torch.stack(all_layers_mean).mean(dim=0) # [seq_len, seq_len]
# Get valid token indices
valid_indices = torch.where(attention_mask)[0]
# Get filtered attention and tokens
filtered_attention = mean_attention[valid_indices][:, valid_indices]
filtered_tokens = [tokens[i] for i in valid_indices.tolist()]
# Map tokens to words
words, token_to_word_map = tokens_to_words(filtered_tokens)
# Calculate token importance as sum of attention
token_importance = filtered_attention.sum(dim=0).cpu().numpy()
# Aggregate token importance by word
word_scores, word_counts = {}, {}
for idx, word_idx in enumerate(token_to_word_map):
if word_idx < len(words): # Safety check
word = words[word_idx]
if word and len(word.strip()) > 0: # Skip empty words
word_scores[word] = word_scores.get(word, 0) + token_importance[idx]
word_counts[word] = word_counts.get(word, 0) + 1
# Calculate average importance per word
word_importance = []
for word in word_scores:
avg_score = word_scores[word] / word_counts[word]
word_importance.append((word, float(avg_score))) # Convert to float for JSON serialization
# Sort by importance (descending) and return top 10
word_importance.sort(key=lambda x: x[1], reverse=True)
return word_importance[:10]
@app.post("/predict")
def predict_code_summary(request: CodeRequest):
try:
# Tokenize input code
inputs = tokenizer(
request.code,
max_length=512,
truncation=True,
padding="max_length",
return_tensors="pt"
)
# Step 1: Generate summary
with torch.no_grad():
summary_outputs = model.generate(
inputs.input_ids,
max_length=150,
num_beams=4,
early_stopping=True
)
summary = tokenizer.decode(summary_outputs[0], skip_special_tokens=True)
# Step 2: Run forward pass to get encoder attentions
with torch.no_grad():
forward_outputs = model.encoder(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
output_attentions=True,
return_dict=True
)
# Ensure we have the attentions
if not hasattr(forward_outputs, 'attentions') or not forward_outputs.attentions:
# Format response to maintain consistency
return {
'summary': summary,
'topWords': []
}
# Step 3: Compute word importance
try:
word_importance = compute_word_importance(forward_outputs.attentions, inputs)
except Exception as e:
# Format response to maintain consistency even when word importance fails
return {
'summary': summary,
'topWords': []
}
# Format the result according to the specified structure
result = {
'summary': summary,
'topWords': [{'word': word, 'score': float(score)} for word, score in word_importance]
}
# Return formatted JSON response
return result
except Exception as e:
error_details = {
"error": str(e),
"traceback": traceback.format_exc()
}
raise HTTPException(status_code=500, detail=error_details) |