123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- // Copyright (c) 2016 Arista Networks, Inc.
- // Use of this source code is governed by the Apache License 2.0
- // that can be found in the COPYING file.
- package netns
- import (
- "io/ioutil"
- "os"
- "path/filepath"
- "testing"
- )
- type mockHandle int
- func (mh mockHandle) close() error {
- return nil
- }
- func (mh mockHandle) fd() int {
- return 0
- }
- func TestNetNs(t *testing.T) {
- setNsCallCount := 0
- // Mock getNs
- oldGetNs := getNs
- getNs = func(nsName string) (handle, error) {
- return mockHandle(1), nil
- }
- defer func() {
- getNs = oldGetNs
- }()
- // Mock setNs
- oldSetNs := setNs
- setNs = func(fd handle) error {
- setNsCallCount++
- return nil
- }
- defer func() {
- setNs = oldSetNs
- }()
- // Create a tempfile so we can use its name for the network namespace
- tmpfile, err := ioutil.TempFile("", "")
- if err != nil {
- t.Fatalf("Failed to create a temp file: %s", err)
- }
- defer os.Remove(tmpfile.Name())
- nsName := filepath.Base(tmpfile.Name())
- // Map of network namespace name to the number of times it should call setNs
- cases := map[string]int{"": 0, "default": 2, nsName: 2}
- for name, callCount := range cases {
- var cbResult string
- err = Do(name, func() error {
- cbResult = "Hello" + name
- return nil
- })
- if err != nil {
- t.Fatalf("Error calling function in different network namespace: %s", err)
- }
- if cbResult != "Hello"+name {
- t.Fatalf("Failed to call the callback function")
- }
- if setNsCallCount != callCount {
- t.Fatalf("setNs should have been called %d times for %s, but was called %d times",
- callCount, name, setNsCallCount)
- }
- setNsCallCount = 0
- }
- }
|