CAR/wly_experiment/dta_codes/translator/switch_cpu.py
2025-04-14 22:25:58 +08:00

423 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#This is the switch-local controller for the DTA translator
#Written by Jonatan Langlet for Direct Telemetry Access
import datetime
import ipaddress
import hashlib
import struct
import os
p4 = bfrt.dta_translator.pipe
mirror = bfrt.mirror
pre = bfrt.pre
logfile = "/root/wly_experiment/dta_results/dta_translator.log"
# add_with_XXX() 函数的关键字必须为小写, 否则无法正常运行 (识别不出来)
# 根据测试平台拓扑添加静态转发规则
forwardingRules = [
("192.168.1.91", 64), # Tofino CPU
("192.168.3.3", 148), # Generator
("192.168.4.3", 180) # Collector
]
# 将收集器的目标 IP 映射到出口端口 (确保所有这些端口都存在 mcRules)(用于 Key-Write 和 Key-Increment 原语)
collectorIPtoPorts = [
("192.168.4.3", 180),
]
# Key-Write 原语中每个插槽的大小, 默认为 8 字节 (4B+4B), 并确保这块与 P4 代码中的定义一致
keywrite_slot_size_B = 8
# Postcarder 原语中每个插槽的大小, 默认为 32 字节 (5x4B+12B), 其中 12 字节为 padding
# 这块的 padding 是为了实现 P4我们需要它是 2 的幂次
postcarder_slot_size_B = 32
# data list 的数量
num_data_lists = 4
# 多播规则, 用于将出口端口 (egress port) 和冗余 (redundancy) 映射到多播组 ID
mcRules = [
{
"mgid":1,
"egressPort":148,
"redundancy":1
},
{
"mgid":2,
"egressPort":148,
"redundancy":2
},
{
"mgid":3,
"egressPort":148,
"redundancy":3
},
{
"mgid":4,
"egressPort":148,
"redundancy":4
},
{
"mgid":5,
"egressPort":64,
"redundancy":1
},
{
"mgid":6,
"egressPort":64,
"redundancy":2
},
{
"mgid":7,
"egressPort":64,
"redundancy":3
},
{
"mgid":8,
"egressPort":64,
"redundancy":4
},
{
"mgid":9,
"egressPort":180,
"redundancy":1
},
{
"mgid":10,
"egressPort":180,
"redundancy":2
},
{
"mgid":11,
"egressPort":180,
"redundancy":3
},
{
"mgid":12,
"egressPort":180,
"redundancy":4
}
]
def log(text):
global logfile, datetime
line = "%s \t DigProc: %s" %(str(datetime.datetime.now()), str(text))
print(line)
f = open(logfile, "a")
f.write(line + "\n")
f.close()
def digest_callback(dev_id, pipe_id, direction, parser_id, session, msg):
global p4, log, Digest
# smac = p4.Ingress.smac
log("Received message from data plane!")
for dig in msg:
print(dig)
return 0
def bindDigestCallback():
global digest_callback, log, p4
try:
p4.SwitchIngressDeparser.debug_digest.callback_deregister()
except:
pass
finally:
log("Deregistering old callback function (if any)")
#Register as callback for digests (bind to DMA?)
log("Registering callback...")
p4.SwitchIngressDeparser.debug_digest.callback_register(digest_callback)
log("Bound callback to digest")
def insertForwardingRules():
''' 下发转发规则(DstIP -> Egress Port) '''
global p4, log, ipaddress, forwardingRules
log("Inserting forwarding rules...")
for dstAddr, egrPort in forwardingRules:
dstIP = ipaddress.ip_address(dstAddr)
log("%s->%i" %(dstIP, egrPort))
print(type(dstIP))
p4.SwitchIngress.tbl_forward.add_with_forward(dstaddr=dstIP, port=egrPort)
def insertKeyWriteRules():
''' 下发 Key-Write 规则(CollectorIP <-> EgressPort, EgressPort <-> redundancy, mgid) '''
global p4, log, ipaddress, collectorIPtoPorts, mcRules
log("Inserting KeyWrite rules...")
maxRedundancyLevel = 4
for collectorIP, egrPort in collectorIPtoPorts:
collectorIP_bin = ipaddress.ip_address(collectorIP)
for redundancyLevel in range(1, maxRedundancyLevel+1):
log("%s,%i,%i" %(collectorIP,egrPort,redundancyLevel))
# 从 mcRules 列表中查找到正确的多播组 ID (匹配 redundancy 和 egressPort)
rule = [ r for r in mcRules if r["redundancy"]==redundancyLevel and r["egressPort"]==egrPort ]
log(rule[0])
multicastGroupID = rule[0]["mgid"]
# multicastGroupID = 1 # Static for now. Update to match created multicast groups
log("Adding multiwrite rule %s, N=%i - %i" % (collectorIP, redundancyLevel, multicastGroupID))
p4.SwitchIngress.ProcessDTAPacket.tbl_Prep_KeyWrite.add_with_prep_MultiWrite(dstaddr=collectorIP_bin, redundancylevel=redundancyLevel, mcast_grp=multicastGroupID)
def getCollectorMetadata(port):
''' 获取 Collector 的元数据信息 '''
global log, os
# 存放 Metadata 的目录
metadata_dir = "/root/wly_experiment/dta_results/rdma_metadata/%i" % port
log("Setting up a new RDMA connection from virtual client... port %i dir %s" % (port, metadata_dir))
os.system("python3 /root/wly_experiment/dta_codes/translator/init_rdma_connection.py --port %i --dir %s" %(port, metadata_dir))
log("Reading collector metadata from disk...")
try:
# 队列对
f = open("%s/tmp_qpnum" % metadata_dir, "r")
queue_pair = int(f.read())
f.close()
# 起始的数据包序列号
f = open("%s/tmp_psn" % metadata_dir, "r")
start_psn = int(f.read())
f.close()
# 起始内存地址
f = open("%s/tmp_memaddr" % metadata_dir, "r")
memory_start = int(f.read())
f.close()
# 能够用于存放数据的长度
f = open("%s/tmp_memlen" % metadata_dir, "r")
memory_length = int(f.read())
f.close()
# 远程键(用于获取访问远端主机内存的权限)
f = open("%s/tmp_rkey" % metadata_dir, "r")
remote_key = int(f.read())
f.close()
except:
log(" !!! !!! Failed to read RDMA metadata !!! !!! ")
log("Collector metadata read from disk!")
return queue_pair, start_psn, memory_start, memory_length, remote_key
psn_reg_index = 0
def setupKeyvalConnection(port=1337):
''' 对 Keyval 连接所需要的信息进行设置 (端口号为 1337) '''
global p4, log, ipaddress, collectorIPtoPorts, getCollectorMetadata, psn_reg_index, keywrite_slot_size_B
# 端口号作为源队列对的编号
source_qp = port
print("Setting up KeyVal connection...")
# 初始化与键值存储相关的 RDMA 连接(首先根据端口号获取收集器的元数据信息)
queue_pair, start_psn, memory_start, memory_length, remote_key = getCollectorMetadata(port)
print("queue_pair", queue_pair)
for dstAddr, _ in collectorIPtoPorts:
dstIP = ipaddress.ip_address(dstAddr)
# 计算收集器中分配了多少个数据插槽,即 memory_length / (csum+data) (size in bytes)
collector_num_storage_slots = int(memory_length/keywrite_slot_size_B)
# 填充存放数据包序列号的寄存器
p4.SwitchEgress.CraftRDMA.reg_rdma_sequence_number.mod(f1=start_psn, REGISTER_INDEX=psn_reg_index)
log("Populating PSN-resynchronization lookup table for QP->regIndex mapping")
p4.SwitchEgress.RDMARatelimit.tbl_get_qp_reg_num.add_with_set_qp_reg_num(queue_pair=source_qp, qp_reg_index=psn_reg_index)
log("Inserting KeyWrite RDMA lookup rule for collector ip %s" %dstAddr)
print("psn_reg_index", psn_reg_index)
# 生成关于收集器的元数据信息的表项, 并将其填充到对应的表中
p4.SwitchEgress.PrepareKeyWrite.tbl_getCollectorMetadataFromIP.add_with_set_server_info(dstaddr=dstIP, remote_key=remote_key, queue_pair=queue_pair, memory_address_start=memory_start, collector_num_storage_slots=collector_num_storage_slots, qp_reg_index=psn_reg_index)
psn_reg_index += 1
def setupDatalistConnection():
''' 对 Append 连接所需要的信息进行设置 (端口号为 1338-1341) '''
global p4, log, getCollectorMetadata, psn_reg_index, num_data_lists
# 在此你需要指定与多少个 dataLists 建立连接,并用 (listID, rdmaCMPort) 元组列表的元数据来填充 ASIC
# lists = [(1,1338),(2,1339),(3,1340),(4,1341)] # 4 lists
# lists = [(1,1338),(2,1339),(3,1340)] # 3 lists
# lists = [(1,1338),(2,1339)] # 2 lists
# lists = [(1,1338)] # 1 list
# Append 原语中列表中每个插槽的大小,以字节为单位 (数据大小)
listSlotSize = 4
# Append 原语中列表的起始端口号
list_start_port = 1338
# for listID, port in lists:
for listID in range(num_data_lists):
# 根据起始端口号和列表 ID 生成当前列表的端口号(范围为 1338 到 1338 + lists_num - 1)
port = list_start_port + listID
print("Setting up dataList connection to list %i port %i..." % (listID, port))
# 初始化与 Append 相关的 RDMA 连接(首先根据端口号获取收集器的元数据信息)
queue_pair, start_psn, memory_start, memory_length, remote_key = getCollectorMetadata(port)
# 端口号作为源队列对的编号
source_qp = port
# 填充存放数据包序列号的寄存器
p4.SwitchEgress.CraftRDMA.reg_rdma_sequence_number.mod(f1=start_psn, REGISTER_INDEX=psn_reg_index)
log("Populating PSN-resynchronization lookup table for QP->regIndex mapping")
p4.SwitchEgress.RDMARatelimit.tbl_get_qp_reg_num.add_with_set_qp_reg_num(queue_pair=source_qp, qp_reg_index=psn_reg_index)
# 计算收集器中分配了多少个数据插槽,即 memory_length / (slot data size in bytes)
collector_num_storage_slots = int(memory_length / listSlotSize)
psn_reg_index = int(psn_reg_index)
log("Inserting Append-to-List RDMA lookup rule for listID %i" % listID)
print("psn_reg_index", psn_reg_index)
print("collector_num_storage_slots", collector_num_storage_slots)
# 生成关于收集器的元数据信息的表项, 并将其填充到对应的表中 (这里拆成两部分进行填充)
p4.SwitchEgress.PrepareAppend.tbl_getCollectorMetadataFromListID_1.add_with_set_server_info_1(listid=listID, remote_key=remote_key, queue_pair=queue_pair, memory_address_start=memory_start)
p4.SwitchEgress.PrepareAppend.tbl_getCollectorMetadataFromListID_2.add_with_set_server_info_2(listid=listID, collector_num_storage_slots=collector_num_storage_slots, qp_reg_index=psn_reg_index)
psn_reg_index += 1
def setupPostcarderConnection(port=1336):
global p4, log, ipaddress, collectorIPtoPorts, getCollectorMetadata, psn_reg_index, postcarder_slot_size_B
# 端口号作为源队列对的编号
source_qp = port
print("Setting up Postcarder connection...")
# 初始化与 Postcarder 相关的 RDMA 连接(首先根据端口号获取收集器的元数据信息)
queue_pair, start_psn, memory_start, memory_length, remote_key = getCollectorMetadata(port)
print("queue_pair", queue_pair)
for dstAddr, _ in collectorIPtoPorts:
dstIP = ipaddress.ip_address(dstAddr)
# 计算收集器中分配了多少个数据插槽,即 memory_length / (slot data size in bytes, i.e 32 Bytes)
collector_num_storage_slots = int(memory_length / postcarder_slot_size_B)
# 填充存放数据包序列号的寄存器
p4.SwitchEgress.CraftRDMA.reg_rdma_sequence_number.mod(f1=start_psn, REGISTER_INDEX=psn_reg_index)
log("Populating PSN-resynchronization lookup table for QP->regIndex mapping")
p4.SwitchEgress.RDMARatelimit.tbl_get_qp_reg_num.add_with_set_qp_reg_num(queue_pair=source_qp, qp_reg_index=psn_reg_index)
log("Inserting Postcarder RDMA lookup rule for collector ip %s" %dstAddr)
print("psn_reg_index", psn_reg_index)
# 生成关于收集器的元数据信息的表项, 并将其填充到对应的表中
p4.SwitchEgress.PreparePostcarder.tbl_getCollectorMetadataFromIP.add_with_set_server_info(dstaddr=dstIP, remote_key=remote_key, queue_pair=queue_pair, memory_address_start=memory_start, collector_num_storage_slots=collector_num_storage_slots, qp_reg_index=psn_reg_index)
psn_reg_index += 1
def insertCollectorMetadataRules():
''' 将收集器的 RDMA 元数据规则插入到 ASIC 中 (Append, KeyVal, Postcarder) '''
global p4, log, ipaddress, collectorIPtoPorts, getCollectorMetadata, setupKeyvalConnection, setupDatalistConnection, setupPostcarderConnection
log("Inserting RDMA metadata into ASIC...")
setupPostcarderConnection()
setupKeyvalConnection()
setupDatalistConnection()
# NOTE: this might break ALL rules about multicasting. Very hacky
def configMulticasting():
global p4, pre, log, mcRules
log("Configuring mirroring sessions...")
lastNodeID = 0
for mcastGroup in mcRules:
mgid = mcastGroup["mgid"]
egressPort = mcastGroup["egressPort"]
redundancy = mcastGroup["redundancy"]
log("Setting up multicast %i, egress port:%i, redundancy:%i" %(mgid, egressPort, redundancy))
nodeIDs = []
log("Adding multicast nodes...")
for i in range(redundancy):
lastNodeID += 1
log("Creating node %i" %lastNodeID)
# pre.node.add(DEV_PORT=[egressPort], MULTICAST_NODE_ID=lastNodeID)
pre.node.add(dev_port=[egressPort], multicast_node_id=lastNodeID)
nodeIDs.append(lastNodeID)
log("Creating the multicast group")
# pre.mgid.add(MGID=mgid, MULTICAST_NODE_ID=nodeIDs, MULTICAST_NODE_L1_XID=[0]*redundancy, MULTICAST_NODE_L1_XID_VALID=[False]*redundancy)
pre.mgid.add(mgid=mgid, multicast_node_id=nodeIDs, multicast_node_l1_xid=[0]*redundancy, multicast_node_l1_xid_valid=[False]*redundancy)
def configMirrorSessions():
global mirror, log
log("Configuring mirroring sessions...")
#TODO: fix truncation length
mirror.cfg.add_with_normal(sid=1, session_enable=True, ucast_egress_port=65, ucast_egress_port_valid=True, direction="BOTH", max_pkt_len=43) #Mirror header+Ethernet+IP
def populateTables():
global p4, log, insertForwardingRules, insertKeyWriteRules, insertCollectorMetadataRules
log("Populating the P4 tables...")
insertForwardingRules()
insertKeyWriteRules()
insertCollectorMetadataRules()
log("Starting")
# 已注释
# configMulticasting()
populateTables()
configMirrorSessions()
bindDigestCallback()
# log("Starting periodic injection of DTA write packet (keeping system alive)")
# os.system("watch \"sudo /home/sde/dta/translator/inject_dta.py keywrite --data 10000 --key 0 --redundancy 1\" &")
print("*** Now start period WRITE function manually")
print("*** Now start period WRITE function manually")
log("Bootstrap complete")