diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ff76526 --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +*.egg-info/ +dist/ +build/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Checkpoints +checkpoints*/ +*.pkl + +# OS +.DS_Store +Thumbs.db diff --git a/exploration/rpc.py b/exploration/rpc.py index 60b33c1..723d68f 100644 --- a/exploration/rpc.py +++ b/exploration/rpc.py @@ -57,8 +57,56 @@ } def split_into_sentences(text): - sentences = re.split(r'(?<=[.!?])\s+', text) - return [s.strip() for s in sentences if s.strip() and len(s.strip()) > 3] + """Split text into sentences and return both sentences and their character positions. + + Returns: + tuple: (sentences, positions) where: + - sentences: list of stripped sentence strings (for classification/display) + - positions: list of (start, end) tuples indicating character positions + in the original text (unstripped, for accurate slicing that + preserves formatting) + + Note: Positions track the unstripped sentence boundaries in the original text, + allowing us to slice the original text with preserved formatting. This is + intentionally different from the stripped sentences which are used for + classification and don't need the extra whitespace. + """ + sentences = [] + positions = [] + + # Split on sentence boundaries, keeping track of positions + # Use re.split with capturing groups to preserve separators and calculate positions + parts = re.split(r'((?<=[.!?])\s+)', text) + + current_pos = 0 + accumulated_text = "" + + for i, part in enumerate(parts): + # Even indices are sentence content, odd indices are separators + if i % 2 == 0: + accumulated_text = part + else: + # We have a complete sentence with its separator + stripped = accumulated_text.strip() + if stripped and len(stripped) > 3: + # Calculate actual positions in original text + start = current_pos + end = current_pos + len(accumulated_text) + sentences.append(stripped) + positions.append((start, end)) + current_pos += len(accumulated_text) + len(part) + + # Handle the last sentence (no separator after it) + if parts and len(parts) % 2 == 1: + accumulated_text = parts[-1] + stripped = accumulated_text.strip() + if stripped and len(stripped) > 3: + start = current_pos + end = current_pos + len(accumulated_text) + sentences.append(stripped) + positions.append((start, end)) + + return sentences, positions def get_sentence_token_positions(text, sentences, tokenizer): input_ids = tokenizer.encode(text, return_tensors="pt").to(device) @@ -180,7 +228,7 @@ def get_hidden_state(text, layer=-1): ) full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) cot = full_text[len(problem):].strip() if full_text.startswith(problem) else full_text - sentences = split_into_sentences(cot) + sentences, sentence_positions = split_into_sentences(cot) if len(sentences) < 2: print(f" ✗ Only {len(sentences)} sentences, skipping") @@ -193,6 +241,7 @@ def get_hidden_state(text, layer=-1): 'problem': problem, 'cot': cot, 'sentences': sentences, + 'sentence_positions': sentence_positions, 'causal_matrix': causal_matrix } del output_ids, input_ids @@ -328,12 +377,18 @@ def classify_sentence(sentence): for pid, anchors in tqdm(all_anchors.items(), desc="Extracting features"): data = all_data[pid] problem = data['problem'] - sentences = data['sentences'] + cot = data['cot'] + sentence_positions = data['sentence_positions'] causal_matrix = data['causal_matrix'] for anchor in anchors: idx = anchor['idx'] - text_before = problem + " " + " ".join(sentences[:idx]) + # Use character positions to slice original text instead of reconstructing + if idx > 0: + end_pos = sentence_positions[idx - 1][1] + text_before = problem + " " + cot[:end_pos] + else: + text_before = problem hidden_state = get_hidden_state(text_before) outgoing_feature = np.sum(np.abs(causal_matrix[idx, :])) all_features.append({