Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 44 additions & 54 deletions steer/models/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,43 @@ def _decoder_layers(self):
For text models this is exactly self.model.model.layers — identical to before."""
return self._base_model().layers

@staticmethod
def _copy_activation_value(value):
if isinstance(value, t.Tensor):
return value.clone().detach()
try:
return copy.deepcopy(value)
except Exception:
return value

def _save_and_clear_generation_state(self, model_layers):
saved_layer_state = {}

for i, layer in enumerate(model_layers):
layer_state = {}
if hasattr(layer, 'add_activations_dict') and layer.add_activations_dict:
layer_state['add_activations_dict'] = {
key: self._copy_activation_value(value)
for key, value in layer.add_activations_dict.items()
}
layer.add_activations_dict = {}

if hasattr(layer, 'intervention_dict') and layer.intervention_dict:
layer_state['intervention_dict'] = dict(layer.intervention_dict)
layer.intervention_dict = {}

if layer_state:
saved_layer_state[i] = layer_state

return saved_layer_state

def _restore_generation_state(self, model_layers, saved_layer_state):
for i, layer_state in saved_layer_state.items():
if 'add_activations_dict' in layer_state:
model_layers[i].add_activations_dict = layer_state['add_activations_dict']
if 'intervention_dict' in layer_state:
model_layers[i].intervention_dict = layer_state['intervention_dict']

def _lm_head(self):
"""The output projection. Override if a subclass nests it elsewhere."""
return self.model.lm_head
Expand Down Expand Up @@ -587,24 +624,8 @@ def get_logits(self, tokens):
return logits

def ori_generate(self, input_ids, **kwargs):
# Save activation dictionaries
saved_activations = {}
model_layers = self._decoder_layers()

for i, layer in enumerate(model_layers):
if hasattr(layer, 'add_activations_dict') and layer.add_activations_dict:
saved_dict = {}
for key, value in layer.add_activations_dict.items():
if isinstance(value, t.Tensor):
saved_dict[key] = value.clone().detach()
else:
try:
saved_dict[key] = copy.deepcopy(value)
except:
saved_dict[key] = value

saved_activations[i] = saved_dict
layer.add_activations_dict = {}
saved_layer_state = self._save_and_clear_generation_state(model_layers)

# Save steer value if exists
saved_steer_values = t.zeros(1)
Expand All @@ -619,9 +640,7 @@ def ori_generate(self, input_ids, **kwargs):
**kwargs
)
finally:
# Restore activation dictionaries
for i, activations_dict in saved_activations.items():
model_layers[i].add_activations_dict = activations_dict
self._restore_generation_state(model_layers, saved_layer_state)

# Restore steer value
if saved_steer_values is not None and hasattr(self, 'steer'):
Expand All @@ -630,24 +649,8 @@ def ori_generate(self, input_ids, **kwargs):
return output

def ori_vllm_generate(self, input_batch, vllm_sampling_params):
# Save activation dictionaries
saved_activations = {}
model_layers = self._decoder_layers()

for i, layer in enumerate(model_layers):
if hasattr(layer, 'add_activations_dict') and layer.add_activations_dict:
saved_dict = {}
for key, value in layer.add_activations_dict.items():
if isinstance(value, t.Tensor):
saved_dict[key] = value.clone().detach()
else:
try:
saved_dict[key] = copy.deepcopy(value)
except:
saved_dict[key] = value

saved_activations[i] = saved_dict
layer.add_activations_dict = {}
saved_layer_state = self._save_and_clear_generation_state(model_layers)

# Save steer value if exists
saved_steer_values = t.zeros(1)
Expand All @@ -663,9 +666,7 @@ def ori_vllm_generate(self, input_batch, vllm_sampling_params):
)

finally:
# Restore activation dictionaries
for i, activations_dict in saved_activations.items():
model_layers[i].add_activations_dict = activations_dict
self._restore_generation_state(model_layers, saved_layer_state)

# Restore steer value
if saved_steer_values is not None and hasattr(self, 'steer'):
Expand Down Expand Up @@ -877,17 +878,8 @@ def reset(self, method_name):
raise ValueError(f"Method {method_name} not supported to reset")

def ori_generate(self, input_ids, **kwargs):
# Save activation dictionaries
saved_activations = {}
if hasattr(self.model, 'transformer') and isinstance(self.model.transformer, Hack_no_grad):
model_layers = self.model.transformer.module.h
else:
model_layers = self.model.transformer.h

for i, layer in enumerate(model_layers):
if hasattr(layer, 'add_activations_dict') and layer.add_activations_dict:
saved_activations[i] = copy.deepcopy(layer.add_activations_dict)
layer.add_activations_dict = {}
model_layers = self._decoder_layers()
saved_layer_state = self._save_and_clear_generation_state(model_layers)

# Save steer value if exists
saved_steer_value = 0
Expand All @@ -902,9 +894,7 @@ def ori_generate(self, input_ids, **kwargs):
**kwargs
)
finally:
# Restore activation dictionaries
for i, activations_dict in saved_activations.items():
model_layers[i].add_activations_dict = activations_dict
self._restore_generation_state(model_layers, saved_layer_state)

# Restore steer value
if saved_steer_value is not None and hasattr(self, 'steer'):
Expand Down