Phase 2 · Model Modification
Priority: High Status: Pending Depends on: Phase 1 (event encoder)
Context
Section titled “Context”- Kronos
forward()atkronos.py:239uses additive temporal embedding pattern - Temporal embedding added at line 257:
x = x + time_embedding - Event embedding follows identical pattern — additive to token embedding
Overview
Section titled “Overview”Add EventEmbedding module to Kronos predictor. Minimal code change: new embedding layer + 3 lines in forward(). BSQ tokenizer, HierarchicalEmbedding, DualHead, and temporal embeddings remain frozen.
Requirements
Section titled “Requirements”Functional
Section titled “Functional”- Accept optional
(B, T, 20)event tensor inforward(),decode_s1(),decode_s2() - Project event features to
d_modeldimensions via linear layer - Add event embedding to combined embedding (same as temporal)
- Maintain backward compatibility — events=None produces identical output to base model
Non-Functional
Section titled “Non-Functional”- Zero overhead when events=None (skip projection entirely)
- New params: ~5.2K (Linear(20, d_model=256)) — negligible VRAM increase
Architecture
Section titled “Architecture”New Module: EventEmbedding
Section titled “New Module: EventEmbedding”class EventEmbedding(nn.Module): def __init__(self, num_event_channels=20, d_model=256): super().__init__() self.proj = nn.Linear(num_event_channels, d_model) nn.init.xavier_uniform_(self.proj.weight) nn.init.zeros_(self.proj.bias)
def forward(self, events): # events: (B, T, 20) or None if events is None: return 0 # additive identity return self.proj(events.float())Modified Kronos.forward()
Section titled “Modified Kronos.forward()”# Current (lines 254-258):x = self.embedding([s1_ids, s2_ids])if stamp is not None: time_embedding = self.time_emb(stamp) x = x + time_embeddingx = self.token_drop(x)
# Modified:x = self.embedding([s1_ids, s2_ids])if stamp is not None: time_embedding = self.time_emb(stamp) x = x + time_embeddingevent_embedding = self.event_emb(events) # NEW — returns 0 if events is Nonex = x + event_embedding # NEWx = self.token_drop(x)Modified auto_regressive_inference()
Section titled “Modified auto_regressive_inference()”- Accept optional
eventsparameter(B, T, 10)for context +(B, pred_len, 10)for prediction window - Concatenate context + prediction event tensors aligned with
full_stamp - Pass events to
model.decode_s1()andmodel.decode_s2()
Implementation Steps
Section titled “Implementation Steps”- Add
EventEmbeddingclass tokronos-service/kronos_lib/model/module.py - Add
self.event_emb = EventEmbedding(20, d_model)toKronos.__init__() - Modify
Kronos.forward()— add event embedding injection - Modify
Kronos.decode_s1()— pass events through - Modify
Kronos.decode_s2()— unchanged (receives context from decode_s1) - Modify
auto_regressive_inference()— accept and route events - Modify
KronosPredictor.predict()— accept events DataFrame - Test: verify events=None produces identical output to base model (regression check)
Key Files
Section titled “Key Files”- Modify:
kronos-service/kronos_lib/model/module.py— add EventEmbedding class - Modify:
kronos-service/kronos_lib/model/kronos.py— modify Kronos class + inference
Success Criteria
Section titled “Success Criteria”-
model.forward(s1, s2, stamp, events=None)produces identical output to unmodified Kronos -
model.forward(s1, s2, stamp, events=tensor)produces different output - No VRAM increase when events=None
- All existing tests pass without modification (backward compatible)
- New unit test: verify event embedding gradient flows during training