arXiv Paper Classifier (distributed demo)¶
A full ML workflow that exercises every differentiated feature of Strata Notebook at once:
- Distributed workers. Nine cells, three different workers:
local,df-cluster(CPU-heavy DataFusion),gpu-fly(GPU for embeddings and training). Each cell declares its target worker with a single annotation; no deployment code. - Content-addressed caching. Re-run the notebook after an edit; every unchanged cell hits cache instantly regardless of which worker it runs on.
- Prompt cells. Two cells use the LLM assistant with
{{ variable }}injection to narrate the data at key points. - DAG invalidation. Change the model in the training cell, and every upstream cell (load, aggregate, embed) still hits cache — only training and evaluation re-execute.
The workflow¶
| # | Cell | Worker | What it does |
|---|---|---|---|
| 1 | load |
local | Load arXiv metadata |
| 2 | aggregate |
df-cluster | Group by category and year via DataFusion |
| 3 | themes |
local (LLM) | Prompt: identify research themes from stats |
| 4 | sample |
df-cluster | Stratified sample for the training set |
| 5 | embed |
gpu-fly | Generate sentence-transformer embeddings |
| 6 | clusters |
local (LLM) | Prompt: describe each paper cluster |
| 7 | train |
gpu-fly | Train a classifier on the embeddings |
| 8 | evaluate |
local | Compute accuracy and classification report |
| 9 | plot |
local | Visualize the confusion matrix |
Cells 1, 3, 6, 8, 9 run on the user's machine. Cells 2 and 4 run on the DataFusion worker — close to the data, avoids round-tripping the full dataset. Cells 5 and 7 run on the GPU worker.
Running it locally¶
Two local HTTP executors stand in for the cloud workers during development.
# Terminal 1 — start the two local worker processes
./examples/arxiv_classifier/run_local_workers.sh
# Terminal 2 — start Strata
uv run strata-server
# Terminal 3 — open the notebook in the UI
# http://localhost:8765
# Click "Open Notebook" and select examples/arxiv_classifier
Once the notebook is open, run all cells. You should see:
- The
worker: df-clusterbadge next to cells 2 and 4 - The
worker: gpu-flybadge next to cells 5 and 7 - A live
dispatching → df-clusterordispatching → gpu-flybadge (pulsing yellow) while each remote cell is executing - Green
✓status on every cell after the first run - Green
✓ cachedon every cell on the second run
Day 1 vs. real workload¶
As of Day 1 (April 2026), the cells run placeholder workloads — tiny DataFrames, fake embeddings, a trivial classifier — so we can validate the distributed plumbing without downloading a 500 MB dataset or setting up GPU deps. The workload is swapped in over the following days as the cloud workers come online. The worker annotations and cell structure stay the same.
Load arXiv Papers¶
kind python · worker local
# @name Load arXiv Papers
# @worker local
# Load arXiv ML papers from Hugging Face and assign topic categories.
# 118K real papers with titles and abstracts. We sample 20K and assign
# topics via keyword matching — a common first step when you have text
# data but no labels.
import re
import pandas as pd
DATASET_URL = (
"https://huggingface.co/api/datasets/CShorten/ML-ArXiv-Papers"
"/parquet/default/train/0.parquet"
)
SAMPLE_SIZE = 20_000
raw = pd.read_parquet(DATASET_URL, columns=["title", "abstract"])
papers = raw.dropna(subset=["abstract"]).head(SAMPLE_SIZE).reset_index(drop=True)
TOPIC_RULES = [
("reinforcement-learning", r"reinforcement|reward|policy gradient|Q-learning|MDP"),
("nlp", r"\bNLP\b|language model|translation|transformer|text classif|sentiment"),
("computer-vision", r"image|object detection|segmentation|convolutional|visual"),
("optimization", r"convex|gradient descent|convergence|optimization|stochastic"),
("generative", r"generative|GAN|diffusion|variational|autoencoder|VAE"),
]
def _assign_topic(text: str) -> str:
lower = text.lower()
for topic, pattern in TOPIC_RULES:
if re.search(pattern, lower, re.IGNORECASE):
return topic
return "other"
papers["topic"] = papers["abstract"].apply(_assign_topic)
print(f"Loaded {len(papers):,} papers, {papers['topic'].nunique()} topics")
print(papers["topic"].value_counts().to_string())
papers
Aggregate by Topic¶
kind python · worker df-cluster
# @name Aggregate by Topic
# @worker df-cluster
# Aggregate paper counts per topic using DataFusion SQL.
# Dispatched to the df-cluster worker which has DataFusion installed.
import pyarrow as pa
from datafusion import SessionContext
ctx = SessionContext()
table = pa.Table.from_pandas(papers)
ctx.register_record_batches("papers", [table.to_batches()])
category_stats = ctx.sql(
"""
SELECT
topic,
COUNT(*) AS paper_count,
ROUND(100.0 * COUNT(*) / SUM(COUNT(*)) OVER (), 1) AS pct
FROM papers
GROUP BY topic
ORDER BY paper_count DESC
"""
).to_pandas()
print("Topic distribution (DataFusion SQL):")
print(category_stats.to_string(index=False))
category_stats
research_themes¶
kind prompt · worker local
Prompt cell — response intentionally excluded from export.
# @name research_themes
# @worker local
# Prompt cell: LLM analysis of topic distribution from the DataFusion aggregation.
Given these arXiv ML paper topic counts derived from keyword classification:
{{ category_stats }}
For each topic, write one sentence describing what kinds of papers fall into it
and why this topic matters for the ML research community. Return as a numbered list.
Stratified Sample¶
kind python · worker df-cluster
# @name Stratified Sample
# @worker df-cluster
# Stratified sample: take up to 500 papers per topic so the embedding
# step runs in seconds, not minutes. Uses DataFusion window functions.
import pyarrow as pa
from datafusion import SessionContext
ctx = SessionContext()
table = pa.Table.from_pandas(papers)
ctx.register_record_batches("papers", [table.to_batches()])
PER_TOPIC = 500
sampled_papers = ctx.sql(
f"""
WITH ranked AS (
SELECT
title,
abstract,
topic,
ROW_NUMBER() OVER (PARTITION BY topic ORDER BY title) AS rn
FROM papers
)
SELECT title, abstract, topic
FROM ranked
WHERE rn <= {PER_TOPIC}
ORDER BY topic, title
"""
).to_pandas()
print(
f"Sampled {len(sampled_papers):,} papers across "
f"{sampled_papers['topic'].nunique()} topics (DataFusion SQL)"
)
print(sampled_papers["topic"].value_counts().to_string())
sampled_papers
Embed Abstracts¶
kind python · worker gpu-fly
# @name Embed Abstracts
# @worker gpu-fly
# @timeout 300
# Generate sentence-transformer embeddings for each paper's abstract.
# This is the expensive step: ~3K abstracts × 384-dim on an A10G GPU
# takes ~5 seconds. On CPU it would take ~90 seconds. On re-run it's
# instant — the artifact store caches the result keyed by the exact
# input data + model identity.
import numpy as np
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-MiniLM-L6-v2")
texts = sampled_papers["abstract"].tolist()
embeddings = model.encode(texts, show_progress_bar=True, batch_size=256)
embeddings = np.array(embeddings, dtype="float32")
print(f"Generated embeddings: {embeddings.shape} ({embeddings.nbytes / 1e6:.1f} MB)")
embeddings
cluster_descriptions¶
kind prompt · worker local
Prompt cell — response intentionally excluded from export.
# @name cluster_descriptions
# @worker local
# Prompt cell: generate a practical summary of the research themes.
The following research themes were identified from analyzing {{ category_stats }}:
{{ research_themes }}
Write a short paragraph (4-5 sentences) for a technical audience explaining
the practical significance of these themes. Focus on how they relate to each
other and what trends they suggest for the ML research community.
Train Classifier¶
kind python · worker local
# @name Train Classifier
# @worker local
# Train a logistic regression classifier: embeddings → topic label.
# Runs locally — logistic regression on 3K × 384 takes <1s on CPU,
# no reason to ship the data to a GPU worker and back.
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
X = embeddings
y = sampled_papers["topic"].to_numpy()
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
classifier = LogisticRegression(max_iter=1000, n_jobs=-1)
classifier.fit(X_train, y_train)
train_acc = classifier.score(X_train, y_train)
test_acc = classifier.score(X_test, y_test)
print(f"Train accuracy: {train_acc:.3f}")
print(f"Test accuracy: {test_acc:.3f}")
print(f"Classes: {list(classifier.classes_)}")
train_test_split_info = {
"train_size": len(X_train),
"test_size": len(X_test),
"train_acc": round(train_acc, 4),
"test_acc": round(test_acc, 4),
}
train_test_split_info
Evaluate Model¶
kind python · worker local
# @name Evaluate Model
# @worker local
# Evaluation on the held-out test set. Small data, runs locally.
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
X = embeddings
y = sampled_papers["topic"].to_numpy()
_, X_test, _, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
predictions = classifier.predict(X_test)
accuracy = accuracy_score(y_test, predictions)
report = classification_report(y_test, predictions, output_dict=True, zero_division=0)
print(f"Test accuracy: {accuracy:.3f}")
print()
print(classification_report(y_test, predictions, zero_division=0))
eval_results = {"accuracy": round(accuracy, 4), "per_class": report}
eval_results
Confusion Matrix¶
kind python · worker local
# @name Confusion Matrix
# @worker local
# Visualize the confusion matrix on the held-out test set.
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from sklearn.model_selection import train_test_split
X = embeddings
y = sampled_papers["topic"].to_numpy()
_, X_test, _, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
classes = sorted(set(y))
predictions = classifier.predict(X_test)
cm = confusion_matrix(y_test, predictions, labels=classes)
fig, ax = plt.subplots(figsize=(8, 6))
ConfusionMatrixDisplay(cm, display_labels=classes).plot(
ax=ax, cmap="Blues", colorbar=False, xticks_rotation=30
)
ax.set_title("arXiv Topic Classification — Confusion Matrix")
plt.tight_layout()
fig