init
This commit is contained in:
143
game_server/net/gateway.py
Normal file
143
game_server/net/gateway.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import asyncio
|
||||
from nazo_rand import randrange
|
||||
from game_server.net.kcp import get_conv
|
||||
from game_server.net.packet import NetOperation
|
||||
from game_server.net.session import PlayerSession
|
||||
from utils.logger import Info,Warn,Debug
|
||||
from database.mongodb import get_database
|
||||
|
||||
|
||||
class KCPGateway(asyncio.DatagramProtocol):
|
||||
def __init__(self, db):
|
||||
self.id_counter = 0
|
||||
self.sessions : dict[int, PlayerSession] = {}
|
||||
self.running = True
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self.db = db
|
||||
self.timeout_check_interval = 5
|
||||
self.save_duration = 60
|
||||
|
||||
async def check_sessions_timeout(self):
|
||||
while self.running:
|
||||
for conv_id, session in list(self.sessions.items()):
|
||||
if session.is_timeout():
|
||||
self.drop_kcp_session(conv_id)
|
||||
await asyncio.sleep(self.timeout_check_interval)
|
||||
|
||||
async def periodic_save_sessions(self):
|
||||
while self.running:
|
||||
start = asyncio.get_event_loop().time()
|
||||
|
||||
for session in self.sessions.values():
|
||||
session.player.save_all()
|
||||
|
||||
elapsed = asyncio.get_event_loop().time() - start
|
||||
Info(f"Database saved in {elapsed:.2f}s")
|
||||
|
||||
await asyncio.sleep(self.save_duration)
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
Info(f"Listening on {self.transport.get_extra_info('sockname')}")
|
||||
|
||||
asyncio.create_task(self.check_sessions_timeout())
|
||||
asyncio.create_task(self.periodic_save_sessions())
|
||||
|
||||
def datagram_received(self, data, addr):
|
||||
data_len = len(data)
|
||||
|
||||
Debug(f"Received {data_len} bytes from {addr}: {[b for b in data]}")
|
||||
|
||||
if data_len == 20:
|
||||
self.process_net_operation(NetOperation.from_bytes(data), addr)
|
||||
elif data_len >= 28:
|
||||
self.process_kcp_payload(data, addr)
|
||||
else:
|
||||
Warn("Unknown data length received")
|
||||
|
||||
def process_net_operation(self, op: NetOperation, addr):
|
||||
if (op.head, op.tail) == (0xFF, 0xFFFFFFFF):
|
||||
self.establish_kcp_session(op.data, addr)
|
||||
elif (op.head, op.tail) == (0x194, 0x19419494):
|
||||
self.drop_kcp_session(op.conv_id)
|
||||
else:
|
||||
Warn(f"Unknown magic pair: {op.head}-{op.tail}")
|
||||
|
||||
def establish_kcp_session(self, data, addr):
|
||||
conv_id, session_token = self.next_conv_pair()
|
||||
session_id = conv_id << 32 | session_token
|
||||
|
||||
Info(f"New connection: {addr} with conv_id: {conv_id}")
|
||||
|
||||
self.sessions[conv_id] = PlayerSession(
|
||||
self.transport, session_id, addr, self.db
|
||||
)
|
||||
|
||||
net_op = NetOperation(
|
||||
head=0x145,
|
||||
conv_id=conv_id,
|
||||
session_token=session_token,
|
||||
data=data,
|
||||
tail=0x14514545,
|
||||
).to_bytes()
|
||||
|
||||
self.transport.sendto(net_op, addr)
|
||||
|
||||
def drop_kcp_session(self, conv_id):
|
||||
if conv_id in self.sessions:
|
||||
self.sessions[conv_id].player.save_all()
|
||||
del self.sessions[conv_id]
|
||||
Info(f"Dropped KCP session with conv_id: {conv_id}")
|
||||
else:
|
||||
Warn(
|
||||
f"Attempted to drop non-existent KCP session with conv_id: {conv_id}"
|
||||
)
|
||||
|
||||
def process_kcp_payload(self, data, addr):
|
||||
conv_id = get_conv(data)
|
||||
session = self.sessions.get(conv_id)
|
||||
|
||||
if session:
|
||||
session.update_last_received()
|
||||
asyncio.create_task(session.consume(data))
|
||||
else:
|
||||
Warn(f"Received data for unknown session conv_id: {conv_id}")
|
||||
|
||||
def next_conv_pair(self):
|
||||
self.id_counter += 1
|
||||
return self.id_counter, randrange(0, 0xFF)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
Info("UDP connection lost, shutting down...")
|
||||
|
||||
for session in self.sessions.values():
|
||||
session.player.save_all()
|
||||
|
||||
self.running = False
|
||||
self.shutdown_event.set()
|
||||
|
||||
def shutdown(self):
|
||||
Info("Shutting down server...")
|
||||
|
||||
for session in self.sessions.values():
|
||||
session.player.save_all()
|
||||
|
||||
self.running = False
|
||||
self.shutdown_event.set()
|
||||
|
||||
@staticmethod
|
||||
async def new(host, port):
|
||||
loop = asyncio.get_running_loop()
|
||||
db = get_database()
|
||||
|
||||
transport, protocol = await loop.create_datagram_endpoint(
|
||||
lambda: KCPGateway(db), local_addr=(host, port)
|
||||
)
|
||||
|
||||
try:
|
||||
await protocol.shutdown_event.wait()
|
||||
except asyncio.CancelledError:
|
||||
Info("Server tasks cancelled.")
|
||||
finally:
|
||||
transport.close()
|
||||
Info("Server stopped.")
|
||||
657
game_server/net/kcp.py
Normal file
657
game_server/net/kcp.py
Normal file
@@ -0,0 +1,657 @@
|
||||
import struct
|
||||
from collections import deque
|
||||
from typing import Callable
|
||||
|
||||
# thanks mero for this kcp lib :3
|
||||
# this file is excluded from the CC0-1.0 License
|
||||
# this file is licensed under GPL-3.0 from the author written below
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Mero <mero@crepe.moe>"
|
||||
|
||||
|
||||
IKCP_RTO_NDL = 30
|
||||
IKCP_RTO_MIN = 100
|
||||
IKCP_RTO_DEF = 200
|
||||
IKCP_RTO_MAX = 60000
|
||||
IKCP_CMD_PUSH = 81
|
||||
IKCP_CMD_ACK = 82
|
||||
IKCP_CMD_WASK = 83
|
||||
IKCP_CMD_WINS = 84
|
||||
IKCP_ASK_SEND = 1
|
||||
IKCP_ASK_TELL = 2
|
||||
IKCP_WND_SND = 32
|
||||
IKCP_WND_RCV = 128
|
||||
IKCP_MTU_DEF = 1400
|
||||
IKCP_ACK_FAST = 3
|
||||
IKCP_INTERVAL = 100
|
||||
IKCP_DEADLINK = 20
|
||||
IKCP_THRESH_INIT = 2
|
||||
IKCP_THRESH_MIN = 2
|
||||
IKCP_PROBE_INIT = 7000
|
||||
IKCP_PROBE_LIMIT = 120000
|
||||
IKCP_FASTACK_LIMIT = 5
|
||||
|
||||
IKCP_PACKET_HEAD_FORMAT = "<IIBBHIIII"
|
||||
IKCP_OVERHEAD = struct.calcsize(IKCP_PACKET_HEAD_FORMAT)
|
||||
|
||||
|
||||
def get_conv(buf: bytes) -> int:
|
||||
assert len(buf) >= IKCP_OVERHEAD
|
||||
return struct.unpack_from("<I", buf, 0)[0]
|
||||
|
||||
|
||||
class KcpException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class KcpSegment:
|
||||
__slots__ = (
|
||||
"session_id",
|
||||
"cmd",
|
||||
"frg",
|
||||
"wnd",
|
||||
"ts",
|
||||
"sn",
|
||||
"una",
|
||||
"data",
|
||||
"resendts",
|
||||
"rto",
|
||||
"fastack",
|
||||
"xmit",
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = b""
|
||||
|
||||
self.resendts = 0
|
||||
self.rto = 0
|
||||
self.fastack = 0
|
||||
self.xmit = 0
|
||||
|
||||
def parse(self, data):
|
||||
conv, token, cmd, frg, wnd, ts, sn, una, len = struct.unpack(
|
||||
IKCP_PACKET_HEAD_FORMAT, data[:IKCP_OVERHEAD]
|
||||
)
|
||||
|
||||
self.session_id = conv << 32 | token
|
||||
self.cmd = cmd
|
||||
self.frg = frg
|
||||
self.wnd = wnd
|
||||
|
||||
self.ts = ts
|
||||
self.sn = sn
|
||||
self.una = una
|
||||
self.data = data[IKCP_OVERHEAD : IKCP_OVERHEAD + len]
|
||||
|
||||
return IKCP_OVERHEAD + len
|
||||
|
||||
def encode(self) -> bytes:
|
||||
conv = self.session_id >> 32
|
||||
token = self.session_id & 0xFFFFFFFF
|
||||
|
||||
return (
|
||||
struct.pack(
|
||||
IKCP_PACKET_HEAD_FORMAT,
|
||||
conv,
|
||||
token,
|
||||
self.cmd,
|
||||
self.frg,
|
||||
self.wnd,
|
||||
int(self.ts),
|
||||
self.sn,
|
||||
self.una,
|
||||
int(len(self.data)),
|
||||
)
|
||||
+ self.data
|
||||
)
|
||||
|
||||
|
||||
class Kcp:
|
||||
__slots__ = (
|
||||
"session_id",
|
||||
"current",
|
||||
"rx_srtt",
|
||||
"rx_rttval",
|
||||
"snd_wnd",
|
||||
"interval",
|
||||
"rx_minrto",
|
||||
"snd_nxt",
|
||||
"rmt_wnd",
|
||||
"snd_buf",
|
||||
"snd_una",
|
||||
"snd_queue",
|
||||
"updated",
|
||||
"ts_flush",
|
||||
"xmit",
|
||||
"state",
|
||||
"ts_probe",
|
||||
"probe_wait",
|
||||
"use_fastask_conserve",
|
||||
"rcv_nxt",
|
||||
"rcv_wnd",
|
||||
"rcv_buf",
|
||||
"rcv_queue",
|
||||
"probe",
|
||||
"acklist",
|
||||
"cwnd",
|
||||
"mtu",
|
||||
"mss",
|
||||
"ssthresh",
|
||||
"incr",
|
||||
"rx_rto",
|
||||
"stream",
|
||||
"output",
|
||||
"nodelay",
|
||||
"nocwnd",
|
||||
"dead_link",
|
||||
"fastresend",
|
||||
"fastlimit",
|
||||
)
|
||||
|
||||
def __init__(self, session_id: int, output: Callable[[bytes], None]):
|
||||
assert session_id < 1 << 64
|
||||
|
||||
self.use_fastask_conserve = False
|
||||
self.session_id = session_id
|
||||
self.output = output
|
||||
|
||||
self.snd_una = 0
|
||||
self.snd_nxt = 0
|
||||
self.rcv_nxt = 0
|
||||
|
||||
self.ts_probe = 0
|
||||
self.probe_wait = 0
|
||||
|
||||
self.snd_wnd = IKCP_WND_SND
|
||||
self.rcv_wnd = IKCP_WND_RCV
|
||||
self.rmt_wnd = IKCP_WND_RCV
|
||||
|
||||
self.cwnd = 0
|
||||
self.incr = 0
|
||||
self.probe = 0
|
||||
|
||||
self.mtu = IKCP_MTU_DEF
|
||||
self.mss = self.mtu - IKCP_OVERHEAD
|
||||
self.stream = False
|
||||
|
||||
self.snd_buf = deque()
|
||||
self.rcv_buf = deque()
|
||||
self.rcv_queue = deque()
|
||||
self.snd_queue = deque()
|
||||
|
||||
self.state = 0
|
||||
self.acklist = deque()
|
||||
|
||||
self.rx_srtt = 0
|
||||
self.rx_rttval = 0
|
||||
self.rx_rto = IKCP_RTO_DEF
|
||||
self.rx_minrto = IKCP_RTO_MIN
|
||||
|
||||
self.current = 0
|
||||
self.interval = IKCP_INTERVAL
|
||||
self.ts_flush = IKCP_INTERVAL
|
||||
self.nodelay = 0
|
||||
self.updated = False
|
||||
|
||||
self.ssthresh = IKCP_THRESH_INIT
|
||||
self.fastresend = 0
|
||||
self.fastlimit = IKCP_FASTACK_LIMIT
|
||||
self.nocwnd = False
|
||||
self.xmit = 0
|
||||
self.dead_link = IKCP_DEADLINK
|
||||
|
||||
def parse_una(self, una):
|
||||
while self.snd_buf:
|
||||
seg = self.snd_buf[0]
|
||||
|
||||
if seg.sn >= una:
|
||||
break
|
||||
|
||||
self.snd_buf.popleft()
|
||||
|
||||
def shrink_buf(self):
|
||||
self.snd_una = self.snd_buf[0].sn if self.snd_buf else self.snd_nxt
|
||||
|
||||
def update_ack(self, rtt):
|
||||
if self.rx_srtt == 0:
|
||||
self.rx_srtt = rtt
|
||||
self.rx_rttval = rtt // 2
|
||||
else:
|
||||
delta = abs(rtt - self.rx_srtt)
|
||||
self.rx_rttval = (3 * self.rx_rttval + delta) // 4
|
||||
self.rx_srtt = max((7 * self.rx_srtt + rtt) // 8, 1)
|
||||
|
||||
rto = self.rx_srtt + max(self.interval, 4 * self.rx_rttval)
|
||||
self.rx_rto = min(max(self.rx_minrto, rto), IKCP_RTO_MAX)
|
||||
|
||||
def parse_ack(self, sn):
|
||||
if self.snd_una > sn or self.snd_nxt <= sn:
|
||||
return
|
||||
|
||||
for seg in self.snd_buf:
|
||||
if sn == seg.sn:
|
||||
self.snd_buf.remove(seg)
|
||||
break
|
||||
|
||||
if seg.sn > sn:
|
||||
break
|
||||
|
||||
def move_buf(self):
|
||||
while self.rcv_buf:
|
||||
seg = self.rcv_buf[0]
|
||||
if seg.sn != self.rcv_nxt or len(self.rcv_queue) >= self.rcv_wnd:
|
||||
break
|
||||
|
||||
self.rcv_nxt += 1
|
||||
self.rcv_queue.append(self.rcv_buf.popleft())
|
||||
|
||||
def parse_data(self, newseg):
|
||||
if (self.rcv_nxt + self.rcv_wnd) <= newseg.sn or self.rcv_nxt > newseg.sn:
|
||||
return
|
||||
|
||||
repeat = False
|
||||
new_index = len(self.rcv_buf)
|
||||
|
||||
for seg in reversed(self.rcv_buf):
|
||||
if seg.sn == newseg.sn:
|
||||
repeat = True
|
||||
break
|
||||
|
||||
if seg.sn < newseg.sn:
|
||||
break
|
||||
|
||||
new_index -= 1
|
||||
|
||||
if not repeat:
|
||||
self.rcv_buf.insert(new_index, newseg)
|
||||
|
||||
self.move_buf()
|
||||
|
||||
def parse_fastack(self, sn, ts):
|
||||
if self.snd_una > sn or self.snd_nxt <= sn:
|
||||
return
|
||||
|
||||
for seg in self.snd_buf:
|
||||
if seg.sn > sn:
|
||||
break
|
||||
elif sn != seg.sn and (self.use_fastask_conserve or seg.ts <= ts):
|
||||
seg.fastack += 1
|
||||
|
||||
def input(self, data: bytes):
|
||||
if not data or len(data) < IKCP_OVERHEAD:
|
||||
raise KcpException(f"data size must be greater than {IKCP_OVERHEAD}")
|
||||
|
||||
maxack = 0
|
||||
latest_ts = 0
|
||||
flag = False
|
||||
prev_una = self.snd_una
|
||||
|
||||
while len(data) >= IKCP_OVERHEAD:
|
||||
seg = KcpSegment()
|
||||
data = data[seg.parse(data) :]
|
||||
|
||||
if seg.session_id != self.session_id:
|
||||
raise KcpException(
|
||||
f"wrong session id, got {seg.session_id} but {self.session_id} was expected"
|
||||
)
|
||||
|
||||
if seg.cmd not in (
|
||||
IKCP_CMD_PUSH,
|
||||
IKCP_CMD_ACK,
|
||||
IKCP_CMD_WASK,
|
||||
IKCP_CMD_WINS,
|
||||
):
|
||||
raise KcpException(f"unknown kcp cmd {seg.cmd}")
|
||||
|
||||
self.rmt_wnd = seg.wnd
|
||||
|
||||
self.parse_una(seg.una)
|
||||
self.shrink_buf()
|
||||
|
||||
if seg.cmd == IKCP_CMD_ACK:
|
||||
rtt = self.current - seg.ts
|
||||
if rtt >= 0:
|
||||
self.update_ack(rtt)
|
||||
|
||||
self.parse_ack(seg.sn)
|
||||
self.shrink_buf()
|
||||
|
||||
if not flag:
|
||||
flag = True
|
||||
maxack = seg.sn
|
||||
latest_ts = seg.ts
|
||||
elif maxack < seg.sn and (
|
||||
self.use_fastask_conserve or latest_ts > seg.ts
|
||||
):
|
||||
maxack = seg.sn
|
||||
latest_ts = seg.ts
|
||||
|
||||
elif seg.cmd == IKCP_CMD_PUSH:
|
||||
if self.rcv_nxt + self.rcv_wnd > seg.sn:
|
||||
self.acklist.append((seg.sn, seg.ts))
|
||||
if self.rcv_nxt <= seg.sn:
|
||||
self.parse_data(seg)
|
||||
|
||||
elif seg.cmd == IKCP_CMD_WASK:
|
||||
self.probe |= IKCP_ASK_TELL
|
||||
|
||||
if flag:
|
||||
self.parse_fastack(maxack, latest_ts)
|
||||
|
||||
if self.snd_una - prev_una > 0 and self.cwnd < self.rmt_wnd:
|
||||
mss = self.mss
|
||||
if self.cwnd < self.ssthresh:
|
||||
self.cwnd += 1
|
||||
self.incr += mss
|
||||
else:
|
||||
if self.incr < mss:
|
||||
self.incr = mss
|
||||
self.incr += mss * mss // self.incr + mss // 16
|
||||
if (self.cwnd + 1) * mss <= self.incr:
|
||||
self.cwnd = (self.incr + mss - 1) // mss if mss > 0 else 1
|
||||
if self.cwnd > self.rmt_wnd:
|
||||
self.cwnd = self.rmt_wnd
|
||||
self.incr = self.rmt_wnd * mss
|
||||
|
||||
def peeksize(self):
|
||||
if not self.rcv_queue:
|
||||
return -1
|
||||
|
||||
seg = self.rcv_queue[0]
|
||||
if seg.frg == 0:
|
||||
return len(seg.data)
|
||||
if len(self.rcv_queue) < seg.frg + 1:
|
||||
return -1
|
||||
|
||||
length = 0
|
||||
for seg in self.rcv_queue:
|
||||
length += len(seg.data)
|
||||
if seg.frg == 0:
|
||||
break
|
||||
|
||||
return length
|
||||
|
||||
def recv(self) -> bytes | None:
|
||||
if not self.rcv_queue:
|
||||
return None
|
||||
|
||||
peeksize = self.peeksize()
|
||||
if peeksize < 0:
|
||||
return None
|
||||
|
||||
recover = len(self.rcv_queue) >= self.rcv_wnd
|
||||
data = b""
|
||||
|
||||
while seg := self.rcv_queue.popleft():
|
||||
data += seg.data
|
||||
if seg.frg == 0:
|
||||
break
|
||||
|
||||
assert len(data) == peeksize
|
||||
self.move_buf()
|
||||
|
||||
if len(self.rcv_queue) < self.rcv_wnd and recover:
|
||||
self.probe |= IKCP_ASK_TELL
|
||||
|
||||
return data
|
||||
|
||||
def send(self, data: bytes):
|
||||
assert self.mss > 0
|
||||
|
||||
if self.stream:
|
||||
if self.snd_queue:
|
||||
seg = self.snd_queue[-1]
|
||||
if len(seg.data) < self.mss:
|
||||
capacity = self.mss - len(seg.data)
|
||||
extend = min(len(data), capacity)
|
||||
|
||||
seg.data += data[:extend]
|
||||
data = data[extend:]
|
||||
seg.frg = 0
|
||||
|
||||
if not data:
|
||||
return
|
||||
|
||||
count = (len(data) + self.mss - 1) // self.mss if len(data) > self.mss else 1
|
||||
if count >= IKCP_WND_RCV:
|
||||
raise KcpException("user buffer is too long")
|
||||
count = max(count, 1)
|
||||
|
||||
for i in range(count):
|
||||
size = min(self.mss, len(data))
|
||||
|
||||
newseg = KcpSegment()
|
||||
newseg.data = data[:size]
|
||||
newseg.frg = 0 if self.stream else count - i - 1
|
||||
|
||||
data = data[size:]
|
||||
self.snd_queue.append(newseg)
|
||||
|
||||
def update(self, current: int):
|
||||
assert current < 1 << 32
|
||||
self.current = current
|
||||
|
||||
if not self.updated:
|
||||
self.updated = True
|
||||
self.ts_flush = self.current
|
||||
|
||||
slap = self.current - self.ts_flush
|
||||
|
||||
if slap >= 10000 or slap < -10000:
|
||||
self.ts_flush = self.current
|
||||
slap = 0
|
||||
|
||||
if slap >= 0:
|
||||
self.ts_flush += self.interval
|
||||
if self.ts_flush <= self.current:
|
||||
self.ts_flush = self.current + self.interval
|
||||
self.flush()
|
||||
|
||||
def wnd_unused(self):
|
||||
return max(self.rcv_wnd - len(self.rcv_queue), 0)
|
||||
|
||||
def flush(self):
|
||||
if not self.updated:
|
||||
return
|
||||
|
||||
seg = KcpSegment()
|
||||
seg.session_id = self.session_id
|
||||
seg.cmd = IKCP_CMD_ACK
|
||||
seg.frg = 0
|
||||
seg.wnd = self.wnd_unused()
|
||||
seg.una = self.rcv_nxt
|
||||
seg.sn = 0
|
||||
seg.ts = 0
|
||||
|
||||
data = b""
|
||||
for sn, ts in self.acklist:
|
||||
if len(data) + IKCP_OVERHEAD > self.mtu:
|
||||
self.output(data)
|
||||
data = b""
|
||||
seg.sn = sn
|
||||
seg.ts = ts
|
||||
data += seg.encode()
|
||||
self.acklist.clear()
|
||||
|
||||
if self.rmt_wnd == 0:
|
||||
if self.probe_wait == 0:
|
||||
self.probe_wait = IKCP_PROBE_INIT
|
||||
self.ts_probe = self.current + self.probe_wait
|
||||
elif self.ts_probe <= self.current:
|
||||
self.probe_wait = min(
|
||||
self.probe_wait + max(self.probe_wait, IKCP_PROBE_INIT) // 2,
|
||||
IKCP_PROBE_LIMIT,
|
||||
)
|
||||
self.ts_probe = self.current + self.probe_wait
|
||||
self.probe |= IKCP_ASK_SEND
|
||||
else:
|
||||
self.ts_probe = 0
|
||||
self.probe_wait = 0
|
||||
|
||||
if self.probe & IKCP_ASK_SEND:
|
||||
seg.cmd = IKCP_CMD_WASK
|
||||
if len(data) + IKCP_OVERHEAD > self.mtu:
|
||||
self.output(data)
|
||||
data = b""
|
||||
data += seg.encode()
|
||||
|
||||
if self.probe & IKCP_ASK_TELL:
|
||||
seg.cmd = IKCP_CMD_WINS
|
||||
if len(data) + IKCP_OVERHEAD > self.mtu:
|
||||
self.output(data)
|
||||
data = b""
|
||||
data += seg.encode()
|
||||
|
||||
self.probe = 0
|
||||
|
||||
cwnd = min(self.snd_wnd, self.rmt_wnd)
|
||||
if not self.nocwnd:
|
||||
cwnd = min(self.cwnd, cwnd)
|
||||
|
||||
while self.snd_una + cwnd > self.snd_nxt:
|
||||
if not self.snd_queue:
|
||||
break
|
||||
|
||||
newseg = self.snd_queue.popleft()
|
||||
self.snd_buf.append(newseg)
|
||||
|
||||
newseg.session_id = self.session_id
|
||||
newseg.cmd = IKCP_CMD_PUSH
|
||||
newseg.wnd = seg.wnd
|
||||
newseg.ts = self.current
|
||||
|
||||
newseg.sn = self.snd_nxt
|
||||
self.snd_nxt += 1
|
||||
|
||||
newseg.una = self.rcv_nxt
|
||||
newseg.resendts = self.current
|
||||
newseg.rto = self.rx_rto
|
||||
newseg.fastack = 0
|
||||
newseg.xmit = 0
|
||||
|
||||
resent = 0xFFFFFFFF
|
||||
if self.fastresend > 0:
|
||||
resent = self.fastresend
|
||||
|
||||
rtomin = 0
|
||||
if not self.nodelay:
|
||||
rtomin = self.rx_rto >> 3
|
||||
|
||||
lost = False
|
||||
change = False
|
||||
|
||||
for segment in self.snd_buf:
|
||||
needsend = False
|
||||
|
||||
if segment.xmit == 0:
|
||||
needsend = True
|
||||
segment.xmit += 1
|
||||
segment.rto = self.rx_rto
|
||||
segment.resendts = self.current + segment.rto + rtomin
|
||||
elif segment.resendts <= self.current:
|
||||
needsend = True
|
||||
segment.xmit += 1
|
||||
self.xmit += 1
|
||||
if not self.nodelay:
|
||||
segment.rto += max(segment.rto, self.rx_rto)
|
||||
else:
|
||||
step = segment.rto if self.nodelay < 2 else self.rx_rto
|
||||
segment.rto += step // 2
|
||||
segment.resendts = self.current + segment.rto
|
||||
lost = True
|
||||
elif segment.fastack >= resent and (
|
||||
segment.xmit <= self.fastlimit or self.fastlimit <= 0
|
||||
):
|
||||
needsend = True
|
||||
segment.xmit += 1
|
||||
segment.fastack = 0
|
||||
segment.resendts = self.current + segment.rto
|
||||
change = True
|
||||
|
||||
if needsend:
|
||||
segment.ts = self.current
|
||||
segment.wnd = seg.wnd
|
||||
segment.una = self.rcv_nxt
|
||||
|
||||
if len(data) + IKCP_OVERHEAD + len(segment.data) > self.mtu:
|
||||
self.output(data)
|
||||
data = b""
|
||||
|
||||
data += segment.encode()
|
||||
if segment.xmit >= self.dead_link:
|
||||
self.state = -1
|
||||
|
||||
if data:
|
||||
self.output(data)
|
||||
|
||||
if change:
|
||||
inflight = self.snd_nxt - self.snd_una
|
||||
self.ssthresh = max(inflight // 2, IKCP_THRESH_MIN)
|
||||
self.cwnd = self.ssthresh + resent
|
||||
self.incr = self.cwnd * self.mss
|
||||
|
||||
if lost:
|
||||
self.ssthresh = max(cwnd // 2, IKCP_THRESH_MIN)
|
||||
self.cwnd = 1
|
||||
self.incr = self.mss
|
||||
|
||||
if self.cwnd < 1:
|
||||
self.cwnd = 1
|
||||
self.incr = self.mss
|
||||
|
||||
def check(self, current):
|
||||
assert current < 1 << 32
|
||||
|
||||
ts_flush = self.ts_flush
|
||||
tm_packet = 0x7FFFFFFF
|
||||
|
||||
if not self.updated:
|
||||
return current
|
||||
|
||||
if current - self.ts_flush >= 10000 or current - self.ts_flush < -10000:
|
||||
ts_flush = current
|
||||
|
||||
if ts_flush <= current:
|
||||
return current
|
||||
|
||||
tm_flush = ts_flush - current
|
||||
|
||||
for seg in self.snd_buf:
|
||||
diff = seg.resendts - current
|
||||
if diff <= 0:
|
||||
return current
|
||||
if diff < tm_packet:
|
||||
tm_packet = diff
|
||||
|
||||
return current + min(tm_flush, tm_packet, self.interval)
|
||||
|
||||
def set_mtu(self, mtu: int):
|
||||
if mtu < 50 or mtu < IKCP_OVERHEAD:
|
||||
raise KcpException("invalid mtu")
|
||||
|
||||
self.mtu = mtu
|
||||
self.mss = self.mtu - IKCP_OVERHEAD
|
||||
|
||||
def set_nodelay(self, nodelay: int, interval: int, resend: int, nc: int):
|
||||
if nodelay >= 0:
|
||||
self.nodelay = nodelay
|
||||
if nodelay:
|
||||
self.rx_minrto = IKCP_RTO_NDL
|
||||
else:
|
||||
self.rx_minrto = IKCP_RTO_MIN
|
||||
|
||||
if interval >= 0:
|
||||
self.interval = min(max(10, interval), 5000)
|
||||
|
||||
if resend >= 0:
|
||||
self.fastresend = resend
|
||||
|
||||
if nc >= 0:
|
||||
self.nocwnd = nc
|
||||
|
||||
def set_wndsize(self, sndwnd: int, rcvwnd: int):
|
||||
if sndwnd > 0:
|
||||
self.snd_wnd = sndwnd
|
||||
if rcvwnd > 0:
|
||||
self.rcv_wnd = max(rcvwnd, IKCP_WND_RCV)
|
||||
93
game_server/net/packet.py
Normal file
93
game_server/net/packet.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from dataclasses import dataclass
|
||||
import struct
|
||||
|
||||
HEAD_MAGIC = 0x9D74C714
|
||||
TAIL_MAGIC = 0xD7A152C8
|
||||
|
||||
|
||||
@dataclass
|
||||
class NetPacket:
|
||||
cmd_type: int
|
||||
head: bytes
|
||||
body: bytes
|
||||
|
||||
def to_message(self, m) -> "NetPacket":
|
||||
return m.parse(self.body)
|
||||
|
||||
@staticmethod
|
||||
def from_message(c, m) -> "NetPacket":
|
||||
return NetPacket(cmd_type=c, head=[], body=bytes(m))
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
# packet_length = 12 + len(self.head) + len(self.body) + 4
|
||||
b = bytearray()
|
||||
|
||||
b.extend(struct.pack(">I", HEAD_MAGIC))
|
||||
b.extend(struct.pack(">H", self.cmd_type))
|
||||
b.extend(struct.pack(">H", len(self.head)))
|
||||
b.extend(struct.pack(">I", len(self.body)))
|
||||
b.extend(self.head)
|
||||
b.extend(self.body)
|
||||
b.extend(struct.pack(">I", TAIL_MAGIC))
|
||||
|
||||
return bytes(b)
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(b: bytes) -> "NetPacket":
|
||||
if len(b) < 16:
|
||||
raise ValueError("len(b) < 16")
|
||||
|
||||
head_magic = struct.unpack_from(">I", b, 0)[0]
|
||||
|
||||
if head_magic != HEAD_MAGIC:
|
||||
raise ValueError("Invalid head magic")
|
||||
|
||||
cmd_type = struct.unpack_from(">H", b, 4)[0]
|
||||
head_length = struct.unpack_from(">H", b, 6)[0]
|
||||
body_length = struct.unpack_from(">I", b, 8)[0]
|
||||
|
||||
head_start = 12
|
||||
head_end = head_start + head_length
|
||||
|
||||
if head_end > len(b):
|
||||
raise ValueError("Head data > packet length")
|
||||
|
||||
head = b[head_start:head_end]
|
||||
|
||||
body_start = head_end
|
||||
body_end = body_start + body_length
|
||||
|
||||
if body_end + 4 > len(b):
|
||||
raise ValueError("Body data > packet length")
|
||||
|
||||
body = b[body_start:body_end]
|
||||
|
||||
tail_magic = struct.unpack_from(">I", b, body_end)[0]
|
||||
|
||||
if tail_magic != TAIL_MAGIC:
|
||||
raise ValueError("Invalid tail magic")
|
||||
|
||||
return NetPacket(cmd_type, head, body)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NetOperation:
|
||||
head: int
|
||||
conv_id: int
|
||||
session_token: int
|
||||
data: int
|
||||
tail: int
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
return struct.pack(
|
||||
">IIIII", self.head, self.conv_id, self.session_token, self.data, self.tail
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(b: bytes) -> "NetOperation":
|
||||
if len(b) != 20:
|
||||
raise ValueError("len(b) != 20")
|
||||
|
||||
head, conv_id, session_token, data, tail = struct.unpack(">IIIII", b)
|
||||
|
||||
return NetOperation(head, conv_id, session_token, data, tail)
|
||||
133
game_server/net/session.py
Normal file
133
game_server/net/session.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import betterproto
|
||||
from game_server.net.kcp import Kcp
|
||||
from game_server.net.packet import NetPacket
|
||||
from utils.logger import Info,Error,Warn
|
||||
from rail_proto import cmd
|
||||
from rail_proto import lib as protos
|
||||
from game_server.game.player.player_manager import PlayerManager
|
||||
from game_server.dummy import dummyprotolist
|
||||
import traceback
|
||||
|
||||
class PlayerSession:
|
||||
def __init__(self, transport, session_id, client_addr, db):
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
self.client_addr = client_addr
|
||||
self.kcp = Kcp(session_id, self.send_output)
|
||||
self.kcp.set_nodelay(1, 5, 2, 0)
|
||||
self.is_destroyed = False
|
||||
self.db = db
|
||||
self.pending_notifies = list()
|
||||
self.player = PlayerManager()
|
||||
self.active = False
|
||||
self.last_received = asyncio.get_event_loop().time()
|
||||
|
||||
def update_last_received(self):
|
||||
self.last_received = asyncio.get_event_loop().time()
|
||||
|
||||
def is_timeout(self, timeout=15):
|
||||
return asyncio.get_event_loop().time() - self.last_received > timeout
|
||||
|
||||
def pending_notify(self, data: betterproto.Message, delay=0):
|
||||
"""
|
||||
This can be used to queue packet to be sent after response inside a handler
|
||||
"""
|
||||
self.pending_notifies.append((data, delay))
|
||||
|
||||
async def notify(self, data: betterproto.Message):
|
||||
msg_name = data.__class__.__name__
|
||||
cmd_id = getattr(cmd.CmdID, msg_name, None)
|
||||
if not cmd_id:
|
||||
Warn(f"Server tried to send notify with unsupported message: {msg_name}")
|
||||
return
|
||||
|
||||
response_packet = NetPacket.from_message(cmd_id, data)
|
||||
await self.send(response_packet)
|
||||
|
||||
def send_output(self, data):
|
||||
self.transport.sendto(data, self.client_addr)
|
||||
|
||||
async def consume(self, data):
|
||||
self.kcp.input(data)
|
||||
self.kcp.update(asyncio.get_running_loop().time())
|
||||
|
||||
while True:
|
||||
packet = self.kcp.recv()
|
||||
if packet is None:
|
||||
break
|
||||
await self.handle_packet(packet)
|
||||
|
||||
self.kcp.update(asyncio.get_running_loop().time())
|
||||
|
||||
async def handle_packet(self, packet):
|
||||
net_packet = NetPacket.from_bytes(packet)
|
||||
cmd_id = net_packet.cmd_type
|
||||
|
||||
request_name = cmd.get_key_by_value(cmd_id)
|
||||
if not request_name:
|
||||
Warn(
|
||||
f"Request doesn't have registered message_id: {cmd_id}"
|
||||
)
|
||||
return
|
||||
if request_name[:-5] in dummyprotolist:
|
||||
dummy_cmd_id = getattr(cmd.CmdID, f"{request_name[:-5]}ScRsp", None)
|
||||
dummy_response = NetPacket.from_message(dummy_cmd_id, b'')
|
||||
await self.send(dummy_response)
|
||||
return
|
||||
try:
|
||||
try:
|
||||
req: betterproto.Message = getattr(protos, request_name)()
|
||||
req.parse(net_packet.body)
|
||||
except Exception:
|
||||
req = betterproto.Message()
|
||||
|
||||
try:
|
||||
handle_result: betterproto.Message = await importlib.import_module(
|
||||
f"game_server.handlers.{request_name}"
|
||||
).handle(self, req)
|
||||
if not handle_result:
|
||||
return
|
||||
except ModuleNotFoundError:
|
||||
Error(f"Unhandled request {request_name}")
|
||||
return
|
||||
except Exception:
|
||||
Error(f"Handler {request_name} returns error.")
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
Info(f"Received cmd: {request_name}({cmd_id})")
|
||||
|
||||
response_name = handle_result.__class__.__name__
|
||||
cmd_type = getattr(cmd.CmdID, response_name, None)
|
||||
if not cmd_type:
|
||||
Warn(
|
||||
f"Server tried to send response with unsupported message: {response_name}"
|
||||
)
|
||||
return
|
||||
response_packet = NetPacket.from_message(cmd_type, handle_result)
|
||||
await self.send(response_packet)
|
||||
|
||||
|
||||
asyncio.create_task(self.send_pending_notifies(
|
||||
self.pending_notifies.copy()
|
||||
))
|
||||
|
||||
self.pending_notifies.clear()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def send_pending_notifies(
|
||||
self,
|
||||
pending_notifies: list[tuple[betterproto.Message, int]]
|
||||
):
|
||||
for notify, delay in pending_notifies:
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
await self.notify(notify)
|
||||
|
||||
async def send(self, packet):
|
||||
self.kcp.send(packet.to_bytes())
|
||||
self.kcp.flush()
|
||||
self.kcp.update(asyncio.get_running_loop().time())
|
||||
Reference in New Issue
Block a user