Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
*.egg-info/
dist/
build/

# IDE
.vscode/
.idea/
*.swp
*.swo
*~

# Checkpoints
checkpoints*/
*.pkl

# OS
.DS_Store
Thumbs.db
65 changes: 60 additions & 5 deletions exploration/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,56 @@
}

def split_into_sentences(text):
sentences = re.split(r'(?<=[.!?])\s+', text)
return [s.strip() for s in sentences if s.strip() and len(s.strip()) > 3]
"""Split text into sentences and return both sentences and their character positions.

Returns:
tuple: (sentences, positions) where:
- sentences: list of stripped sentence strings (for classification/display)
- positions: list of (start, end) tuples indicating character positions
in the original text (unstripped, for accurate slicing that
preserves formatting)

Note: Positions track the unstripped sentence boundaries in the original text,
allowing us to slice the original text with preserved formatting. This is
intentionally different from the stripped sentences which are used for
classification and don't need the extra whitespace.
"""
sentences = []
positions = []

# Split on sentence boundaries, keeping track of positions
# Use re.split with capturing groups to preserve separators and calculate positions
parts = re.split(r'((?<=[.!?])\s+)', text)

current_pos = 0
accumulated_text = ""

for i, part in enumerate(parts):
# Even indices are sentence content, odd indices are separators
if i % 2 == 0:
accumulated_text = part
else:
# We have a complete sentence with its separator
stripped = accumulated_text.strip()
if stripped and len(stripped) > 3:
# Calculate actual positions in original text
start = current_pos
end = current_pos + len(accumulated_text)
sentences.append(stripped)
positions.append((start, end))
current_pos += len(accumulated_text) + len(part)

# Handle the last sentence (no separator after it)
if parts and len(parts) % 2 == 1:
accumulated_text = parts[-1]
stripped = accumulated_text.strip()
if stripped and len(stripped) > 3:
start = current_pos
end = current_pos + len(accumulated_text)
sentences.append(stripped)
positions.append((start, end))

return sentences, positions

def get_sentence_token_positions(text, sentences, tokenizer):
input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
Expand Down Expand Up @@ -180,7 +228,7 @@ def get_hidden_state(text, layer=-1):
)
full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
cot = full_text[len(problem):].strip() if full_text.startswith(problem) else full_text
sentences = split_into_sentences(cot)
sentences, sentence_positions = split_into_sentences(cot)

if len(sentences) < 2:
print(f" ✗ Only {len(sentences)} sentences, skipping")
Expand All @@ -193,6 +241,7 @@ def get_hidden_state(text, layer=-1):
'problem': problem,
'cot': cot,
'sentences': sentences,
'sentence_positions': sentence_positions,
'causal_matrix': causal_matrix
}
del output_ids, input_ids
Expand Down Expand Up @@ -328,12 +377,18 @@ def classify_sentence(sentence):
for pid, anchors in tqdm(all_anchors.items(), desc="Extracting features"):
data = all_data[pid]
problem = data['problem']
sentences = data['sentences']
cot = data['cot']
sentence_positions = data['sentence_positions']
causal_matrix = data['causal_matrix']

for anchor in anchors:
idx = anchor['idx']
text_before = problem + " " + " ".join(sentences[:idx])
# Use character positions to slice original text instead of reconstructing
if idx > 0:
end_pos = sentence_positions[idx - 1][1]
text_before = problem + " " + cot[:end_pos]
else:
text_before = problem
hidden_state = get_hidden_state(text_before)
outgoing_feature = np.sum(np.abs(causal_matrix[idx, :]))
all_features.append({
Expand Down