diff --git a/core/webrtc/whep.go b/core/webrtc/whep.go new file mode 100644 index 0000000..87e07ac --- /dev/null +++ b/core/webrtc/whep.go @@ -0,0 +1,93 @@ +package webrtc + +import ( + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + + "github.com/pion/webrtc/v4" +) + +// WHEPHandler serves the WebRTC-HTTP Egress Protocol POST. +type WHEPHandler struct { + registry *Registry + factory *PeerFactory + config Config + + mu sync.Mutex + peers map[string]*Peer // resourceID → Peer + peersCount int64 // atomic, for cap check without lock +} + +// NewWHEPHandler constructs a handler with the given dependencies. +func NewWHEPHandler(r *Registry, f *PeerFactory, c Config) *WHEPHandler { + return &WHEPHandler{ + registry: r, + factory: f, + config: c, + peers: make(map[string]*Peer), + } +} + +// ServeHTTP handles POST /whep/{stream_id}. Other methods and paths return 405. +func (h *WHEPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.Header().Set("Allow", "POST") + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + // Extract stream_id from path: /whep/{stream_id} + streamID := strings.TrimPrefix(r.URL.Path, "/whep/") + if streamID == "" || strings.Contains(streamID, "/") { + http.Error(w, "invalid stream id", http.StatusBadRequest) + return + } + + // Peer cap enforcement (happy path still respects the cap). + if atomic.LoadInt64(&h.peersCount) >= int64(h.config.MaxPeersTotal) { + http.Error(w, ErrPeerCapReached.Error(), http.StatusServiceUnavailable) + return + } + + handle, ok := h.registry.Lookup(streamID) + if !ok { + http.Error(w, ErrStreamNotFound.Error(), http.StatusNotFound) + return + } + src, ok := handle.(*Source) + if !ok { + http.Error(w, "registered source is not a *Source", http.StatusInternalServerError) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "read body: "+err.Error(), http.StatusBadRequest) + return + } + if len(body) == 0 { + http.Error(w, ErrInvalidSDP.Error(), http.StatusBadRequest) + return + } + + offer := webrtc.SessionDescription{Type: webrtc.SDPTypeOffer, SDP: string(body)} + + peer, err := h.factory.CreatePeer(r.Context(), src, offer) + if err != nil { + http.Error(w, "create peer: "+err.Error(), http.StatusInternalServerError) + return + } + + h.mu.Lock() + h.peers[peer.ResourceID()] = peer + h.mu.Unlock() + atomic.AddInt64(&h.peersCount, 1) + + w.Header().Set("Content-Type", "application/sdp") + w.Header().Set("Location", "/whep/"+streamID+"/"+peer.ResourceID()) + w.WriteHeader(http.StatusCreated) + _, _ = io.WriteString(w, peer.Answer().SDP) +} diff --git a/core/webrtc/whep_test.go b/core/webrtc/whep_test.go new file mode 100644 index 0000000..fe5d755 --- /dev/null +++ b/core/webrtc/whep_test.go @@ -0,0 +1,64 @@ +package webrtc + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/pion/webrtc/v4" +) + +func TestWHEP_POSTReturns201WithSDP(t *testing.T) { + // Set up a Source and register it. + src, _ := NewSource("streamA", 0) + defer src.Close() + src.Start() + + reg := NewRegistry() + _ = reg.Register("streamA", src) + + factory, _ := NewPeerFactory(DefaultConfig()) + + handler := NewWHEPHandler(reg, factory, DefaultConfig()) + + // Build an offer using a throwaway PC. + me := &webrtc.MediaEngine{} + _ = me.RegisterDefaultCodecs() + api := webrtc.NewAPI(webrtc.WithMediaEngine(me)) + pc, _ := api.NewPeerConnection(webrtc.Configuration{}) + defer pc.Close() + _, _ = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, + webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionRecvonly}) + _, _ = pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, + webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionRecvonly}) + offer, _ := pc.CreateOffer(nil) + + req := httptest.NewRequest(http.MethodPost, "/whep/streamA", + strings.NewReader(offer.SDP)) + req.Header.Set("Content-Type", "application/sdp") + // Give the handler generous ICE gathering time in tests. + ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) + defer cancel() + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + body, _ := io.ReadAll(rr.Result().Body) + t.Fatalf("status = %d, want 201. body=%s", rr.Code, string(body)) + } + if ct := rr.Header().Get("Content-Type"); ct != "application/sdp" { + t.Errorf("Content-Type = %q, want application/sdp", ct) + } + if loc := rr.Header().Get("Location"); !strings.HasPrefix(loc, "/whep/streamA/") { + t.Errorf("Location = %q, want /whep/streamA/", loc) + } + if !strings.Contains(rr.Body.String(), "v=0") { + t.Errorf("body does not look like SDP: %s", rr.Body.String()) + } +}