File size: 5,963 Bytes
09d1986
c0d569c
09d1986
 
 
9c7f50a
 
 
 
 
 
 
c0d569c
9c7f50a
 
 
 
 
 
 
d1061a6
 
 
 
c0d569c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c7f50a
 
 
c0d569c
9c7f50a
 
 
 
 
 
 
c0d569c
 
9c7f50a
c0d569c
9c7f50a
 
 
c0d569c
 
 
 
 
 
 
 
 
9c7f50a
c0d569c
9c7f50a
c0d569c
 
 
4b45566
c0d569c
4b45566
 
c0d569c
 
 
 
 
 
4b45566
c0d569c
4b45566
 
c0d569c
 
4b45566
 
 
 
9c7f50a
4b45566
 
 
9c7f50a
c0d569c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import traceback
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
os.environ["HF_HOME"] = "/tmp/hf"

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

app = FastAPI(title="Code Summarizer API")

# Load model and tokenizer once
model_name = "Amitabhdas/code-summarizer-python"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

class CodeRequest(BaseModel):
    code: str

@app.get("/")
def root():
    return {"message": "Welcome to Code Summarizer API. Use /predict endpoint to get code summaries."}

def tokens_to_words(tokens):
    """
    Convert subword tokens to full words and create a mapping from token to word.
    """
    words = []
    current_word = ""
    mapping = []
    for i, token in enumerate(tokens):
        if token.startswith("▁"):
            if current_word:
                words.append(current_word)
            current_word = token.lstrip("▁")
            mapping.append(len(words))
        else:
            current_word += token
            mapping.append(len(words) - 1 if words else 0)
    if current_word:
        words.append(current_word)
    return words, mapping

def compute_word_importance(attention_weights, inputs):
    """
    Compute the importance of each word based on attention weights.
    Returns the top 10 most important words with their scores.
    """
    # Get tokens and create mask from attention_mask
    tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
    attention_mask = inputs.attention_mask[0].bool()
    
    # Process attention weights safely with proper dimension handling
    # First, let's handle the different possible shapes of attention weights
    all_layers_mean = []
    
    for layer_attn in attention_weights:
        # Handle dimension properly
        # layer_attn shape could be [batch_size, num_heads, seq_len, seq_len] or [num_heads, seq_len, seq_len]
        if len(layer_attn.shape) == 4:
            # If batch dimension is present
            layer_mean = layer_attn[0].mean(dim=0)  # Average over heads -> [seq_len, seq_len]
        else:
            layer_mean = layer_attn.mean(dim=0)  # Average over heads -> [seq_len, seq_len]
        all_layers_mean.append(layer_mean)
    
    # Average across all layers
    mean_attention = torch.stack(all_layers_mean).mean(dim=0)  # [seq_len, seq_len]
    
    # Get valid token indices
    valid_indices = torch.where(attention_mask)[0]
    
    # Get filtered attention and tokens
    filtered_attention = mean_attention[valid_indices][:, valid_indices]
    filtered_tokens = [tokens[i] for i in valid_indices.tolist()]
    
    # Map tokens to words
    words, token_to_word_map = tokens_to_words(filtered_tokens)
    
    # Calculate token importance as sum of attention
    token_importance = filtered_attention.sum(dim=0).cpu().numpy()
    
    # Aggregate token importance by word
    word_scores, word_counts = {}, {}
    for idx, word_idx in enumerate(token_to_word_map):
        if word_idx < len(words):  # Safety check
            word = words[word_idx]
            if word and len(word.strip()) > 0:  # Skip empty words
                word_scores[word] = word_scores.get(word, 0) + token_importance[idx]
                word_counts[word] = word_counts.get(word, 0) + 1
    
    # Calculate average importance per word
    word_importance = []
    for word in word_scores:
        avg_score = word_scores[word] / word_counts[word]
        word_importance.append((word, float(avg_score)))  # Convert to float for JSON serialization
    
    # Sort by importance (descending) and return top 10
    word_importance.sort(key=lambda x: x[1], reverse=True)
    return word_importance[:10]

@app.post("/predict")
def predict_code_summary(request: CodeRequest):
    try:
        # Tokenize input code
        inputs = tokenizer(
            request.code,
            max_length=512,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        
        # Step 1: Generate summary
        with torch.no_grad():
            summary_outputs = model.generate(
                inputs.input_ids,
                max_length=150,
                num_beams=4,
                early_stopping=True
            )
        summary = tokenizer.decode(summary_outputs[0], skip_special_tokens=True)
        
        # Step 2: Run forward pass to get encoder attentions
        with torch.no_grad():
            forward_outputs = model.encoder(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                output_attentions=True,
                return_dict=True
            )
        
        # Ensure we have the attentions
        if not hasattr(forward_outputs, 'attentions') or not forward_outputs.attentions:
            # Format response to maintain consistency
            return {
                'summary': summary,
                'topWords': []
            }
        
        # Step 3: Compute word importance
        try:
            word_importance = compute_word_importance(forward_outputs.attentions, inputs)
        except Exception as e:
            # Format response to maintain consistency even when word importance fails
            return {
                'summary': summary,
                'topWords': []
            }
        
        # Format the result according to the specified structure
        result = {
            'summary': summary,
            'topWords': [{'word': word, 'score': float(score)} for word, score in word_importance]
        }
        
        # Return formatted JSON response
        return result
    except Exception as e:
        error_details = {
            "error": str(e),
            "traceback": traceback.format_exc()
        }
        raise HTTPException(status_code=500, detail=error_details)