11from __future__ import annotations
22
33import abc
4- import datetime as dt
54import itertools
65import logging
76from collections .abc import Generator
3029from databento .common .data import SCHEMA_DTYPES_MAP
3130from databento .common .data import SCHEMA_STRUCT_MAP
3231from databento .common .error import BentoError
33- from databento .common .symbology import InstrumentIdMappingInterval
32+ from databento .common .symbology import InstrumentMap
3433from databento .common .validation import validate_file_write_path
3534from databento .common .validation import validate_maybe_enum
3635from databento .live import DBNRecord
@@ -98,7 +97,6 @@ def format_dataframe(
9897 schema : Schema ,
9998 pretty_px : bool ,
10099 pretty_ts : bool ,
101- instrument_id_index : dict [dt .date , dict [int , str ]],
102100) -> pd .DataFrame :
103101 struct = SCHEMA_STRUCT_MAP [schema ]
104102
@@ -122,13 +120,6 @@ def format_dataframe(
122120 index_column = "ts_event" if schema .value .startswith ("ohlcv" ) else "ts_recv"
123121 df .set_index (index_column , inplace = True )
124122
125- if instrument_id_index :
126- df_index = df .index if pretty_ts else pd .to_datetime (df .index , utc = True )
127- dates = [ts .date () for ts in df_index ]
128- df ["symbol" ] = [
129- instrument_id_index [dates [i ]][p ] for i , p in enumerate (df ["instrument_id" ])
130- ]
131-
132123 return df
133124
134125
@@ -252,7 +243,12 @@ class MemoryDataSource(DataSource):
252243 """
253244
254245 def __init__ (self , source : BytesIO | bytes | IO [bytes ]):
255- initial_data = source if isinstance (source , bytes ) else source .read ()
246+ if isinstance (source , bytes ):
247+ initial_data = source
248+ else :
249+ source .seek (0 )
250+ initial_data = source .read ()
251+
256252 if len (initial_data ) == 0 :
257253 raise ValueError (
258254 f"Cannot create data source from empty { type (source ).__name__ } " ,
@@ -397,11 +393,7 @@ def __init__(self, data_source: DataSource) -> None:
397393 metadata_bytes .getvalue (),
398394 )
399395
400- # This is populated when _map_symbols is called
401- self ._instrument_id_index : dict [
402- dt .date ,
403- dict [int , str ],
404- ] = {}
396+ self ._instrument_map = InstrumentMap ()
405397
406398 def __iter__ (self ) -> Generator [DBNRecord , None , None ]:
407399 reader = self .reader
@@ -417,6 +409,8 @@ def __iter__(self) -> Generator[DBNRecord, None, None]:
417409 for record in records :
418410 if isinstance (record , databento_dbn .Metadata ):
419411 continue
412+ if isinstance (record , databento_dbn .SymbolMappingMsg ):
413+ self ._instrument_map .insert_symbol_mapping_msg (record )
420414 yield record
421415 else :
422416 if len (decoder .buffer ()) > 0 :
@@ -429,38 +423,6 @@ def __repr__(self) -> str:
429423 name = self .__class__ .__name__
430424 return f"<{ name } (schema={ self .schema } )>"
431425
432- def _build_instrument_id_index (self ) -> dict [dt .date , dict [int , str ]]:
433- intervals : list [InstrumentIdMappingInterval ] = []
434- for raw_symbol , i in self .mappings .items ():
435- for row in i :
436- symbol = row ["symbol" ]
437- if symbol == "" :
438- continue
439- intervals .append (
440- InstrumentIdMappingInterval (
441- start_date = row ["start_date" ],
442- end_date = row ["end_date" ],
443- raw_symbol = raw_symbol ,
444- instrument_id = int (row ["symbol" ]),
445- ),
446- )
447-
448- instrument_id_index : dict [dt .date , dict [int , str ]] = {}
449- for interval in intervals :
450- for ts in pd .date_range (
451- start = interval .start_date ,
452- end = interval .end_date ,
453- # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.date_range.html
454- ** {"inclusive" if pd .__version__ >= "1.4.0" else "closed" : "left" },
455- ):
456- d : dt .date = ts .date ()
457- date_map : dict [int , str ] = instrument_id_index .get (d , {})
458- if not date_map :
459- instrument_id_index [d ] = date_map
460- date_map [interval .instrument_id ] = interval .raw_symbol
461-
462- return instrument_id_index
463-
464426 @property
465427 def compression (self ) -> Compression :
466428 """
@@ -808,13 +770,20 @@ def request_symbology(self, client: Historical) -> dict[str, Any]:
808770 date range.
809771
810772 """
773+ if self .end is None :
774+ end_date = None
775+ elif self .start .date () == self .end .date ():
776+ end_date = (self .start + pd .Timedelta (days = 1 )).date ()
777+ else :
778+ end_date = self .end
779+
811780 return client .symbology .resolve (
812781 dataset = self .dataset ,
813782 symbols = self .symbols ,
814783 stype_in = self .stype_in ,
815784 stype_out = self .stype_out ,
816785 start_date = self .start .date (),
817- end_date = self . end . date () if self . end else None ,
786+ end_date = end_date ,
818787 )
819788
820789 def to_csv (
@@ -877,7 +846,7 @@ def to_df(
877846 self ,
878847 pretty_px : bool = ...,
879848 pretty_ts : bool = ...,
880- map_symbols : bool | None = ...,
849+ map_symbols : bool = ...,
881850 schema : Schema | str | None = ...,
882851 count : None = ...,
883852 ) -> pd .DataFrame :
@@ -888,7 +857,7 @@ def to_df(
888857 self ,
889858 pretty_px : bool = ...,
890859 pretty_ts : bool = ...,
891- map_symbols : bool | None = ...,
860+ map_symbols : bool = ...,
892861 schema : Schema | str | None = ...,
893862 count : int = ...,
894863 ) -> DataFrameIterator :
@@ -898,7 +867,7 @@ def to_df(
898867 self ,
899868 pretty_px : bool = True ,
900869 pretty_ts : bool = True ,
901- map_symbols : bool | None = None ,
870+ map_symbols : bool = True ,
902871 schema : Schema | str | None = None ,
903872 count : int | None = None ,
904873 ) -> pd .DataFrame | DataFrameIterator :
@@ -945,29 +914,22 @@ def to_df(
945914 raise ValueError ("a schema must be specified for mixed DBN data" )
946915 schema = self .schema
947916
948- if map_symbols is None :
949- map_symbols = self .stype_out == SType .INSTRUMENT_ID
950-
951- if map_symbols :
952- if self .stype_out != SType .INSTRUMENT_ID :
953- raise ValueError (
954- "`map_symbols` is not supported when `stype_out` is not 'instrument_id'" ,
955- )
956- if not self ._instrument_id_index :
957- self ._instrument_id_index = self ._build_instrument_id_index ()
958-
959917 if count is None :
960918 records = iter ([self .to_ndarray (schema )])
961919 else :
962920 records = self .to_ndarray (schema , count )
963921
922+ if map_symbols :
923+ self ._instrument_map .insert_metadata (self .metadata )
924+
964925 df_iter = DataFrameIterator (
965926 records = records ,
966927 schema = schema ,
967928 count = count ,
929+ instrument_map = self ._instrument_map ,
968930 pretty_px = pretty_px ,
969931 pretty_ts = pretty_ts ,
970- instrument_id_index = self . _instrument_id_index if map_symbols else {} ,
932+ map_symbols = map_symbols ,
971933 )
972934
973935 if count is None :
@@ -1111,7 +1073,7 @@ def to_ndarray(
11111073
11121074 dtype = SCHEMA_DTYPES_MAP [schema ]
11131075 ndarray_iter = NDArrayIterator (
1114- filter (lambda r : isinstance (r , SCHEMA_STRUCT_MAP [schema ]), self ), # type: ignore [arg-type]
1076+ filter (lambda r : isinstance (r , SCHEMA_STRUCT_MAP [schema ]), self ),
11151077 dtype ,
11161078 count ,
11171079 )
@@ -1163,30 +1125,38 @@ def __init__(
11631125 records : Iterator [np .ndarray [Any , Any ]],
11641126 count : int | None ,
11651127 schema : Schema ,
1128+ instrument_map : InstrumentMap ,
11661129 pretty_px : bool = True ,
11671130 pretty_ts : bool = True ,
1168- instrument_id_index : dict [ dt . date , dict [ int , str ]] | None = None ,
1131+ map_symbols : bool = True ,
11691132 ):
11701133 self ._records = records
11711134 self ._schema = schema
11721135 self ._count = count
11731136 self ._pretty_px = pretty_px
11741137 self ._pretty_ts = pretty_ts
1175- self ._instrument_id_index = (
1176- instrument_id_index if instrument_id_index is not None else {}
1177- )
1138+ self ._map_symbols = map_symbols
1139+ self ._instrument_map = instrument_map
11781140
11791141 def __iter__ (self ) -> DataFrameIterator :
11801142 return self
11811143
11821144 def __next__ (self ) -> pd .DataFrame :
1183- return format_dataframe (
1145+ df = format_dataframe (
11841146 pd .DataFrame (
11851147 next (self ._records ),
11861148 columns = SCHEMA_COLUMNS [self ._schema ],
11871149 ),
11881150 schema = self ._schema ,
11891151 pretty_px = self ._pretty_px ,
11901152 pretty_ts = self ._pretty_ts ,
1191- instrument_id_index = self ._instrument_id_index ,
11921153 )
1154+
1155+ if self ._map_symbols :
1156+ df_index = df .index if self ._pretty_ts else pd .to_datetime (df .index , utc = True )
1157+ dates = [ts .date () for ts in df_index ]
1158+ df ["symbol" ] = [
1159+ self ._instrument_map .resolve (inst , dates [i ]) for i , inst in enumerate (df ["instrument_id" ])
1160+ ]
1161+
1162+ return df
0 commit comments