Text Summarization: Extractive & AbstractiveΒΆ
BART, T5, and PEGASUS for document summarization. Evaluation with ROUGE scores.
# Install dependencies
# !pip install transformers torch rouge-score nltk numpy
Abstractive Summarization with TransformersΒΆ
# Abstractive summarization with BART (requires transformers)
'''
from transformers import BartForConditionalGeneration, BartTokenizer
# Load BART model
model_name = "facebook/bart-large-cnn"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
# Sample article
article = """
The stock market experienced significant volatility today as investors reacted
to new inflation data. The Dow Jones Industrial Average fell 1.5%, while the
S&P 500 dropped 2.1%. Technology stocks were particularly hard hit, with the
Nasdaq Composite declining 3.2%. Energy stocks bucked the trend, rising 2.3%
on higher oil prices. Analysts predict continued uncertainty in the coming weeks
as the Federal Reserve considers interest rate adjustments.
"""
# Generate summary
inputs = tokenizer([article], max_length=1024, return_tensors="pt", truncation=True)
summary_ids = model.generate(
inputs["input_ids"],
max_length=150,
min_length=40,
length_penalty=2.0,
num_beams=4,
early_stopping=True
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
print(f"Original ({len(article.split())} words):\n{article}\n")
print(f"Summary ({len(summary.split())} words):\n{summary}")
# Alternative: T5 for summarization
from transformers import T5ForConditionalGeneration, T5Tokenizer
model_name = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
# T5 requires "summarize: " prefix
input_text = "summarize: " + article
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
summary_ids = model.generate(inputs["input_ids"], max_length=150, min_length=40)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
print(f"T5 Summary: {summary}")
'''
print("Popular Summarization Models:\n")
print("Abstractive:")
print(" β’ facebook/bart-large-cnn (CNN/DailyMail news)")
print(" β’ t5-base, t5-large (general purpose)")
print(" β’ google/pegasus-xsum (extreme summarization)")
print(" β’ google/pegasus-cnn_dailymail (news articles)")
print("\nExtractive:")
print(" β’ TextRank (graph-based, unsupervised)")
print(" β’ BertSum (BERT-based sentence selection)")
print(" β’ LexRank (PageRank for sentences)")
Extractive Summarization with TextRankΒΆ
from typing import List, Dict
from collections import Counter
import re
import numpy as np
class TextRankSummarizer:
"""Extractive summarization using TextRank algorithm"""
def __init__(self, damping: float = 0.85, iterations: int = 100):
self.damping = damping
self.iterations = iterations
def _split_sentences(self, text: str) -> List[str]:
"""Split text into sentences"""
# Simple sentence splitter
sentences = re.split(r'[.!?]+', text)
return [s.strip() for s in sentences if s.strip()]
def _sentence_similarity(self, sent1: str, sent2: str) -> float:
"""Calculate similarity between two sentences"""
# Simple word overlap similarity
words1 = set(sent1.lower().split())
words2 = set(sent2.lower().split())
if not words1 or not words2:
return 0.0
intersection = words1 & words2
union = words1 | words2
return len(intersection) / len(union)
def _build_similarity_matrix(self, sentences: List[str]) -> np.ndarray:
"""Build sentence similarity matrix"""
n = len(sentences)
matrix = np.zeros((n, n))
for i in range(n):
for j in range(n):
if i != j:
matrix[i][j] = self._sentence_similarity(sentences[i], sentences[j])
return matrix
def _page_rank(self, matrix: np.ndarray) -> np.ndarray:
"""Apply PageRank algorithm"""
n = matrix.shape[0]
# Normalize matrix
row_sums = matrix.sum(axis=1, keepdims=True)
row_sums[row_sums == 0] = 1 # Avoid division by zero
matrix_norm = matrix / row_sums
# Initialize scores
scores = np.ones(n) / n
# Iterate
for _ in range(self.iterations):
scores = (1 - self.damping) / n + self.damping * matrix_norm.T @ scores
return scores
def summarize(self, text: str, num_sentences: int = 3) -> str:
"""Generate extractive summary"""
# Split into sentences
sentences = self._split_sentences(text)
if len(sentences) <= num_sentences:
return text
# Build similarity matrix
similarity_matrix = self._build_similarity_matrix(sentences)
# Apply PageRank
scores = self._page_rank(similarity_matrix)
# Get top sentences
top_indices = np.argsort(scores)[-num_sentences:][::-1]
top_indices = sorted(top_indices) # Maintain original order
summary_sentences = [sentences[i] for i in top_indices]
return '. '.join(summary_sentences) + '.'
# Test TextRank
article = """
The stock market experienced significant volatility today as investors reacted to new inflation data.
The Dow Jones Industrial Average fell 1.5%, while the S&P 500 dropped 2.1%.
Technology stocks were particularly hard hit, with the Nasdaq Composite declining 3.2%.
Energy stocks bucked the trend, rising 2.3% on higher oil prices.
Analysts predict continued uncertainty in the coming weeks as the Federal Reserve considers interest rate adjustments.
Many investors are moving to safer assets like bonds and gold.
The volatility is expected to continue until there is more clarity on monetary policy.
"""
summarizer = TextRankSummarizer()
summary = summarizer.summarize(article, num_sentences=3)
print(f"Original ({len(article.split())} words):\n{article}")
print(f"\nSummary ({len(summary.split())} words):\n{summary}")
ROUGE Evaluation MetricsΒΆ
from collections import Counter
from typing import Tuple
class ROUGEMetrics:
"""ROUGE metrics for summarization evaluation"""
@staticmethod
def _get_ngrams(text: str, n: int) -> Counter:
"""Get n-grams from text"""
words = text.lower().split()
ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)]
return Counter(ngrams)
@staticmethod
def rouge_n(reference: str, hypothesis: str, n: int = 1) -> Dict[str, float]:
"""ROUGE-N score (unigram or bigram overlap)"""
ref_ngrams = ROUGEMetrics._get_ngrams(reference, n)
hyp_ngrams = ROUGEMetrics._get_ngrams(hypothesis, n)
if not ref_ngrams or not hyp_ngrams:
return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
# Count overlapping n-grams
overlap = sum((ref_ngrams & hyp_ngrams).values())
precision = overlap / sum(hyp_ngrams.values())
recall = overlap / sum(ref_ngrams.values())
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
return {"precision": precision, "recall": recall, "f1": f1}
@staticmethod
def rouge_l(reference: str, hypothesis: str) -> Dict[str, float]:
"""ROUGE-L score (longest common subsequence)"""
ref_words = reference.lower().split()
hyp_words = hypothesis.lower().split()
# Compute LCS length
m, n = len(ref_words), len(hyp_words)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(1, m + 1):
for j in range(1, n + 1):
if ref_words[i-1] == hyp_words[j-1]:
dp[i][j] = dp[i-1][j-1] + 1
else:
dp[i][j] = max(dp[i-1][j], dp[i][j-1])
lcs_length = dp[m][n]
precision = lcs_length / n if n > 0 else 0
recall = lcs_length / m if m > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
return {"precision": precision, "recall": recall, "f1": f1}
@staticmethod
def rouge_scores(reference: str, hypothesis: str) -> Dict[str, Dict[str, float]]:
"""Compute all ROUGE scores"""
return {
"rouge-1": ROUGEMetrics.rouge_n(reference, hypothesis, 1),
"rouge-2": ROUGEMetrics.rouge_n(reference, hypothesis, 2),
"rouge-l": ROUGEMetrics.rouge_l(reference, hypothesis)
}
# Test ROUGE metrics
reference = "The stock market fell today. Technology stocks declined sharply. Analysts predict uncertainty."
hypothesis1 = "Stock market fell. Tech stocks dropped. Uncertainty predicted." # Good summary
hypothesis2 = "The weather was nice today. People went outside." # Bad summary
print("ROUGE Evaluation:\n")
print(f"Reference: {reference}\n")
for i, hyp in enumerate([hypothesis1, hypothesis2], 1):
scores = ROUGEMetrics.rouge_scores(reference, hyp)
print(f"Hypothesis {i}: {hyp}")
for metric, values in scores.items():
print(f" {metric.upper()}: P={values['precision']:.3f}, R={values['recall']:.3f}, F1={values['f1']:.3f}")
print()
Production Summarization SystemΒΆ
from dataclasses import dataclass
from typing import Optional
import hashlib
import time
from collections import defaultdict
@dataclass
class Summary:
"""Summary result"""
original_text: str
summary_text: str
method: str # 'extractive' or 'abstractive'
num_sentences: int
compression_ratio: float
quality_score: Optional[float] = None
class ProductionSummarizer:
"""Production-ready summarization system"""
def __init__(self, method: str = "extractive"):
self.method = method
self.extractive_summarizer = TextRankSummarizer()
self.cache = {}
self.stats = {
"total_summarizations": 0,
"cache_hits": 0,
"avg_compression_ratio": 0.0,
"compression_ratios": [],
"summaries_by_method": defaultdict(int)
}
def _get_cache_key(self, text: str, num_sentences: int) -> str:
"""Generate cache key"""
key_str = f"{text}_{num_sentences}_{self.method}"
return hashlib.md5(key_str.encode()).hexdigest()
def _calculate_quality_score(self, original: str, summary: str) -> float:
"""Estimate summary quality (simplified)"""
# Use word overlap as a simple quality metric
orig_words = set(original.lower().split())
summ_words = set(summary.lower().split())
if not summ_words:
return 0.0
overlap = len(orig_words & summ_words) / len(summ_words)
return overlap
def summarize(
self,
text: str,
num_sentences: int = 3,
use_cache: bool = True
) -> Summary:
"""Generate summary with caching and monitoring"""
# Check cache
cache_key = self._get_cache_key(text, num_sentences)
if use_cache and cache_key in self.cache:
self.stats["cache_hits"] += 1
self.stats["total_summarizations"] += 1
return self.cache[cache_key]
# Generate summary
if self.method == "extractive":
summary_text = self.extractive_summarizer.summarize(text, num_sentences)
else:
# Fallback to extractive if abstractive not available
summary_text = self.extractive_summarizer.summarize(text, num_sentences)
# Calculate metrics
orig_words = len(text.split())
summ_words = len(summary_text.split())
compression_ratio = summ_words / orig_words if orig_words > 0 else 0
quality_score = self._calculate_quality_score(text, summary_text)
# Create summary object
summary = Summary(
original_text=text,
summary_text=summary_text,
method=self.method,
num_sentences=num_sentences,
compression_ratio=compression_ratio,
quality_score=quality_score
)
# Cache result
if use_cache:
self.cache[cache_key] = summary
# Update stats
self.stats["total_summarizations"] += 1
self.stats["compression_ratios"].append(compression_ratio)
self.stats["avg_compression_ratio"] = sum(self.stats["compression_ratios"]) / len(self.stats["compression_ratios"])
self.stats["summaries_by_method"][self.method] += 1
return summary
def get_stats(self) -> Dict:
"""Get summarization statistics"""
total = max(self.stats["total_summarizations"], 1)
return {
"total_summarizations": self.stats["total_summarizations"],
"cache_hits": self.stats["cache_hits"],
"cache_hit_rate": self.stats["cache_hits"] / total,
"avg_compression_ratio": self.stats["avg_compression_ratio"],
"summaries_by_method": dict(self.stats["summaries_by_method"])
}
# Test production summarizer
prod_summarizer = ProductionSummarizer(method="extractive")
article = """
The stock market experienced significant volatility today as investors reacted to new inflation data.
The Dow Jones Industrial Average fell 1.5%, while the S&P 500 dropped 2.1%.
Technology stocks were particularly hard hit, with the Nasdaq Composite declining 3.2%.
Energy stocks bucked the trend, rising 2.3% on higher oil prices.
Analysts predict continued uncertainty in the coming weeks.
"""
# Summarize
summary = prod_summarizer.summarize(article, num_sentences=2)
print(f"Original ({len(article.split())} words):\n{article}")
print(f"\nSummary ({len(summary.summary_text.split())} words):\n{summary.summary_text}")
print(f"\nMetrics:")
print(f" Compression ratio: {summary.compression_ratio:.2%}")
print(f" Quality score: {summary.quality_score:.3f}")
print(f" Method: {summary.method}")
# Test cache
summary2 = prod_summarizer.summarize(article, num_sentences=2) # Cache hit
stats = prod_summarizer.get_stats()
print(f"\nStatistics:")
print(f" Total summarizations: {stats['total_summarizations']}")
print(f" Cache hits: {stats['cache_hits']}")
print(f" Avg compression: {stats['avg_compression_ratio']:.2%}")
Best PracticesΒΆ
1. Model SelectionΒΆ
News: BART-CNN, PEGASUS-CNN/DailyMail
Scientific papers: SciBERT-based models
Conversational: BART-XSUM for extreme summarization
Multi-document: Use hierarchical models
2. Handling Long DocumentsΒΆ
Chunking: Split into smaller segments (512-1024 tokens)
Sliding window: Overlap chunks for context
Hierarchical: Summarize chunks, then summarize summaries
Longformer/BigBird: Models with extended context (4096+ tokens)
3. Training TipsΒΆ
Use CNN/DailyMail or XSum datasets
Apply label smoothing (0.1)
Use beam search for generation (beam size 4)
Monitor ROUGE scores during training
Fine-tune on domain-specific data
4. Production OptimizationΒΆ
Cache frequent summaries
Use length penalties to control output
Implement quality checks (min/max length, coherence)
Monitor compression ratios
Common ChallengesΒΆ
Factual consistency: Models may hallucinate facts
Redundancy: Repeated information in long documents
Coherence: Generated text may lack flow
Length control: Balancing detail and brevity
Multi-document: Handling conflicting information
Evaluation MetricsΒΆ
ROUGE-1: Unigram overlap (content coverage)
ROUGE-2: Bigram overlap (fluency)
ROUGE-L: Longest common subsequence (structure)
BERTScore: Semantic similarity using embeddings
Human evaluation: Gold standard but expensive
Key TakeawaysΒΆ
β Extractive is simpler, abstractive is more flexible
β BART and PEGASUS are state-of-the-art for abstractive
β TextRank provides good extractive baselines
β ROUGE is the standard metric but has limitations
β Factual consistency is a major challenge
β Chunk long documents for better results
Next: 04_sentiment_analysis.ipynb - Sentiment Analysis