update ImportDB and enhancement

This commit is contained in:
Hamidreza Ghavami 2023-05-06 04:47:57 +04:30
parent 058ab5f901
commit 83c853ffb6
2 changed files with 46 additions and 18 deletions

View File

@ -154,13 +154,15 @@ func (a *ServerController) importDB(c *gin.Context) {
defer file.Close() defer file.Close()
// Always restart Xray before return // Always restart Xray before return
defer a.serverService.RestartXrayService() defer a.serverService.RestartXrayService()
defer func() {
a.lastGetStatusTime = time.Now()
}()
// Import it // Import it
err = a.serverService.ImportDB(file) err = a.serverService.ImportDB(file)
if err != nil { if err != nil {
jsonMsg(c, "", err) jsonMsg(c, "", err)
return return
} }
a.lastGetStatusTime = time.Now()
jsonObj(c, "Import DB", nil) jsonObj(c, "Import DB", nil)
} }

View File

@ -409,23 +409,33 @@ func (s *ServerService) ImportDB(file multipart.File) error {
return common.NewError("Invalid db file format") return common.NewError("Invalid db file format")
} }
// Save the file as temporary file
tempPath := fmt.Sprintf("%s.temp", config.GetDBPath())
// remove temp file before return
defer os.Remove(tempPath)
tempFile, err := os.Create(tempPath)
if err != nil {
return common.NewErrorf("Error creating temporary db file: %v", err)
}
defer tempFile.Close()
// Reset the file reader to the beginning // Reset the file reader to the beginning
_, err = file.Seek(0, 0) _, err = file.Seek(0, 0)
if err != nil { if err != nil {
return common.NewErrorf("Error resetting file reader: %v", err) return common.NewErrorf("Error resetting file reader: %v", err)
} }
// Save temp file // Save the file as temporary file
tempPath := fmt.Sprintf("%s.temp", config.GetDBPath())
// Remove the existing fallback file (if any) before creating one
_, err = os.Stat(tempPath)
if err == nil {
errRemove := os.Remove(tempPath)
if errRemove != nil {
return common.NewErrorf("Error removing existing temporary db file: %v", errRemove)
}
}
// Create the temporary file
tempFile, err := os.Create(tempPath)
if err != nil {
return common.NewErrorf("Error creating temporary db file: %v", err)
}
defer tempFile.Close()
// Remove temp file before returning
defer os.Remove(tempPath)
// Save uploaded file to temporary file
_, err = io.Copy(tempFile, file) _, err = io.Copy(tempFile, file)
if err != nil { if err != nil {
return common.NewErrorf("Error saving db: %v", err) return common.NewErrorf("Error saving db: %v", err)
@ -440,26 +450,42 @@ func (s *ServerService) ImportDB(file multipart.File) error {
// Stop Xray // Stop Xray
s.StopXrayService() s.StopXrayService()
// Backup db for fallback // Backup the current database for fallback
fallbackPath := fmt.Sprintf("%s.backup", config.GetDBPath()) fallbackPath := fmt.Sprintf("%s.backup", config.GetDBPath())
// remove fallback file before return // Remove the existing fallback file (if any)
defer os.Remove(fallbackPath) _, err = os.Stat(fallbackPath)
if err == nil {
errRemove := os.Remove(fallbackPath)
if errRemove != nil {
return common.NewErrorf("Error removing existing fallback db file: %v", errRemove)
}
}
// Move the current database to the fallback location
err = os.Rename(config.GetDBPath(), fallbackPath) err = os.Rename(config.GetDBPath(), fallbackPath)
if err != nil { if err != nil {
return common.NewErrorf("Error backup temporary db file: %v", err) return common.NewErrorf("Error backing up temporary db file: %v", err)
} }
// Remove the temporary file before returning
defer os.Remove(fallbackPath)
// Move temp to DB path // Move temp to DB path
err = os.Rename(tempPath, config.GetDBPath()) err = os.Rename(tempPath, config.GetDBPath())
if err != nil { if err != nil {
os.Rename(fallbackPath, config.GetDBPath()) errRename := os.Rename(fallbackPath, config.GetDBPath())
if errRename != nil {
return common.NewErrorf("Error moving db file and restoring fallback: %v", errRename)
}
return common.NewErrorf("Error moving db file: %v", err) return common.NewErrorf("Error moving db file: %v", err)
} }
// Migrate DB // Migrate DB
err = database.InitDB(config.GetDBPath()) err = database.InitDB(config.GetDBPath())
if err != nil { if err != nil {
os.Rename(fallbackPath, config.GetDBPath()) errRename := os.Rename(fallbackPath, config.GetDBPath())
if errRename != nil {
return common.NewErrorf("Error migrating db and restoring fallback: %v", errRename)
}
return common.NewErrorf("Error migrating db: %v", err) return common.NewErrorf("Error migrating db: %v", err)
} }
s.inboundService.MigrateDB() s.inboundService.MigrateDB()