diff --git a/util/file.go b/util/file.go index e4314bc..fc7c18f 100644 --- a/util/file.go +++ b/util/file.go @@ -15,7 +15,9 @@ package util import ( + "io" "os" + "path/filepath" "strings" "github.com/b3log/wide/log" @@ -80,3 +82,67 @@ func (*myfile) IsDir(path string) bool { return fio.IsDir() } + +// CopyFile copies the source file to the dest file. +func (*myfile) CopyFile(source string, dest string) (err error) { + sourcefile, err := os.Open(source) + if err != nil { + return err + } + + defer sourcefile.Close() + + destfile, err := os.Create(dest) + if err != nil { + return err + } + + defer destfile.Close() + + _, err = io.Copy(destfile, sourcefile) + if err == nil { + sourceinfo, err := os.Stat(source) + if err != nil { + err = os.Chmod(dest, sourceinfo.Mode()) + } + } + + return nil +} + +// CopyDir copies the source directory to the dest directory. +func (*myfile) CopyDir(source string, dest string) (err error) { + sourceinfo, err := os.Stat(source) + if err != nil { + return err + } + + // create dest dir + err = os.MkdirAll(dest, sourceinfo.Mode()) + if err != nil { + return err + } + + directory, _ := os.Open(source) + objects, err := directory.Readdir(-1) + + for _, obj := range objects { + srcFilePath := filepath.Join(source, obj.Name()) + destFilePath := filepath.Join(dest, obj.Name()) + + if obj.IsDir() { + // create sub-directories - recursively + err = File.CopyDir(srcFilePath, destFilePath) + if err != nil { + fileLogger.Error(err) + } + } else { + err = File.CopyFile(srcFilePath, destFilePath) + if err != nil { + fileLogger.Error(err) + } + } + } + + return nil +} diff --git a/util/file_test.go b/util/file_test.go index 10b9d65..0b3d631 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -15,6 +15,8 @@ package util import ( + "os" + "path/filepath" "strconv" "testing" ) @@ -56,3 +58,35 @@ func TestIsDir(t *testing.T) { return } } + +func TestCopyDir(t *testing.T) { + home, _ := OS.Home() + + testDir := filepath.Join(home, "wide-test") + os.Mkdir(testDir, 0644) + + dest := filepath.Join(testDir, "util") + + err := File.CopyDir(".", dest) + if nil != err { + t.Error("Copy dir error: ", err) + + return + } +} + +func TestCopyFile(t *testing.T) { + home, _ := OS.Home() + + testDir := filepath.Join(home, "wide-test") + os.Mkdir(testDir, 0644) + + dest := filepath.Join(testDir, "file.go") + + err := File.CopyFile("./file.go", dest) + if nil != err { + t.Error("Copy file error: ", err) + + return + } +} diff --git a/util/zip_test.go b/util/zip_test.go index 8c02752..e7382e3 100644 --- a/util/zip_test.go +++ b/util/zip_test.go @@ -16,10 +16,13 @@ package util import ( "os" + "path/filepath" "testing" ) -var packageName = "test_zip" +var home, _ = OS.Home() +var testDir = filepath.Join(home, "wide-test") +var packageName = filepath.Join(testDir, "test_zip") func TestCreate(t *testing.T) { zipFile, err := Zip.Create(packageName + ".zip") @@ -54,11 +57,12 @@ func TestUnzip(t *testing.T) { } func TestMain(m *testing.M) { + os.Mkdir(testDir, 0644) + retCode := m.Run() // clean test data - os.RemoveAll(packageName + ".zip") - os.RemoveAll(packageName) + os.RemoveAll(testDir) os.Exit(retCode) }